Skip to main content

runmat_runtime/builtins/math/ode/
ode23.rs

1//! MATLAB-compatible `ode23` builtin.
2
3use runmat_builtins::{
4    BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
5    BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor, Value,
6};
7use runmat_macros::runtime_builtin;
8
9use crate::builtins::common::spec::{
10    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
11    ReductionNaN, ResidencyPolicy, ShapeRequirements,
12};
13use crate::builtins::math::ode::common::{
14    build_ode_output, ode_options_from_struct, parse_ode_input, parse_options, solve_ode, OdeMethod,
15};
16use crate::builtins::math::ode::type_resolvers::ode_solution_type;
17use crate::{build_runtime_error, BuiltinResult, RuntimeError};
18
19const NAME: &str = "ode23";
20
21const ODE23_OUTPUT_Y: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
22    name: "y",
23    ty: BuiltinParamType::NumericArray,
24    arity: BuiltinParamArity::Required,
25    default: None,
26    description: "Solution states evaluated over tspan.",
27}];
28
29const ODE23_OUTPUT_TY: [BuiltinParamDescriptor; 2] = [
30    BuiltinParamDescriptor {
31        name: "t",
32        ty: BuiltinParamType::NumericArray,
33        arity: BuiltinParamArity::Required,
34        default: None,
35        description: "Time points selected by solver.",
36    },
37    BuiltinParamDescriptor {
38        name: "y",
39        ty: BuiltinParamType::NumericArray,
40        arity: BuiltinParamArity::Required,
41        default: None,
42        description: "Solution states at each returned time point.",
43    },
44];
45
46const ODE23_INPUTS_CORE: [BuiltinParamDescriptor; 3] = [
47    BuiltinParamDescriptor {
48        name: "odefun",
49        ty: BuiltinParamType::Any,
50        arity: BuiltinParamArity::Required,
51        default: None,
52        description: "ODE right-hand-side callback f(t,y).",
53    },
54    BuiltinParamDescriptor {
55        name: "tspan",
56        ty: BuiltinParamType::Any,
57        arity: BuiltinParamArity::Required,
58        default: None,
59        description: "Time interval or monotonic time vector.",
60    },
61    BuiltinParamDescriptor {
62        name: "y0",
63        ty: BuiltinParamType::Any,
64        arity: BuiltinParamArity::Required,
65        default: None,
66        description: "Initial state vector/value.",
67    },
68];
69
70const ODE23_INPUTS_WITH_OPTIONS: [BuiltinParamDescriptor; 4] = [
71    BuiltinParamDescriptor {
72        name: "odefun",
73        ty: BuiltinParamType::Any,
74        arity: BuiltinParamArity::Required,
75        default: None,
76        description: "ODE right-hand-side callback f(t,y).",
77    },
78    BuiltinParamDescriptor {
79        name: "tspan",
80        ty: BuiltinParamType::Any,
81        arity: BuiltinParamArity::Required,
82        default: None,
83        description: "Time interval or monotonic time vector.",
84    },
85    BuiltinParamDescriptor {
86        name: "y0",
87        ty: BuiltinParamType::Any,
88        arity: BuiltinParamArity::Required,
89        default: None,
90        description: "Initial state vector/value.",
91    },
92    BuiltinParamDescriptor {
93        name: "options",
94        ty: BuiltinParamType::Any,
95        arity: BuiltinParamArity::Optional,
96        default: None,
97        description: "Optional struct with tolerances and step controls.",
98    },
99];
100
101const ODE23_SIGNATURES: [BuiltinSignatureDescriptor; 4] = [
102    BuiltinSignatureDescriptor {
103        label: "y = ode23(odefun, tspan, y0)",
104        inputs: &ODE23_INPUTS_CORE,
105        outputs: &ODE23_OUTPUT_Y,
106    },
107    BuiltinSignatureDescriptor {
108        label: "y = ode23(odefun, tspan, y0, options)",
109        inputs: &ODE23_INPUTS_WITH_OPTIONS,
110        outputs: &ODE23_OUTPUT_Y,
111    },
112    BuiltinSignatureDescriptor {
113        label: "[t, y] = ode23(odefun, tspan, y0)",
114        inputs: &ODE23_INPUTS_CORE,
115        outputs: &ODE23_OUTPUT_TY,
116    },
117    BuiltinSignatureDescriptor {
118        label: "[t, y] = ode23(odefun, tspan, y0, options)",
119        inputs: &ODE23_INPUTS_WITH_OPTIONS,
120        outputs: &ODE23_OUTPUT_TY,
121    },
122];
123
124const ODE23_ERROR_INVALID_ARGUMENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
125    code: "RM.ODE23.INVALID_ARGUMENT",
126    identifier: Some("RunMat:ode23:InvalidArgument"),
127    when: "Input argument count/options struct grammar is invalid.",
128    message: "ode23: invalid argument",
129};
130
131const ODE23_ERROR_INVALID_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
132    code: "RM.ODE23.INVALID_INPUT",
133    identifier: Some("RunMat:ode23:InvalidInput"),
134    when: "ODE input/state/callback semantics are invalid for integration.",
135    message: "ode23: invalid input",
136};
137
138const ODE23_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
139    code: "RM.ODE23.INTERNAL",
140    identifier: Some("RunMat:ode23:Internal"),
141    when: "Internal output materialization fails.",
142    message: "ode23: internal runtime failure",
143};
144
145const ODE23_ERRORS: [BuiltinErrorDescriptor; 3] = [
146    ODE23_ERROR_INVALID_ARGUMENT,
147    ODE23_ERROR_INVALID_INPUT,
148    ODE23_ERROR_INTERNAL,
149];
150
151pub const ODE23_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
152    signatures: &ODE23_SIGNATURES,
153    output_mode: BuiltinOutputMode::ByRequestedOutputCount,
154    completion_policy: BuiltinCompletionPolicy::Public,
155    errors: &ODE23_ERRORS,
156};
157
158fn ode23_error_with_detail(
159    error: &'static BuiltinErrorDescriptor,
160    detail: impl AsRef<str>,
161) -> RuntimeError {
162    let detail = detail.as_ref();
163    let message = if detail.starts_with("ode23:") {
164        detail.to_string()
165    } else {
166        format!("{}: {}", error.message, detail)
167    };
168    let mut builder = build_runtime_error(message).with_builtin(NAME);
169    if let Some(identifier) = error.identifier {
170        builder = builder.with_identifier(identifier);
171    }
172    builder.build()
173}
174
175fn ode23_map_error(err: RuntimeError, fallback: &'static BuiltinErrorDescriptor) -> RuntimeError {
176    if err.identifier().is_some() {
177        err
178    } else {
179        ode23_error_with_detail(fallback, err.message())
180    }
181}
182
183#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::ode::ode23")]
184pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
185    name: "ode23",
186    op_kind: GpuOpKind::Custom("ode-solve"),
187    supported_precisions: &[],
188    broadcast: BroadcastSemantics::None,
189    provider_hooks: &[],
190    constant_strategy: ConstantStrategy::InlineLiteral,
191    residency: ResidencyPolicy::GatherImmediately,
192    nan_mode: ReductionNaN::Include,
193    two_pass_threshold: None,
194    workgroup_size: None,
195    accepts_nan_mode: false,
196    notes: "Adaptive ODE integration runs on the host. RHS callbacks may call GPU-aware builtins.",
197};
198
199#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::math::ode::ode23")]
200pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
201    name: "ode23",
202    shape: ShapeRequirements::Any,
203    constant_strategy: ConstantStrategy::InlineLiteral,
204    elementwise: None,
205    reduction: None,
206    emits_nan: false,
207    notes: "ODE integration repeatedly invokes user callbacks and terminates fusion planning.",
208};
209
210#[runtime_builtin(
211    name = "ode23",
212    category = "math/ode",
213    summary = "Solve nonstiff ODE systems using adaptive Bogacki-Shampine 3(2) integration.",
214    keywords = "ode23,ode,nonstiff,bogacki-shampine,adaptive step",
215    accel = "sink",
216    type_resolver(ode_solution_type),
217    descriptor(crate::builtins::math::ode::ode23::ODE23_DESCRIPTOR),
218    builtin_path = "crate::builtins::math::ode::ode23"
219)]
220async fn ode23_builtin(
221    function: Value,
222    tspan: Value,
223    y0: Value,
224    rest: Vec<Value>,
225) -> BuiltinResult<Value> {
226    if rest.len() > 1 {
227        return Err(ode23_error_with_detail(
228            &ODE23_ERROR_INVALID_ARGUMENT,
229            "too many input arguments",
230        ));
231    }
232    let options = parse_options(NAME, rest.first())
233        .map_err(|err| ode23_map_error(err, &ODE23_ERROR_INVALID_ARGUMENT))?;
234    let opts = ode_options_from_struct(NAME, options.as_ref())
235        .map_err(|err| ode23_map_error(err, &ODE23_ERROR_INVALID_ARGUMENT))?;
236    let input = parse_ode_input(NAME, tspan, y0)
237        .await
238        .map_err(|err| ode23_map_error(err, &ODE23_ERROR_INVALID_INPUT))?;
239    let result = solve_ode(NAME, OdeMethod::Ode23, &function, &input, &opts)
240        .await
241        .map_err(|err| ode23_map_error(err, &ODE23_ERROR_INVALID_INPUT))?;
242    build_ode_output(NAME, result).map_err(|err| ode23_map_error(err, &ODE23_ERROR_INTERNAL))
243}
244
245#[cfg(test)]
246mod tests {
247    use super::*;
248    use futures::executor::block_on;
249    use runmat_builtins::Tensor;
250    use std::sync::Arc;
251
252    #[test]
253    fn ode23_supports_two_output_form() {
254        let _resolver =
255            crate::user_functions::install_semantic_function_resolver(Some(Arc::new(|_name| {
256                Some(0)
257            })));
258        let _invoker = crate::user_functions::install_semantic_function_invoker(Some(Arc::new(
259            move |_function, args, _requested_outputs| {
260                let y = match &args[1] {
261                    Value::Num(n) => *n,
262                    other => panic!("expected scalar state, got {other:?}"),
263                };
264                Box::pin(async move { Ok(Value::Num(-y)) })
265            },
266        )));
267
268        let _out_guard = crate::output_count::push_output_count(Some(2));
269        let out = block_on(ode23_builtin(
270            Value::FunctionHandle("decay".into()),
271            Value::Tensor(Tensor::new(vec![0.0, 0.5, 1.0], vec![1, 3]).unwrap()),
272            Value::Num(1.0),
273            Vec::new(),
274        ))
275        .unwrap();
276
277        match out {
278            Value::OutputList(values) => {
279                assert_eq!(values.len(), 2);
280            }
281            other => panic!("unexpected output {other:?}"),
282        }
283    }
284
285    #[test]
286    fn ode23_accepts_semantic_function_handle_rhs() {
287        let _invoker = crate::user_functions::install_semantic_function_invoker(Some(Arc::new(
288            move |function, args, _requested_outputs| {
289                assert_eq!(function, 55);
290                let y = match &args[1] {
291                    Value::Num(n) => *n,
292                    other => panic!("expected scalar state, got {other:?}"),
293                };
294                Box::pin(async move { Ok(Value::Num(-y)) })
295            },
296        )));
297
298        let out = block_on(ode23_builtin(
299            Value::BoundFunctionHandle {
300                name: "ode_decay".to_string(),
301                function: 55,
302            },
303            Value::Tensor(Tensor::new(vec![0.0, 1.0], vec![1, 2]).unwrap()),
304            Value::Num(1.0),
305            Vec::new(),
306        ))
307        .unwrap();
308
309        match out {
310            Value::Tensor(t) => {
311                assert_eq!(t.cols(), 1);
312                let last = t.data[t.rows() - 1];
313                assert!(last.is_finite());
314                assert!(last > 0.0);
315                assert!(last < 1.0);
316            }
317            other => panic!("unexpected output {other:?}"),
318        }
319    }
320
321    #[test]
322    fn ode23_too_many_inputs_uses_stable_identifier() {
323        let err = block_on(ode23_builtin(
324            Value::FunctionHandle("decay".into()),
325            Value::Tensor(Tensor::new(vec![0.0, 1.0], vec![1, 2]).unwrap()),
326            Value::Num(1.0),
327            vec![Value::Num(1.0), Value::Num(2.0)],
328        ))
329        .expect_err("expected too many inputs error");
330        assert_eq!(err.identifier(), ODE23_ERROR_INVALID_ARGUMENT.identifier);
331    }
332
333    #[test]
334    fn ode23_descriptor_signatures_cover_surface() {
335        let labels: Vec<&str> = ODE23_DESCRIPTOR
336            .signatures
337            .iter()
338            .map(|signature| signature.label)
339            .collect();
340        assert_eq!(
341            labels,
342            vec![
343                "y = ode23(odefun, tspan, y0)",
344                "y = ode23(odefun, tspan, y0, options)",
345                "[t, y] = ode23(odefun, tspan, y0)",
346                "[t, y] = ode23(odefun, tspan, y0, options)",
347            ]
348        );
349    }
350
351    #[test]
352    fn ode23_descriptor_errors_have_stable_codes() {
353        let codes: Vec<&str> = ODE23_DESCRIPTOR
354            .errors
355            .iter()
356            .map(|error| error.code)
357            .collect();
358        assert_eq!(
359            codes,
360            vec![
361                "RM.ODE23.INVALID_ARGUMENT",
362                "RM.ODE23.INVALID_INPUT",
363                "RM.ODE23.INTERNAL",
364            ]
365        );
366    }
367}