use catgrad::category::lang::*;
use catgrad::stdlib::*;
use catgrad::typecheck::*;
pub mod test_utils;
use test_utils::{get_forget_core_declarations, save_diagram_if_enabled};
pub mod test_models;
use test_models::LinearSigmoid;
#[test]
fn test_construct_linear_sigmoid() {
let sigmoid = nn::Sigmoid.term().unwrap();
println!("{sigmoid:?}");
let term = LinearSigmoid.term().unwrap();
println!("{term:?}");
}
#[test]
fn test_graph_sigmoid() {
let term = nn::Sigmoid.term().unwrap().term;
let term = open_hypergraphs::lax::var::forget::forget_monogamous(&term);
save_diagram_if_enabled("test_graph_sigmoid.svg", &term);
}
#[test]
fn test_graph_linear_sigmoid() {
let term = LinearSigmoid.term().unwrap().term;
let term = open_hypergraphs::lax::var::forget::forget_monogamous(&term);
save_diagram_if_enabled("test_graph_linear_sigmoid.svg", &term);
}
#[test]
fn test_check_linear_sigmoid() {
run_check_test(LinearSigmoid.term(), "test_check_linear_sigmoid.svg").expect("valid");
}
#[test]
fn test_check_sigmoid() {
run_check_test(nn::Sigmoid.term(), "test_check_sigmoid.svg").expect("valid");
}
#[test]
fn test_check_exp() {
run_check_test(nn::Exp.term(), "test_check_exp.svg").expect("valid");
}
#[allow(clippy::result_large_err)]
pub fn run_check_test(
term: Option<catgrad::category::lang::TypedTerm>,
svg_filename: &str,
) -> Result<(), InterpreterError> {
let TypedTerm {
term, source_type, ..
} = term.unwrap();
let term = open_hypergraphs::lax::var::forget::forget_monogamous(&term);
let env = get_forget_core_declarations();
let result = check_with(&env, &Parameters::default(), term.clone(), source_type)?;
println!("result: {result:?}");
let typed_term = term
.with_nodes(|_| result.into_iter().map(|e| format!("{e:?}")).collect())
.unwrap();
save_diagram_if_enabled(svg_filename, &typed_term);
Ok(())
}