sim-lib-numbers-numeric 0.1.0

SIM workspace package for sim lib numbers numeric.
Documentation
use std::sync::Arc;

use sim_kernel::{
    Args, Cx, DefaultFactory, EagerPolicy, Error, Expr, NumberLiteral, QuoteMode, Symbol,
};
use sim_lib_numbers_func::Func;

use crate::{ComposedPipeline, NumericNumbersLib, PipelineKind, StateKind, numeric_compose_symbol};

fn test_cx() -> Cx {
    let mut cx = Cx::new(Arc::new(EagerPolicy), Arc::new(DefaultFactory));
    cx.load_lib(&sim_lib_numbers_arith::NumbersArithmeticLib::new())
        .unwrap();
    cx.load_lib(&sim_lib_numbers_f64::F64NumbersLib::new())
        .unwrap();
    cx.load_lib(&sim_lib_numbers_cas::CasNumbersLib::new())
        .unwrap();
    cx.load_lib(&sim_lib_numbers_cas_diff::CasDiffLib::new())
        .unwrap();
    cx.load_lib(&sim_lib_numbers_cas_eval::CasEvalLib::new())
        .unwrap();
    cx.load_lib(&sim_lib_numbers_func::FuncNumbersLib::new())
        .unwrap();
    cx.load_lib(&NumericNumbersLib::new()).unwrap();
    cx
}

fn f64_number(text: &str) -> Expr {
    Expr::Number(NumberLiteral {
        domain: Symbol::qualified("numbers", "f64"),
        canonical: text.to_owned(),
    })
}

fn quoted(name: &str) -> Expr {
    Expr::Quote {
        mode: QuoteMode::Quote,
        expr: Box::new(Expr::Symbol(Symbol::new(name))),
    }
}

fn native_identity_func(cx: &mut Cx) -> sim_kernel::Value {
    cx.factory()
        .opaque(Arc::new(Func::native(
            vec![Symbol::new("x")],
            Arc::new(|_cx, args| {
                let [x] = args else {
                    return Err(Error::Eval("expected one arg".to_owned()));
                };
                Ok(x.clone())
            }),
        )))
        .unwrap()
}

#[test]
fn composed_pipeline_table_value_round_trips() {
    let mut cx = test_cx();
    let pipeline = ComposedPipeline::new(
        sim_kernel::Ref::Symbol(Symbol::new("test-func")),
        PipelineKind::OdeSolve,
        Symbol::new("rk4"),
        StateKind::F64,
    );

    let table = pipeline.table_value(cx.factory()).unwrap();
    let table_impl = table.object().as_table_impl().unwrap();
    assert_eq!(
        table_impl
            .get(&mut cx, Symbol::new("kind"))
            .unwrap()
            .object()
            .as_expr(&mut cx)
            .unwrap(),
        Expr::String("composed-pipeline".to_owned())
    );
    assert_eq!(
        table_impl
            .get(&mut cx, Symbol::new("domain"))
            .unwrap()
            .object()
            .as_expr(&mut cx)
            .unwrap(),
        Expr::Symbol(Symbol::new("ode-solve"))
    );
    assert_eq!(
        table_impl
            .get(&mut cx, Symbol::new("method"))
            .unwrap()
            .object()
            .as_expr(&mut cx)
            .unwrap(),
        Expr::Symbol(Symbol::new("rk4"))
    );
    assert_eq!(
        table_impl
            .get(&mut cx, Symbol::new("state"))
            .unwrap()
            .object()
            .as_expr(&mut cx)
            .unwrap(),
        Expr::Symbol(Symbol::new("f64"))
    );
    assert_eq!(
        table_impl
            .get(&mut cx, Symbol::new("func"))
            .unwrap()
            .object()
            .as_expr(&mut cx)
            .unwrap(),
        Expr::Symbol(Symbol::new("test-func"))
    );
}

