Skip to main content

runmat_runtime/builtins/math/ode/
ode15s.rs

1//! MATLAB-compatible `ode15s` builtin.
2
3use runmat_builtins::Value;
4use runmat_macros::runtime_builtin;
5
6use crate::builtins::common::spec::{
7    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
8    ReductionNaN, ResidencyPolicy, ShapeRequirements,
9};
10use crate::builtins::math::ode::common::{
11    build_ode_output, ode_options_from_struct, parse_ode_input, parse_options, solve_ode, OdeMethod,
12};
13use crate::builtins::math::ode::type_resolvers::ode_solution_type;
14use crate::BuiltinResult;
15
16const NAME: &str = "ode15s";
17
18#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::ode::ode15s")]
19pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
20    name: "ode15s",
21    op_kind: GpuOpKind::Custom("ode-solve-stiff"),
22    supported_precisions: &[],
23    broadcast: BroadcastSemantics::None,
24    provider_hooks: &[],
25    constant_strategy: ConstantStrategy::InlineLiteral,
26    residency: ResidencyPolicy::GatherImmediately,
27    nan_mode: ReductionNaN::Include,
28    two_pass_threshold: None,
29    workgroup_size: None,
30    accepts_nan_mode: false,
31    notes: "Stiff ODE integration runs on the host. RHS callbacks may call GPU-aware builtins.",
32};
33
34#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::math::ode::ode15s")]
35pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
36    name: "ode15s",
37    shape: ShapeRequirements::Any,
38    constant_strategy: ConstantStrategy::InlineLiteral,
39    elementwise: None,
40    reduction: None,
41    emits_nan: false,
42    notes: "ODE integration repeatedly invokes user callbacks and terminates fusion planning.",
43};
44
45#[runtime_builtin(
46    name = "ode15s",
47    category = "math/ode",
48    summary = "Solve stiff ODE systems with an adaptive implicit host-side integrator.",
49    keywords = "ode15s,ode,stiff,implicit,adaptive step",
50    accel = "sink",
51    type_resolver(ode_solution_type),
52    builtin_path = "crate::builtins::math::ode::ode15s"
53)]
54async fn ode15s_builtin(
55    function: Value,
56    tspan: Value,
57    y0: Value,
58    rest: Vec<Value>,
59) -> BuiltinResult<Value> {
60    if rest.len() > 1 {
61        return Err(crate::builtins::math::ode::common::ode_error(
62            NAME,
63            "ode15s: too many input arguments",
64        ));
65    }
66    let options = parse_options(NAME, rest.first())?;
67    let opts = ode_options_from_struct(NAME, options.as_ref())?;
68    let input = parse_ode_input(NAME, tspan, y0).await?;
69    let result = solve_ode(NAME, OdeMethod::Ode15s, &function, &input, &opts).await?;
70    build_ode_output(NAME, result)
71}
72
73#[cfg(test)]
74mod tests {
75    use super::*;
76    use futures::executor::block_on;
77    use runmat_builtins::{StructValue, Tensor};
78    use std::sync::Arc;
79
80    #[test]
81    fn ode15s_handles_linear_stiff_decay() {
82        let _guard = crate::user_functions::install_user_function_invoker(Some(Arc::new(
83            move |_name, args| {
84                let y = match &args[1] {
85                    Value::Num(n) => *n,
86                    other => panic!("expected scalar state, got {other:?}"),
87                };
88                Box::pin(async move { Ok(Value::Num(-15.0 * y)) })
89            },
90        )));
91
92        let out = block_on(ode15s_builtin(
93            Value::FunctionHandle("stiff_decay".into()),
94            Value::Tensor(Tensor::new(vec![0.0, 1.0], vec![1, 2]).unwrap()),
95            Value::Num(1.0),
96            Vec::new(),
97        ))
98        .unwrap();
99
100        match out {
101            Value::Tensor(t) => {
102                assert_eq!(t.cols(), 1);
103                let last = t.data[t.rows() - 1];
104                assert!(last.is_finite());
105                assert!(last > 0.0);
106                assert!(last < 0.1);
107            }
108            other => panic!("unexpected output {other:?}"),
109        }
110    }
111
112    #[test]
113    fn ode15s_accepts_picard_unstable_stiff_step_with_newton() {
114        let _guard = crate::user_functions::install_user_function_invoker(Some(Arc::new(
115            move |_name, args| {
116                let y = match &args[1] {
117                    Value::Num(n) => *n,
118                    other => panic!("expected scalar state, got {other:?}"),
119                };
120                Box::pin(async move { Ok(Value::Num(-1000.0 * y)) })
121            },
122        )));
123        let mut options = StructValue::new();
124        options.insert("RelTol", Value::Num(1.0e6));
125        options.insert("AbsTol", Value::Num(1.0e6));
126        options.insert("InitialStep", Value::Num(0.1));
127        options.insert("MaxStep", Value::Num(0.1));
128        options.insert("MaxSteps", Value::Num(2.0));
129
130        let out = block_on(ode15s_builtin(
131            Value::FunctionHandle("very_stiff_decay".into()),
132            Value::Tensor(Tensor::new(vec![0.0, 0.1], vec![1, 2]).unwrap()),
133            Value::Num(1.0),
134            vec![Value::Struct(options)],
135        ))
136        .unwrap();
137
138        match out {
139            Value::Tensor(t) => {
140                assert_eq!(t.cols(), 1);
141                let last = t.data[t.rows() - 1];
142                assert!(last.is_finite());
143                assert!(last > 0.0);
144                assert!(last < 0.02);
145            }
146            other => panic!("unexpected output {other:?}"),
147        }
148    }
149}