Skip to main content

runmat_runtime/builtins/math/ode/
ode15s.rs

1//! MATLAB-compatible `ode15s` builtin.
2
3use runmat_builtins::{
4    BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
5    BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor, Value,
6};
7use runmat_macros::runtime_builtin;
8
9use crate::builtins::common::spec::{
10    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
11    ReductionNaN, ResidencyPolicy, ShapeRequirements,
12};
13use crate::builtins::math::ode::common::{
14    build_ode_output, ode_options_from_struct, parse_ode_input, parse_options, solve_ode, OdeMethod,
15};
16use crate::builtins::math::ode::type_resolvers::ode_solution_type;
17use crate::{build_runtime_error, BuiltinResult, RuntimeError};
18
19const NAME: &str = "ode15s";
20
21const ODE15S_OUTPUT_Y: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
22    name: "y",
23    ty: BuiltinParamType::NumericArray,
24    arity: BuiltinParamArity::Required,
25    default: None,
26    description: "Solution states evaluated over tspan.",
27}];
28
29const ODE15S_OUTPUT_TY: [BuiltinParamDescriptor; 2] = [
30    BuiltinParamDescriptor {
31        name: "t",
32        ty: BuiltinParamType::NumericArray,
33        arity: BuiltinParamArity::Required,
34        default: None,
35        description: "Time points selected by solver.",
36    },
37    BuiltinParamDescriptor {
38        name: "y",
39        ty: BuiltinParamType::NumericArray,
40        arity: BuiltinParamArity::Required,
41        default: None,
42        description: "Solution states at each returned time point.",
43    },
44];
45
46const ODE15S_INPUTS_CORE: [BuiltinParamDescriptor; 3] = [
47    BuiltinParamDescriptor {
48        name: "odefun",
49        ty: BuiltinParamType::Any,
50        arity: BuiltinParamArity::Required,
51        default: None,
52        description: "ODE right-hand-side callback f(t,y).",
53    },
54    BuiltinParamDescriptor {
55        name: "tspan",
56        ty: BuiltinParamType::Any,
57        arity: BuiltinParamArity::Required,
58        default: None,
59        description: "Time interval or monotonic time vector.",
60    },
61    BuiltinParamDescriptor {
62        name: "y0",
63        ty: BuiltinParamType::Any,
64        arity: BuiltinParamArity::Required,
65        default: None,
66        description: "Initial state vector/value.",
67    },
68];
69
70const ODE15S_INPUTS_WITH_OPTIONS: [BuiltinParamDescriptor; 4] = [
71    BuiltinParamDescriptor {
72        name: "odefun",
73        ty: BuiltinParamType::Any,
74        arity: BuiltinParamArity::Required,
75        default: None,
76        description: "ODE right-hand-side callback f(t,y).",
77    },
78    BuiltinParamDescriptor {
79        name: "tspan",
80        ty: BuiltinParamType::Any,
81        arity: BuiltinParamArity::Required,
82        default: None,
83        description: "Time interval or monotonic time vector.",
84    },
85    BuiltinParamDescriptor {
86        name: "y0",
87        ty: BuiltinParamType::Any,
88        arity: BuiltinParamArity::Required,
89        default: None,
90        description: "Initial state vector/value.",
91    },
92    BuiltinParamDescriptor {
93        name: "options",
94        ty: BuiltinParamType::Any,
95        arity: BuiltinParamArity::Optional,
96        default: None,
97        description: "Optional struct with tolerances and step controls.",
98    },
99];
100
101const ODE15S_SIGNATURES: [BuiltinSignatureDescriptor; 4] = [
102    BuiltinSignatureDescriptor {
103        label: "y = ode15s(odefun, tspan, y0)",
104        inputs: &ODE15S_INPUTS_CORE,
105        outputs: &ODE15S_OUTPUT_Y,
106    },
107    BuiltinSignatureDescriptor {
108        label: "y = ode15s(odefun, tspan, y0, options)",
109        inputs: &ODE15S_INPUTS_WITH_OPTIONS,
110        outputs: &ODE15S_OUTPUT_Y,
111    },
112    BuiltinSignatureDescriptor {
113        label: "[t, y] = ode15s(odefun, tspan, y0)",
114        inputs: &ODE15S_INPUTS_CORE,
115        outputs: &ODE15S_OUTPUT_TY,
116    },
117    BuiltinSignatureDescriptor {
118        label: "[t, y] = ode15s(odefun, tspan, y0, options)",
119        inputs: &ODE15S_INPUTS_WITH_OPTIONS,
120        outputs: &ODE15S_OUTPUT_TY,
121    },
122];
123
124const ODE15S_ERROR_INVALID_ARGUMENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
125    code: "RM.ODE15S.INVALID_ARGUMENT",
126    identifier: Some("RunMat:ode15s:InvalidArgument"),
127    when: "Input argument count/options struct grammar is invalid.",
128    message: "ode15s: invalid argument",
129};
130
131const ODE15S_ERROR_INVALID_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
132    code: "RM.ODE15S.INVALID_INPUT",
133    identifier: Some("RunMat:ode15s:InvalidInput"),
134    when: "ODE input/state/callback semantics are invalid for integration.",
135    message: "ode15s: invalid input",
136};
137
138const ODE15S_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
139    code: "RM.ODE15S.INTERNAL",
140    identifier: Some("RunMat:ode15s:Internal"),
141    when: "Internal output materialization fails.",
142    message: "ode15s: internal runtime failure",
143};
144
145const ODE15S_ERRORS: [BuiltinErrorDescriptor; 3] = [
146    ODE15S_ERROR_INVALID_ARGUMENT,
147    ODE15S_ERROR_INVALID_INPUT,
148    ODE15S_ERROR_INTERNAL,
149];
150
151pub const ODE15S_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
152    signatures: &ODE15S_SIGNATURES,
153    output_mode: BuiltinOutputMode::ByRequestedOutputCount,
154    completion_policy: BuiltinCompletionPolicy::Public,
155    errors: &ODE15S_ERRORS,
156};
157
158fn ode15s_error_with_detail(
159    error: &'static BuiltinErrorDescriptor,
160    detail: impl AsRef<str>,
161) -> RuntimeError {
162    let detail = detail.as_ref();
163    let message = if detail.starts_with("ode15s:") {
164        detail.to_string()
165    } else {
166        format!("{}: {}", error.message, detail)
167    };
168    let mut builder = build_runtime_error(message).with_builtin(NAME);
169    if let Some(identifier) = error.identifier {
170        builder = builder.with_identifier(identifier);
171    }
172    builder.build()
173}
174
175fn ode15s_map_error(err: RuntimeError, fallback: &'static BuiltinErrorDescriptor) -> RuntimeError {
176    if err.identifier().is_some() {
177        err
178    } else {
179        ode15s_error_with_detail(fallback, err.message())
180    }
181}
182
183#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::ode::ode15s")]
184pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
185    name: "ode15s",
186    op_kind: GpuOpKind::Custom("ode-solve-stiff"),
187    supported_precisions: &[],
188    broadcast: BroadcastSemantics::None,
189    provider_hooks: &[],
190    constant_strategy: ConstantStrategy::InlineLiteral,
191    residency: ResidencyPolicy::GatherImmediately,
192    nan_mode: ReductionNaN::Include,
193    two_pass_threshold: None,
194    workgroup_size: None,
195    accepts_nan_mode: false,
196    notes: "Stiff ODE integration runs on the host. RHS callbacks may call GPU-aware builtins.",
197};
198
199#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::math::ode::ode15s")]
200pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
201    name: "ode15s",
202    shape: ShapeRequirements::Any,
203    constant_strategy: ConstantStrategy::InlineLiteral,
204    elementwise: None,
205    reduction: None,
206    emits_nan: false,
207    notes: "ODE integration repeatedly invokes user callbacks and terminates fusion planning.",
208};
209
210#[runtime_builtin(
211    name = "ode15s",
212    category = "math/ode",
213    summary = "Solve stiff ODE systems with adaptive implicit integration.",
214    keywords = "ode15s,ode,stiff,implicit,adaptive step",
215    accel = "sink",
216    type_resolver(ode_solution_type),
217    descriptor(crate::builtins::math::ode::ode15s::ODE15S_DESCRIPTOR),
218    builtin_path = "crate::builtins::math::ode::ode15s"
219)]
220async fn ode15s_builtin(
221    function: Value,
222    tspan: Value,
223    y0: Value,
224    rest: Vec<Value>,
225) -> BuiltinResult<Value> {
226    if rest.len() > 1 {
227        return Err(ode15s_error_with_detail(
228            &ODE15S_ERROR_INVALID_ARGUMENT,
229            "too many input arguments",
230        ));
231    }
232    let options = parse_options(NAME, rest.first())
233        .map_err(|err| ode15s_map_error(err, &ODE15S_ERROR_INVALID_ARGUMENT))?;
234    let opts = ode_options_from_struct(NAME, options.as_ref())
235        .map_err(|err| ode15s_map_error(err, &ODE15S_ERROR_INVALID_ARGUMENT))?;
236    let input = parse_ode_input(NAME, tspan, y0)
237        .await
238        .map_err(|err| ode15s_map_error(err, &ODE15S_ERROR_INVALID_INPUT))?;
239    let result = solve_ode(NAME, OdeMethod::Ode15s, &function, &input, &opts)
240        .await
241        .map_err(|err| ode15s_map_error(err, &ODE15S_ERROR_INVALID_INPUT))?;
242    build_ode_output(NAME, result).map_err(|err| ode15s_map_error(err, &ODE15S_ERROR_INTERNAL))
243}
244
245#[cfg(test)]
246mod tests {
247    use super::*;
248    use futures::executor::block_on;
249    use runmat_builtins::{StructValue, Tensor};
250    use std::sync::Arc;
251
252    #[test]
253    fn ode15s_handles_linear_stiff_decay() {
254        let _resolver =
255            crate::user_functions::install_semantic_function_resolver(Some(Arc::new(|_name| {
256                Some(0)
257            })));
258        let _invoker = crate::user_functions::install_semantic_function_invoker(Some(Arc::new(
259            move |_function, args, _requested_outputs| {
260                let y = match &args[1] {
261                    Value::Num(n) => *n,
262                    other => panic!("expected scalar state, got {other:?}"),
263                };
264                Box::pin(async move { Ok(Value::Num(-15.0 * y)) })
265            },
266        )));
267
268        let out = block_on(ode15s_builtin(
269            Value::FunctionHandle("stiff_decay".into()),
270            Value::Tensor(Tensor::new(vec![0.0, 1.0], vec![1, 2]).unwrap()),
271            Value::Num(1.0),
272            Vec::new(),
273        ))
274        .unwrap();
275
276        match out {
277            Value::Tensor(t) => {
278                assert_eq!(t.cols(), 1);
279                let last = t.data[t.rows() - 1];
280                assert!(last.is_finite());
281                assert!(last > 0.0);
282                assert!(last < 0.1);
283            }
284            other => panic!("unexpected output {other:?}"),
285        }
286    }
287
288    #[test]
289    fn ode15s_accepts_picard_unstable_stiff_step_with_newton() {
290        let _resolver =
291            crate::user_functions::install_semantic_function_resolver(Some(Arc::new(|_name| {
292                Some(0)
293            })));
294        let _invoker = crate::user_functions::install_semantic_function_invoker(Some(Arc::new(
295            move |_function, args, _requested_outputs| {
296                let y = match &args[1] {
297                    Value::Num(n) => *n,
298                    other => panic!("expected scalar state, got {other:?}"),
299                };
300                Box::pin(async move { Ok(Value::Num(-1000.0 * y)) })
301            },
302        )));
303        let mut options = StructValue::new();
304        options.insert("RelTol", Value::Num(1.0e6));
305        options.insert("AbsTol", Value::Num(1.0e6));
306        options.insert("InitialStep", Value::Num(0.1));
307        options.insert("MaxStep", Value::Num(0.1));
308        options.insert("MaxSteps", Value::Num(2.0));
309
310        let out = block_on(ode15s_builtin(
311            Value::FunctionHandle("very_stiff_decay".into()),
312            Value::Tensor(Tensor::new(vec![0.0, 0.1], vec![1, 2]).unwrap()),
313            Value::Num(1.0),
314            vec![Value::Struct(options)],
315        ))
316        .unwrap();
317
318        match out {
319            Value::Tensor(t) => {
320                assert_eq!(t.cols(), 1);
321                let last = t.data[t.rows() - 1];
322                assert!(last.is_finite());
323                assert!(last > 0.0);
324                assert!(last < 0.02);
325            }
326            other => panic!("unexpected output {other:?}"),
327        }
328    }
329
330    #[test]
331    fn ode15s_accepts_semantic_function_handle_rhs() {
332        let _invoker = crate::user_functions::install_semantic_function_invoker(Some(Arc::new(
333            move |function, args, _requested_outputs| {
334                assert_eq!(function, 57);
335                let y = match &args[1] {
336                    Value::Num(n) => *n,
337                    other => panic!("expected scalar state, got {other:?}"),
338                };
339                Box::pin(async move { Ok(Value::Num(-15.0 * y)) })
340            },
341        )));
342
343        let out = block_on(ode15s_builtin(
344            Value::BoundFunctionHandle {
345                name: "ode_stiff_decay".to_string(),
346                function: 57,
347            },
348            Value::Tensor(Tensor::new(vec![0.0, 1.0], vec![1, 2]).unwrap()),
349            Value::Num(1.0),
350            Vec::new(),
351        ))
352        .unwrap();
353
354        match out {
355            Value::Tensor(t) => {
356                assert_eq!(t.cols(), 1);
357                let last = t.data[t.rows() - 1];
358                assert!(last.is_finite());
359                assert!(last > 0.0);
360                assert!(last < 0.1);
361            }
362            other => panic!("unexpected output {other:?}"),
363        }
364    }
365
366    #[test]
367    fn ode15s_too_many_inputs_uses_stable_identifier() {
368        let err = block_on(ode15s_builtin(
369            Value::FunctionHandle("stiff_decay".into()),
370            Value::Tensor(Tensor::new(vec![0.0, 1.0], vec![1, 2]).unwrap()),
371            Value::Num(1.0),
372            vec![Value::Num(1.0), Value::Num(2.0)],
373        ))
374        .expect_err("expected too many inputs error");
375        assert_eq!(err.identifier(), ODE15S_ERROR_INVALID_ARGUMENT.identifier);
376    }
377
378    #[test]
379    fn ode15s_descriptor_signatures_cover_surface() {
380        let labels: Vec<&str> = ODE15S_DESCRIPTOR
381            .signatures
382            .iter()
383            .map(|signature| signature.label)
384            .collect();
385        assert_eq!(
386            labels,
387            vec![
388                "y = ode15s(odefun, tspan, y0)",
389                "y = ode15s(odefun, tspan, y0, options)",
390                "[t, y] = ode15s(odefun, tspan, y0)",
391                "[t, y] = ode15s(odefun, tspan, y0, options)",
392            ]
393        );
394    }
395
396    #[test]
397    fn ode15s_descriptor_errors_have_stable_codes() {
398        let codes: Vec<&str> = ODE15S_DESCRIPTOR
399            .errors
400            .iter()
401            .map(|error| error.code)
402            .collect();
403        assert_eq!(
404            codes,
405            vec![
406                "RM.ODE15S.INVALID_ARGUMENT",
407                "RM.ODE15S.INVALID_INPUT",
408                "RM.ODE15S.INTERNAL",
409            ]
410        );
411    }
412}