#[test]
fn numeric_compose_returns_composed_pipeline_value() {
    let mut cx = test_cx();
    let func = native_identity_func(&mut cx);
    let kind = cx.factory().symbol(Symbol::new(":ode-solve")).unwrap();
    let method = cx.factory().symbol(Symbol::new("rk4")).unwrap();
    let state = cx.factory().symbol(Symbol::new(":f64")).unwrap();

    let value = cx
        .call_function(
            &numeric_compose_symbol(),
            Args::new(vec![func, kind, method, state]),
        )
        .unwrap();
    let pipeline = value.object().downcast_ref::<ComposedPipeline>().unwrap();
    assert_eq!(pipeline.kind, PipelineKind::OdeSolve);
    assert_eq!(pipeline.method, Symbol::new("rk4"));
    assert_eq!(pipeline.state, StateKind::F64);
    assert!(matches!(pipeline.func_ref, sim_kernel::Ref::Handle(_)));
}

#[test]
fn unknown_method_errors_cleanly() {
    let mut cx = test_cx();
    let err = cx
        .eval_expr(Expr::Call {
            operator: Box::new(Expr::Symbol(Symbol::new("numeric-diff"))),
            args: vec![
                Expr::Call {
                    operator: Box::new(Expr::Symbol(Symbol::new("fn"))),
                    args: vec![
                        Expr::List(vec![Expr::Symbol(Symbol::new("x"))]),
                        Expr::Symbol(Symbol::new("x")),
                    ],
                },
                quoted("x"),
                f64_number("2.0"),
                Expr::Symbol(Symbol::new(":method")),
                quoted("no-such-method"),
            ],
        })
        .unwrap_err();
    assert!(matches!(err, Error::Eval(message) if message.contains("UnknownNumericMethod")));
}

#[test]
fn native_func_can_be_passed_to_numeric_diff() {
    let mut cx = test_cx();
    let func = cx
        .factory()
        .opaque(Arc::new(Func::native(
            vec![Symbol::new("x")],
            Arc::new(|cx, args| {
                let [x] = args else {
                    return Err(Error::Eval("expected one arg".to_owned()));
                };
                let x2 = cx.apply_value_number_binary_op(
                    &Symbol::qualified("math", "mul"),
                    x.clone(),
                    x.clone(),
                )?;
                cx.apply_value_number_binary_op(&Symbol::qualified("math", "add"), x2, x.clone())
            }),
        )))
        .unwrap();
    let out = cx
        .call_function(
            &Symbol::new("numeric-diff"),
            Args::new(vec![
                func,
                cx.factory().expr(quoted("x")).unwrap(),
                cx.factory()
                    .number_literal(Symbol::qualified("numbers", "f64"), "3.0".to_owned())
                    .unwrap(),
            ]),
        )
        .unwrap();
    let rendered = out
        .object()
        .display(&mut cx)
        .unwrap()
        .parse::<f64>()
        .unwrap();
    assert!((rendered - 7.0).abs() < 1.0e-3);
}

#[test]
fn symbolic_func_prefers_symbolic_derivative_before_numeric_fallback() {
    let mut cx = test_cx();
    let out = cx
        .eval_expr(Expr::Call {
            operator: Box::new(Expr::Symbol(Symbol::new("numeric-diff"))),
            args: vec![
                Expr::Call {
                    operator: Box::new(Expr::Symbol(Symbol::new("fn"))),
                    args: vec![
                        Expr::List(vec![Expr::Symbol(Symbol::new("x"))]),
                        Expr::Call {
                            operator: Box::new(Expr::Symbol(Symbol::new("+"))),
                            args: vec![
                                Expr::Call {
                                    operator: Box::new(Expr::Symbol(Symbol::new("*"))),
                                    args: vec![quoted("x"), quoted("x")],
                                },
                                quoted("x"),
                            ],
                        },
                    ],
                },
                quoted("x"),
                f64_number("3.0"),
            ],
        })
        .unwrap();
    assert_eq!(out.object().display(&mut cx).unwrap(), "7");
    let diagnostics = cx.take_diagnostics();
    assert!(
        diagnostics
            .iter()
            .any(|diagnostic| diagnostic.message.contains("method=auto"))
    );
}