runmat_runtime/builtins/math/ode/
ode45.rs1use runmat_builtins::Value;
4use runmat_macros::runtime_builtin;
5
6use crate::builtins::common::spec::{
7 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
8 ReductionNaN, ResidencyPolicy, ShapeRequirements,
9};
10use crate::builtins::math::ode::common::{
11 build_ode_output, ode_options_from_struct, parse_ode_input, parse_options, solve_ode, OdeMethod,
12};
13use crate::builtins::math::ode::type_resolvers::ode_solution_type;
14use crate::BuiltinResult;
15
16const NAME: &str = "ode45";
17
18#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::ode::ode45")]
19pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
20 name: "ode45",
21 op_kind: GpuOpKind::Custom("ode-solve"),
22 supported_precisions: &[],
23 broadcast: BroadcastSemantics::None,
24 provider_hooks: &[],
25 constant_strategy: ConstantStrategy::InlineLiteral,
26 residency: ResidencyPolicy::GatherImmediately,
27 nan_mode: ReductionNaN::Include,
28 two_pass_threshold: None,
29 workgroup_size: None,
30 accepts_nan_mode: false,
31 notes: "Adaptive ODE integration runs on the host. RHS callbacks may call GPU-aware builtins.",
32};
33
34#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::math::ode::ode45")]
35pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
36 name: "ode45",
37 shape: ShapeRequirements::Any,
38 constant_strategy: ConstantStrategy::InlineLiteral,
39 elementwise: None,
40 reduction: None,
41 emits_nan: false,
42 notes: "ODE integration repeatedly invokes user callbacks and terminates fusion planning.",
43};
44
45#[runtime_builtin(
46 name = "ode45",
47 category = "math/ode",
48 summary = "Solve nonstiff ODE systems with an adaptive Dormand-Prince 5(4) method.",
49 keywords = "ode45,ode,nonstiff,dormand-prince,adaptive step",
50 accel = "sink",
51 type_resolver(ode_solution_type),
52 builtin_path = "crate::builtins::math::ode::ode45"
53)]
54async fn ode45_builtin(
55 function: Value,
56 tspan: Value,
57 y0: Value,
58 rest: Vec<Value>,
59) -> BuiltinResult<Value> {
60 if rest.len() > 1 {
61 return Err(crate::builtins::math::ode::common::ode_error(
62 NAME,
63 "ode45: too many input arguments",
64 ));
65 }
66 let options = parse_options(NAME, rest.first())?;
67 let opts = ode_options_from_struct(NAME, options.as_ref())?;
68 let input = parse_ode_input(NAME, tspan, y0).await?;
69 let result = solve_ode(NAME, OdeMethod::Ode45, &function, &input, &opts).await?;
70 build_ode_output(NAME, result)
71}
72
73#[cfg(test)]
74mod tests {
75 use super::*;
76 use futures::executor::block_on;
77 use runmat_builtins::Tensor;
78 use std::sync::Arc;
79
80 #[test]
81 fn ode45_scalar_decay_returns_reasonable_final_value() {
82 let _guard = crate::user_functions::install_user_function_invoker(Some(Arc::new(
83 move |_name, args| {
84 let y = match &args[1] {
85 Value::Num(n) => *n,
86 other => panic!("expected scalar state, got {other:?}"),
87 };
88 Box::pin(async move { Ok(Value::Num(-y)) })
89 },
90 )));
91
92 let out = block_on(ode45_builtin(
93 Value::FunctionHandle("decay".into()),
94 Value::Tensor(Tensor::new(vec![0.0, 1.0], vec![1, 2]).unwrap()),
95 Value::Num(1.0),
96 Vec::new(),
97 ))
98 .unwrap();
99
100 match out {
101 Value::Tensor(t) => {
102 assert_eq!(t.cols(), 1);
103 let last = t.data[t.rows() - 1];
104 assert!((last - (-1.0_f64).exp()).abs() < 5.0e-3);
105 }
106 other => panic!("unexpected output {other:?}"),
107 }
108 }
109
110 #[test]
111 fn ode45_rejects_nan_rhs() {
112 let _guard = crate::user_functions::install_user_function_invoker(Some(Arc::new(
113 move |_name, _args| Box::pin(async move { Ok(Value::Num(f64::NAN)) }),
114 )));
115
116 let err = block_on(ode45_builtin(
117 Value::FunctionHandle("nan_rhs".into()),
118 Value::Tensor(Tensor::new(vec![0.0, 1.0], vec![1, 2]).unwrap()),
119 Value::Num(1.0),
120 Vec::new(),
121 ))
122 .expect_err("ode45 should reject NaN derivative values");
123
124 assert!(err.to_string().contains("function value must be finite"));
125 }
126}