Skip to main content

runmat_runtime/builtins/math/ode/
ode45.rs

1//! MATLAB-compatible `ode45` builtin.
2
3use runmat_builtins::{
4    BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
5    BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor, Value,
6};
7use runmat_macros::runtime_builtin;
8
9use crate::builtins::common::spec::{
10    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
11    ReductionNaN, ResidencyPolicy, ShapeRequirements,
12};
13use crate::builtins::math::ode::common::{
14    build_ode_output, ode_options_from_struct, parse_ode_input, parse_options, solve_ode, OdeMethod,
15};
16use crate::builtins::math::ode::type_resolvers::ode_solution_type;
17use crate::{build_runtime_error, BuiltinResult, RuntimeError};
18
19const NAME: &str = "ode45";
20
21const ODE45_OUTPUT_Y: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
22    name: "y",
23    ty: BuiltinParamType::NumericArray,
24    arity: BuiltinParamArity::Required,
25    default: None,
26    description: "Solution states evaluated over tspan.",
27}];
28
29const ODE45_OUTPUT_TY: [BuiltinParamDescriptor; 2] = [
30    BuiltinParamDescriptor {
31        name: "t",
32        ty: BuiltinParamType::NumericArray,
33        arity: BuiltinParamArity::Required,
34        default: None,
35        description: "Time points selected by solver.",
36    },
37    BuiltinParamDescriptor {
38        name: "y",
39        ty: BuiltinParamType::NumericArray,
40        arity: BuiltinParamArity::Required,
41        default: None,
42        description: "Solution states at each returned time point.",
43    },
44];
45
46const ODE45_INPUTS_CORE: [BuiltinParamDescriptor; 3] = [
47    BuiltinParamDescriptor {
48        name: "odefun",
49        ty: BuiltinParamType::Any,
50        arity: BuiltinParamArity::Required,
51        default: None,
52        description: "ODE right-hand-side callback f(t,y).",
53    },
54    BuiltinParamDescriptor {
55        name: "tspan",
56        ty: BuiltinParamType::Any,
57        arity: BuiltinParamArity::Required,
58        default: None,
59        description: "Time interval or monotonic time vector.",
60    },
61    BuiltinParamDescriptor {
62        name: "y0",
63        ty: BuiltinParamType::Any,
64        arity: BuiltinParamArity::Required,
65        default: None,
66        description: "Initial state vector/value.",
67    },
68];
69
70const ODE45_INPUTS_WITH_OPTIONS: [BuiltinParamDescriptor; 4] = [
71    BuiltinParamDescriptor {
72        name: "odefun",
73        ty: BuiltinParamType::Any,
74        arity: BuiltinParamArity::Required,
75        default: None,
76        description: "ODE right-hand-side callback f(t,y).",
77    },
78    BuiltinParamDescriptor {
79        name: "tspan",
80        ty: BuiltinParamType::Any,
81        arity: BuiltinParamArity::Required,
82        default: None,
83        description: "Time interval or monotonic time vector.",
84    },
85    BuiltinParamDescriptor {
86        name: "y0",
87        ty: BuiltinParamType::Any,
88        arity: BuiltinParamArity::Required,
89        default: None,
90        description: "Initial state vector/value.",
91    },
92    BuiltinParamDescriptor {
93        name: "options",
94        ty: BuiltinParamType::Any,
95        arity: BuiltinParamArity::Optional,
96        default: None,
97        description: "Optional struct with tolerances and step controls.",
98    },
99];
100
101const ODE45_SIGNATURES: [BuiltinSignatureDescriptor; 4] = [
102    BuiltinSignatureDescriptor {
103        label: "y = ode45(odefun, tspan, y0)",
104        inputs: &ODE45_INPUTS_CORE,
105        outputs: &ODE45_OUTPUT_Y,
106    },
107    BuiltinSignatureDescriptor {
108        label: "y = ode45(odefun, tspan, y0, options)",
109        inputs: &ODE45_INPUTS_WITH_OPTIONS,
110        outputs: &ODE45_OUTPUT_Y,
111    },
112    BuiltinSignatureDescriptor {
113        label: "[t, y] = ode45(odefun, tspan, y0)",
114        inputs: &ODE45_INPUTS_CORE,
115        outputs: &ODE45_OUTPUT_TY,
116    },
117    BuiltinSignatureDescriptor {
118        label: "[t, y] = ode45(odefun, tspan, y0, options)",
119        inputs: &ODE45_INPUTS_WITH_OPTIONS,
120        outputs: &ODE45_OUTPUT_TY,
121    },
122];
123
124const ODE45_ERROR_INVALID_ARGUMENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
125    code: "RM.ODE45.INVALID_ARGUMENT",
126    identifier: Some("RunMat:ode45:InvalidArgument"),
127    when: "Input argument count/options struct grammar is invalid.",
128    message: "ode45: invalid argument",
129};
130
131const ODE45_ERROR_INVALID_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
132    code: "RM.ODE45.INVALID_INPUT",
133    identifier: Some("RunMat:ode45:InvalidInput"),
134    when: "ODE input/state/callback semantics are invalid for integration.",
135    message: "ode45: invalid input",
136};
137
138const ODE45_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
139    code: "RM.ODE45.INTERNAL",
140    identifier: Some("RunMat:ode45:Internal"),
141    when: "Internal output materialization fails.",
142    message: "ode45: internal runtime failure",
143};
144
145const ODE45_ERRORS: [BuiltinErrorDescriptor; 3] = [
146    ODE45_ERROR_INVALID_ARGUMENT,
147    ODE45_ERROR_INVALID_INPUT,
148    ODE45_ERROR_INTERNAL,
149];
150
151pub const ODE45_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
152    signatures: &ODE45_SIGNATURES,
153    output_mode: BuiltinOutputMode::ByRequestedOutputCount,
154    completion_policy: BuiltinCompletionPolicy::Public,
155    errors: &ODE45_ERRORS,
156};
157
158fn ode45_error_with_detail(
159    error: &'static BuiltinErrorDescriptor,
160    detail: impl AsRef<str>,
161) -> RuntimeError {
162    let detail = detail.as_ref();
163    let message = if detail.starts_with("ode45:") {
164        detail.to_string()
165    } else {
166        format!("{}: {}", error.message, detail)
167    };
168    let mut builder = build_runtime_error(message).with_builtin(NAME);
169    if let Some(identifier) = error.identifier {
170        builder = builder.with_identifier(identifier);
171    }
172    builder.build()
173}
174
175fn ode45_map_error(err: RuntimeError, fallback: &'static BuiltinErrorDescriptor) -> RuntimeError {
176    if err.identifier().is_some() {
177        err
178    } else {
179        ode45_error_with_detail(fallback, err.message())
180    }
181}
182
183#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::ode::ode45")]
184pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
185    name: "ode45",
186    op_kind: GpuOpKind::Custom("ode-solve"),
187    supported_precisions: &[],
188    broadcast: BroadcastSemantics::None,
189    provider_hooks: &[],
190    constant_strategy: ConstantStrategy::InlineLiteral,
191    residency: ResidencyPolicy::GatherImmediately,
192    nan_mode: ReductionNaN::Include,
193    two_pass_threshold: None,
194    workgroup_size: None,
195    accepts_nan_mode: false,
196    notes: "Adaptive ODE integration runs on the host. RHS callbacks may call GPU-aware builtins.",
197};
198
199#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::math::ode::ode45")]
200pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
201    name: "ode45",
202    shape: ShapeRequirements::Any,
203    constant_strategy: ConstantStrategy::InlineLiteral,
204    elementwise: None,
205    reduction: None,
206    emits_nan: false,
207    notes: "ODE integration repeatedly invokes user callbacks and terminates fusion planning.",
208};
209
210#[runtime_builtin(
211    name = "ode45",
212    category = "math/ode",
213    summary = "Solve nonstiff ODE systems using adaptive Dormand-Prince 5(4) integration.",
214    keywords = "ode45,ode,nonstiff,dormand-prince,adaptive step",
215    accel = "sink",
216    type_resolver(ode_solution_type),
217    descriptor(crate::builtins::math::ode::ode45::ODE45_DESCRIPTOR),
218    builtin_path = "crate::builtins::math::ode::ode45"
219)]
220async fn ode45_builtin(
221    function: Value,
222    tspan: Value,
223    y0: Value,
224    rest: Vec<Value>,
225) -> BuiltinResult<Value> {
226    if rest.len() > 1 {
227        return Err(ode45_error_with_detail(
228            &ODE45_ERROR_INVALID_ARGUMENT,
229            "too many input arguments",
230        ));
231    }
232    let options = parse_options(NAME, rest.first())
233        .map_err(|err| ode45_map_error(err, &ODE45_ERROR_INVALID_ARGUMENT))?;
234    let opts = ode_options_from_struct(NAME, options.as_ref())
235        .map_err(|err| ode45_map_error(err, &ODE45_ERROR_INVALID_ARGUMENT))?;
236    let input = parse_ode_input(NAME, tspan, y0)
237        .await
238        .map_err(|err| ode45_map_error(err, &ODE45_ERROR_INVALID_INPUT))?;
239    let result = solve_ode(NAME, OdeMethod::Ode45, &function, &input, &opts)
240        .await
241        .map_err(|err| ode45_map_error(err, &ODE45_ERROR_INVALID_INPUT))?;
242    build_ode_output(NAME, result).map_err(|err| ode45_map_error(err, &ODE45_ERROR_INTERNAL))
243}
244
245#[cfg(test)]
246mod tests {
247    use super::*;
248    use futures::executor::block_on;
249    use runmat_builtins::Tensor;
250    use std::sync::Arc;
251
252    #[test]
253    fn ode45_scalar_decay_returns_reasonable_final_value() {
254        let _resolver =
255            crate::user_functions::install_semantic_function_resolver(Some(Arc::new(|_name| {
256                Some(0)
257            })));
258        let _invoker = crate::user_functions::install_semantic_function_invoker(Some(Arc::new(
259            move |_function, args, _requested_outputs| {
260                let y = match &args[1] {
261                    Value::Num(n) => *n,
262                    other => panic!("expected scalar state, got {other:?}"),
263                };
264                Box::pin(async move { Ok(Value::Num(-y)) })
265            },
266        )));
267
268        let out = block_on(ode45_builtin(
269            Value::FunctionHandle("decay".into()),
270            Value::Tensor(Tensor::new(vec![0.0, 1.0], vec![1, 2]).unwrap()),
271            Value::Num(1.0),
272            Vec::new(),
273        ))
274        .unwrap();
275
276        match out {
277            Value::Tensor(t) => {
278                assert_eq!(t.cols(), 1);
279                let last = t.data[t.rows() - 1];
280                assert!((last - (-1.0_f64).exp()).abs() < 5.0e-3);
281            }
282            other => panic!("unexpected output {other:?}"),
283        }
284    }
285
286    #[test]
287    fn ode45_rejects_nan_rhs() {
288        let _resolver =
289            crate::user_functions::install_semantic_function_resolver(Some(Arc::new(|_name| {
290                Some(0)
291            })));
292        let _invoker = crate::user_functions::install_semantic_function_invoker(Some(Arc::new(
293            move |_function, _args, _requested_outputs| {
294                Box::pin(async move { Ok(Value::Num(f64::NAN)) })
295            },
296        )));
297
298        let err = block_on(ode45_builtin(
299            Value::FunctionHandle("nan_rhs".into()),
300            Value::Tensor(Tensor::new(vec![0.0, 1.0], vec![1, 2]).unwrap()),
301            Value::Num(1.0),
302            Vec::new(),
303        ))
304        .expect_err("ode45 should reject NaN derivative values");
305
306        assert!(err.to_string().contains("function value must be finite"));
307    }
308
309    #[test]
310    fn ode45_accepts_external_function_handle_rhs() {
311        let _resolver =
312            crate::user_functions::install_semantic_function_resolver(Some(Arc::new(|name| {
313                (name == "pkg.decay").then_some(56)
314            })));
315        let _invoker = crate::user_functions::install_semantic_function_invoker(Some(Arc::new(
316            move |function, args, _requested_outputs| {
317                assert_eq!(function, 56);
318                let y = match &args[1] {
319                    Value::Num(n) => *n,
320                    other => panic!("expected scalar state, got {other:?}"),
321                };
322                Box::pin(async move { Ok(Value::Num(-y)) })
323            },
324        )));
325
326        let out = block_on(ode45_builtin(
327            Value::ExternalFunctionHandle("pkg.decay".to_string()),
328            Value::Tensor(Tensor::new(vec![0.0, 1.0], vec![1, 2]).unwrap()),
329            Value::Num(1.0),
330            Vec::new(),
331        ))
332        .unwrap();
333
334        match out {
335            Value::Tensor(t) => {
336                assert_eq!(t.cols(), 1);
337                let last = t.data[t.rows() - 1];
338                assert!(last.is_finite());
339                assert!(last > 0.0);
340                assert!(last < 1.0);
341            }
342            other => panic!("unexpected output {other:?}"),
343        }
344    }
345
346    #[test]
347    fn ode45_too_many_inputs_uses_stable_identifier() {
348        let err = block_on(ode45_builtin(
349            Value::FunctionHandle("decay".into()),
350            Value::Tensor(Tensor::new(vec![0.0, 1.0], vec![1, 2]).unwrap()),
351            Value::Num(1.0),
352            vec![Value::Num(1.0), Value::Num(2.0)],
353        ))
354        .expect_err("expected too many inputs error");
355        assert_eq!(err.identifier(), ODE45_ERROR_INVALID_ARGUMENT.identifier);
356    }
357
358    #[test]
359    fn ode45_descriptor_signatures_cover_surface() {
360        let labels: Vec<&str> = ODE45_DESCRIPTOR
361            .signatures
362            .iter()
363            .map(|signature| signature.label)
364            .collect();
365        assert_eq!(
366            labels,
367            vec![
368                "y = ode45(odefun, tspan, y0)",
369                "y = ode45(odefun, tspan, y0, options)",
370                "[t, y] = ode45(odefun, tspan, y0)",
371                "[t, y] = ode45(odefun, tspan, y0, options)",
372            ]
373        );
374    }
375
376    #[test]
377    fn ode45_descriptor_errors_have_stable_codes() {
378        let codes: Vec<&str> = ODE45_DESCRIPTOR
379            .errors
380            .iter()
381            .map(|error| error.code)
382            .collect();
383        assert_eq!(
384            codes,
385            vec![
386                "RM.ODE45.INVALID_ARGUMENT",
387                "RM.ODE45.INVALID_INPUT",
388                "RM.ODE45.INTERNAL",
389            ]
390        );
391    }
392}