Skip to main content

runmat_runtime/builtins/math/ode/
ode45.rs

1//! MATLAB-compatible `ode45` builtin.
2
3use runmat_builtins::Value;
4use runmat_macros::runtime_builtin;
5
6use crate::builtins::common::spec::{
7    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
8    ReductionNaN, ResidencyPolicy, ShapeRequirements,
9};
10use crate::builtins::math::ode::common::{
11    build_ode_output, ode_options_from_struct, parse_ode_input, parse_options, solve_ode, OdeMethod,
12};
13use crate::builtins::math::ode::type_resolvers::ode_solution_type;
14use crate::BuiltinResult;
15
16const NAME: &str = "ode45";
17
18#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::ode::ode45")]
19pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
20    name: "ode45",
21    op_kind: GpuOpKind::Custom("ode-solve"),
22    supported_precisions: &[],
23    broadcast: BroadcastSemantics::None,
24    provider_hooks: &[],
25    constant_strategy: ConstantStrategy::InlineLiteral,
26    residency: ResidencyPolicy::GatherImmediately,
27    nan_mode: ReductionNaN::Include,
28    two_pass_threshold: None,
29    workgroup_size: None,
30    accepts_nan_mode: false,
31    notes: "Adaptive ODE integration runs on the host. RHS callbacks may call GPU-aware builtins.",
32};
33
34#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::math::ode::ode45")]
35pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
36    name: "ode45",
37    shape: ShapeRequirements::Any,
38    constant_strategy: ConstantStrategy::InlineLiteral,
39    elementwise: None,
40    reduction: None,
41    emits_nan: false,
42    notes: "ODE integration repeatedly invokes user callbacks and terminates fusion planning.",
43};
44
45#[runtime_builtin(
46    name = "ode45",
47    category = "math/ode",
48    summary = "Solve nonstiff ODE systems with an adaptive Dormand-Prince 5(4) method.",
49    keywords = "ode45,ode,nonstiff,dormand-prince,adaptive step",
50    accel = "sink",
51    type_resolver(ode_solution_type),
52    builtin_path = "crate::builtins::math::ode::ode45"
53)]
54async fn ode45_builtin(
55    function: Value,
56    tspan: Value,
57    y0: Value,
58    rest: Vec<Value>,
59) -> BuiltinResult<Value> {
60    if rest.len() > 1 {
61        return Err(crate::builtins::math::ode::common::ode_error(
62            NAME,
63            "ode45: too many input arguments",
64        ));
65    }
66    let options = parse_options(NAME, rest.first())?;
67    let opts = ode_options_from_struct(NAME, options.as_ref())?;
68    let input = parse_ode_input(NAME, tspan, y0).await?;
69    let result = solve_ode(NAME, OdeMethod::Ode45, &function, &input, &opts).await?;
70    build_ode_output(NAME, result)
71}
72
73#[cfg(test)]
74mod tests {
75    use super::*;
76    use futures::executor::block_on;
77    use runmat_builtins::Tensor;
78    use std::sync::Arc;
79
80    #[test]
81    fn ode45_scalar_decay_returns_reasonable_final_value() {
82        let _guard = crate::user_functions::install_user_function_invoker(Some(Arc::new(
83            move |_name, args| {
84                let y = match &args[1] {
85                    Value::Num(n) => *n,
86                    other => panic!("expected scalar state, got {other:?}"),
87                };
88                Box::pin(async move { Ok(Value::Num(-y)) })
89            },
90        )));
91
92        let out = block_on(ode45_builtin(
93            Value::FunctionHandle("decay".into()),
94            Value::Tensor(Tensor::new(vec![0.0, 1.0], vec![1, 2]).unwrap()),
95            Value::Num(1.0),
96            Vec::new(),
97        ))
98        .unwrap();
99
100        match out {
101            Value::Tensor(t) => {
102                assert_eq!(t.cols(), 1);
103                let last = t.data[t.rows() - 1];
104                assert!((last - (-1.0_f64).exp()).abs() < 5.0e-3);
105            }
106            other => panic!("unexpected output {other:?}"),
107        }
108    }
109
110    #[test]
111    fn ode45_rejects_nan_rhs() {
112        let _guard = crate::user_functions::install_user_function_invoker(Some(Arc::new(
113            move |_name, _args| Box::pin(async move { Ok(Value::Num(f64::NAN)) }),
114        )));
115
116        let err = block_on(ode45_builtin(
117            Value::FunctionHandle("nan_rhs".into()),
118            Value::Tensor(Tensor::new(vec![0.0, 1.0], vec![1, 2]).unwrap()),
119            Value::Num(1.0),
120            Vec::new(),
121        ))
122        .expect_err("ode45 should reject NaN derivative values");
123
124        assert!(err.to_string().contains("function value must be finite"));
125    }
126}