catgrad 0.2.1

a categorical deep learning compiler
Documentation
use super::types::*;

/// Make sure an op has exact arity m, consistent with arguments
pub(crate) fn get_exact_arity<const N: usize, T>(ssa: &CoreSSA, args: Vec<T>) -> Result<[T; N]> {
    if ssa.sources.len() != N {
        return Err(InterpreterError::ArityError(ssa.edge_id));
    }

    if args.len() != N {
        // TODO: return a better error here
        return Err(InterpreterError::ArityError(ssa.edge_id));
    }

    args.try_into()
        .map_err(|_e| InterpreterError::ArityError(ssa.edge_id))
}

pub(crate) fn ensure_profile<T>(ssa: &CoreSSA, args: Vec<T>, m: usize, n: usize) -> Result<Vec<T>> {
    if ssa.sources.len() != m || args.len() != m {
        return Err(InterpreterError::ArityError(ssa.edge_id));
    }

    // TODO: coarity error?
    if ssa.targets.len() != n {
        return Err(InterpreterError::ArityError(ssa.edge_id));
    };

    Ok(args)
}

////////////////////////////////////////////////////////////////////////////////
// Match cases of Value or yield an EvalResult.
// NOTE: we don't use TryInto here because we need the ssa value to build an error.

// unwrap a Value to a nat
pub(crate) fn to_nat<V: Interpreter>(ssa: &CoreSSA, v: Value<V>) -> Result<V::Nat> {
    match v {
        Value::Nat(v) => Ok(v),
        _ => Err(InterpreterError::TypeError(ssa.edge_id)),
    }
}

pub(crate) fn to_shape<V: Interpreter>(ssa: &CoreSSA, v: Value<V>) -> Result<V::Shape> {
    match v {
        Value::Shape(s) => Ok(s),
        _ => Err(InterpreterError::TypeError(ssa.edge_id)),
    }
}

pub(crate) fn to_tensor<V: Interpreter>(ssa: &CoreSSA, v: Value<V>) -> Result<V::Tensor> {
    match v {
        Value::Tensor(t) => Ok(t),
        _ => Err(InterpreterError::TypeError(ssa.edge_id)),
    }
}

pub(crate) fn to_dtype<V: Interpreter>(ssa: &CoreSSA, v: Value<V>) -> Result<V::Dtype> {
    match v {
        Value::Dtype(d) => Ok(d),
        _ => Err(InterpreterError::TypeError(ssa.edge_id)),
    }
}