Skip to main content

runmat_runtime/builtins/math/ode/
ode23.rs

1//! MATLAB-compatible `ode23` builtin.
2
3use runmat_builtins::Value;
4use runmat_macros::runtime_builtin;
5
6use crate::builtins::common::spec::{
7    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
8    ReductionNaN, ResidencyPolicy, ShapeRequirements,
9};
10use crate::builtins::math::ode::common::{
11    build_ode_output, ode_options_from_struct, parse_ode_input, parse_options, solve_ode, OdeMethod,
12};
13use crate::builtins::math::ode::type_resolvers::ode_solution_type;
14use crate::BuiltinResult;
15
16const NAME: &str = "ode23";
17
18#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::ode::ode23")]
19pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
20    name: "ode23",
21    op_kind: GpuOpKind::Custom("ode-solve"),
22    supported_precisions: &[],
23    broadcast: BroadcastSemantics::None,
24    provider_hooks: &[],
25    constant_strategy: ConstantStrategy::InlineLiteral,
26    residency: ResidencyPolicy::GatherImmediately,
27    nan_mode: ReductionNaN::Include,
28    two_pass_threshold: None,
29    workgroup_size: None,
30    accepts_nan_mode: false,
31    notes: "Adaptive ODE integration runs on the host. RHS callbacks may call GPU-aware builtins.",
32};
33
34#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::math::ode::ode23")]
35pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
36    name: "ode23",
37    shape: ShapeRequirements::Any,
38    constant_strategy: ConstantStrategy::InlineLiteral,
39    elementwise: None,
40    reduction: None,
41    emits_nan: false,
42    notes: "ODE integration repeatedly invokes user callbacks and terminates fusion planning.",
43};
44
45#[runtime_builtin(
46    name = "ode23",
47    category = "math/ode",
48    summary = "Solve nonstiff ODE systems with an adaptive Bogacki-Shampine 3(2) method.",
49    keywords = "ode23,ode,nonstiff,bogacki-shampine,adaptive step",
50    accel = "sink",
51    type_resolver(ode_solution_type),
52    builtin_path = "crate::builtins::math::ode::ode23"
53)]
54async fn ode23_builtin(
55    function: Value,
56    tspan: Value,
57    y0: Value,
58    rest: Vec<Value>,
59) -> BuiltinResult<Value> {
60    if rest.len() > 1 {
61        return Err(crate::builtins::math::ode::common::ode_error(
62            NAME,
63            "ode23: too many input arguments",
64        ));
65    }
66    let options = parse_options(NAME, rest.first())?;
67    let opts = ode_options_from_struct(NAME, options.as_ref())?;
68    let input = parse_ode_input(NAME, tspan, y0).await?;
69    let result = solve_ode(NAME, OdeMethod::Ode23, &function, &input, &opts).await?;
70    build_ode_output(NAME, result)
71}
72
73#[cfg(test)]
74mod tests {
75    use super::*;
76    use futures::executor::block_on;
77    use runmat_builtins::Tensor;
78    use std::sync::Arc;
79
80    #[test]
81    fn ode23_supports_two_output_form() {
82        let _guard = crate::user_functions::install_user_function_invoker(Some(Arc::new(
83            move |_name, args| {
84                let y = match &args[1] {
85                    Value::Num(n) => *n,
86                    other => panic!("expected scalar state, got {other:?}"),
87                };
88                Box::pin(async move { Ok(Value::Num(-y)) })
89            },
90        )));
91
92        let _out_guard = crate::output_count::push_output_count(Some(2));
93        let out = block_on(ode23_builtin(
94            Value::FunctionHandle("decay".into()),
95            Value::Tensor(Tensor::new(vec![0.0, 0.5, 1.0], vec![1, 3]).unwrap()),
96            Value::Num(1.0),
97            Vec::new(),
98        ))
99        .unwrap();
100
101        match out {
102            Value::OutputList(values) => {
103                assert_eq!(values.len(), 2);
104            }
105            other => panic!("unexpected output {other:?}"),
106        }
107    }
108}