use std::sync::Arc;
use sim_kernel::{
Args, Cx, DefaultFactory, EagerPolicy, Error, Expr, NumberLiteral, QuoteMode, Symbol,
};
use sim_lib_numbers_func::Func;
use crate::{ComposedPipeline, NumericNumbersLib, PipelineKind, StateKind, numeric_compose_symbol};
fn test_cx() -> Cx {
let mut cx = Cx::new(Arc::new(EagerPolicy), Arc::new(DefaultFactory));
cx.load_lib(&sim_lib_numbers_arith::NumbersArithmeticLib::new())
.unwrap();
cx.load_lib(&sim_lib_numbers_f64::F64NumbersLib::new())
.unwrap();
cx.load_lib(&sim_lib_numbers_cas::CasNumbersLib::new())
.unwrap();
cx.load_lib(&sim_lib_numbers_cas_diff::CasDiffLib::new())
.unwrap();
cx.load_lib(&sim_lib_numbers_cas_eval::CasEvalLib::new())
.unwrap();
cx.load_lib(&sim_lib_numbers_func::FuncNumbersLib::new())
.unwrap();
cx.load_lib(&NumericNumbersLib::new()).unwrap();
cx
}
fn f64_number(text: &str) -> Expr {
Expr::Number(NumberLiteral {
domain: Symbol::qualified("numbers", "f64"),
canonical: text.to_owned(),
})
}
fn quoted(name: &str) -> Expr {
Expr::Quote {
mode: QuoteMode::Quote,
expr: Box::new(Expr::Symbol(Symbol::new(name))),
}
}
fn native_identity_func(cx: &mut Cx) -> sim_kernel::Value {
cx.factory()
.opaque(Arc::new(Func::native(
vec![Symbol::new("x")],
Arc::new(|_cx, args| {
let [x] = args else {
return Err(Error::Eval("expected one arg".to_owned()));
};
Ok(x.clone())
}),
)))
.unwrap()
}
#[test]
fn composed_pipeline_table_value_round_trips() {
let mut cx = test_cx();
let pipeline = ComposedPipeline::new(
sim_kernel::Ref::Symbol(Symbol::new("test-func")),
PipelineKind::OdeSolve,
Symbol::new("rk4"),
StateKind::F64,
);
let table = pipeline.table_value(cx.factory()).unwrap();
let table_impl = table.object().as_table_impl().unwrap();
assert_eq!(
table_impl
.get(&mut cx, Symbol::new("kind"))
.unwrap()
.object()
.as_expr(&mut cx)
.unwrap(),
Expr::String("composed-pipeline".to_owned())
);
assert_eq!(
table_impl
.get(&mut cx, Symbol::new("domain"))
.unwrap()
.object()
.as_expr(&mut cx)
.unwrap(),
Expr::Symbol(Symbol::new("ode-solve"))
);
assert_eq!(
table_impl
.get(&mut cx, Symbol::new("method"))
.unwrap()
.object()
.as_expr(&mut cx)
.unwrap(),
Expr::Symbol(Symbol::new("rk4"))
);
assert_eq!(
table_impl
.get(&mut cx, Symbol::new("state"))
.unwrap()
.object()
.as_expr(&mut cx)
.unwrap(),
Expr::Symbol(Symbol::new("f64"))
);
assert_eq!(
table_impl
.get(&mut cx, Symbol::new("func"))
.unwrap()
.object()
.as_expr(&mut cx)
.unwrap(),
Expr::Symbol(Symbol::new("test-func"))
);
}
#[test]
fn numeric_compose_returns_composed_pipeline_value() {
let mut cx = test_cx();
let func = native_identity_func(&mut cx);
let kind = cx.factory().symbol(Symbol::new(":ode-solve")).unwrap();
let method = cx.factory().symbol(Symbol::new("rk4")).unwrap();
let state = cx.factory().symbol(Symbol::new(":f64")).unwrap();
let value = cx
.call_function(
&numeric_compose_symbol(),
Args::new(vec![func, kind, method, state]),
)
.unwrap();
let pipeline = value.object().downcast_ref::<ComposedPipeline>().unwrap();
assert_eq!(pipeline.kind, PipelineKind::OdeSolve);
assert_eq!(pipeline.method, Symbol::new("rk4"));
assert_eq!(pipeline.state, StateKind::F64);
assert!(matches!(pipeline.func_ref, sim_kernel::Ref::Handle(_)));
}
#[test]
fn unknown_method_errors_cleanly() {
let mut cx = test_cx();
let err = cx
.eval_expr(Expr::Call {
operator: Box::new(Expr::Symbol(Symbol::new("numeric-diff"))),
args: vec![
Expr::Call {
operator: Box::new(Expr::Symbol(Symbol::new("fn"))),
args: vec![
Expr::List(vec![Expr::Symbol(Symbol::new("x"))]),
Expr::Symbol(Symbol::new("x")),
],
},
quoted("x"),
f64_number("2.0"),
Expr::Symbol(Symbol::new(":method")),
quoted("no-such-method"),
],
})
.unwrap_err();
assert!(matches!(err, Error::Eval(message) if message.contains("UnknownNumericMethod")));
}
#[test]
fn native_func_can_be_passed_to_numeric_diff() {
let mut cx = test_cx();
let func = cx
.factory()
.opaque(Arc::new(Func::native(
vec![Symbol::new("x")],
Arc::new(|cx, args| {
let [x] = args else {
return Err(Error::Eval("expected one arg".to_owned()));
};
let x2 = cx.apply_value_number_binary_op(
&Symbol::qualified("math", "mul"),
x.clone(),
x.clone(),
)?;
cx.apply_value_number_binary_op(&Symbol::qualified("math", "add"), x2, x.clone())
}),
)))
.unwrap();
let out = cx
.call_function(
&Symbol::new("numeric-diff"),
Args::new(vec![
func,
cx.factory().expr(quoted("x")).unwrap(),
cx.factory()
.number_literal(Symbol::qualified("numbers", "f64"), "3.0".to_owned())
.unwrap(),
]),
)
.unwrap();
let rendered = out
.object()
.display(&mut cx)
.unwrap()
.parse::<f64>()
.unwrap();
assert!((rendered - 7.0).abs() < 1.0e-3);
}
#[test]
fn symbolic_func_prefers_symbolic_derivative_before_numeric_fallback() {
let mut cx = test_cx();
let out = cx
.eval_expr(Expr::Call {
operator: Box::new(Expr::Symbol(Symbol::new("numeric-diff"))),
args: vec![
Expr::Call {
operator: Box::new(Expr::Symbol(Symbol::new("fn"))),
args: vec![
Expr::List(vec![Expr::Symbol(Symbol::new("x"))]),
Expr::Call {
operator: Box::new(Expr::Symbol(Symbol::new("+"))),
args: vec![
Expr::Call {
operator: Box::new(Expr::Symbol(Symbol::new("*"))),
args: vec![quoted("x"), quoted("x")],
},
quoted("x"),
],
},
],
},
quoted("x"),
f64_number("3.0"),
],
})
.unwrap();
assert_eq!(out.object().display(&mut cx).unwrap(), "7");
let diagnostics = cx.take_diagnostics();
assert!(
diagnostics
.iter()
.any(|diagnostic| diagnostic.message.contains("method=auto"))
);
}