#![cfg(feature = "ndarray-backend")]
use catgrad::category::core::Shape;
use catgrad::category::lang::*;
use catgrad::{typecheck, typecheck::*};
use catgrad::stdlib::*;
use catgrad::interpreter::backend::Backend;
use catgrad::interpreter::backend::ndarray::NdArrayBackend;
use catgrad::interpreter::{
Interpreter, Parameters, TaggedTensor, TaggedTensorTuple, Value, tensor,
};
pub mod test_models;
pub mod test_utils;
use catgrad::stdlib::nn::Exp;
use test_models::{Add, BatchMatMul};
fn run_test_with_inputs<F>(
TypedTerm {
term, source_type, ..
}: TypedTerm,
build_inputs: F,
) -> Vec<catgrad::interpreter::Value<NdArrayBackend>>
where
F: FnOnce(&NdArrayBackend) -> Vec<catgrad::interpreter::Value<NdArrayBackend>>,
{
let env = catgrad::stdlib::stdlib();
let _result = check_with(
&env,
&typecheck::Parameters::default(),
term.clone(),
source_type,
)
.unwrap();
let backend = NdArrayBackend;
let interpreter: Interpreter<NdArrayBackend> =
Interpreter::new(backend, env, Parameters::default());
let values = build_inputs(&interpreter.backend);
interpreter.run(term, values).unwrap()
}
#[test]
fn test_run_add() {
let data: Vec<u32> = vec![1, 2, 3, 4, 5, 6]; let result = run_test_with_inputs(Add.term().unwrap(), |backend| {
let input = tensor(backend, Shape(vec![2, 1, 3]), &data).unwrap();
vec![input.clone(), input]
});
println!("Interpreter result: {result:?}");
let expected_data: Vec<u32> = data.iter().map(|&x| x * 2).collect();
let backend = NdArrayBackend;
let expected = tensor(&backend, Shape(vec![2, 1, 3]), &expected_data).unwrap();
let backend = NdArrayBackend;
match (&result[0], &expected) {
(Value::Tensor(TaggedTensor::U32([actual])), Value::Tensor(TaggedTensor::U32([exp]))) => {
assert!(
backend.compare(TaggedTensorTuple::U32([actual.clone(), exp.clone()])),
"Result should be double the input data"
);
}
_ => panic!("Expected U32 tensors"),
}
}
#[test]
fn test_run_batch_matmul() {
let x0_data: Vec<f32> = vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, ];
let x1_data: Vec<f32> = vec![
1.0, 2.0, 3.0, 4.0, ];
let result = run_test_with_inputs(BatchMatMul.term().unwrap(), |backend| {
let x0 = tensor(backend, Shape(vec![2, 2, 2]), &x0_data).unwrap();
let x1 = tensor(backend, Shape(vec![2, 2, 1]), &x1_data).unwrap();
vec![x0, x1]
});
let backend = NdArrayBackend;
let expected_data: Vec<f32> = vec![
5.0, 11.0, 39.0, 53.0, ];
let expected = tensor(&backend, Shape(vec![2, 2, 1]), &expected_data).unwrap();
let backend = NdArrayBackend;
match (&result[0], &expected) {
(Value::Tensor(TaggedTensor::F32([actual])), Value::Tensor(TaggedTensor::F32([exp]))) => {
assert!(
backend.compare(TaggedTensorTuple::F32([actual.clone(), exp.clone()])),
"Batch matmul result should match expected output"
);
}
_ => panic!("Expected F32 tensors"),
}
}
fn allclose_f32(a: &[f32], b: &[f32], rtol: f32, atol: f32) -> bool {
if a.len() != b.len() {
return false;
}
a.iter().zip(b.iter()).all(|(&x, &y)| {
let diff = (x - y).abs();
diff <= atol + rtol * y.abs()
})
}
#[test]
fn test_run_exp() {
let data: Vec<f32> = vec![0.0, 1.0, 2.0, -1.0]; let result = run_test_with_inputs(Exp.term().unwrap(), |backend| {
vec![tensor(backend, Shape(vec![2, 2]), &data).unwrap()]
});
use catgrad::interpreter::{TaggedTensor, Value};
let actual = match &result[..] {
[Value::Tensor(TaggedTensor::F32([actual]))] => actual,
xs => panic!("wrong output type: {xs:?}"),
};
let expected: Vec<f32> = data.iter().map(|&x| x.exp()).collect();
let backend = NdArrayBackend;
let expected_tensor = tensor(&backend, Shape(vec![2, 2]), &expected).unwrap();
match (&expected_tensor, actual) {
(Value::Tensor(TaggedTensor::F32([exp])), actual_arr) => {
assert!(
allclose_f32(
actual_arr.as_slice().unwrap(),
exp.as_slice().unwrap(),
1e-5,
1e-8
),
"actual should be close to expected!"
);
}
_ => panic!("Expected F32 tensors"),
}
}