runmat_runtime/builtins/math/ode/
ode23.rs1use runmat_builtins::Value;
4use runmat_macros::runtime_builtin;
5
6use crate::builtins::common::spec::{
7 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
8 ReductionNaN, ResidencyPolicy, ShapeRequirements,
9};
10use crate::builtins::math::ode::common::{
11 build_ode_output, ode_options_from_struct, parse_ode_input, parse_options, solve_ode, OdeMethod,
12};
13use crate::builtins::math::ode::type_resolvers::ode_solution_type;
14use crate::BuiltinResult;
15
16const NAME: &str = "ode23";
17
18#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::ode::ode23")]
19pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
20 name: "ode23",
21 op_kind: GpuOpKind::Custom("ode-solve"),
22 supported_precisions: &[],
23 broadcast: BroadcastSemantics::None,
24 provider_hooks: &[],
25 constant_strategy: ConstantStrategy::InlineLiteral,
26 residency: ResidencyPolicy::GatherImmediately,
27 nan_mode: ReductionNaN::Include,
28 two_pass_threshold: None,
29 workgroup_size: None,
30 accepts_nan_mode: false,
31 notes: "Adaptive ODE integration runs on the host. RHS callbacks may call GPU-aware builtins.",
32};
33
34#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::math::ode::ode23")]
35pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
36 name: "ode23",
37 shape: ShapeRequirements::Any,
38 constant_strategy: ConstantStrategy::InlineLiteral,
39 elementwise: None,
40 reduction: None,
41 emits_nan: false,
42 notes: "ODE integration repeatedly invokes user callbacks and terminates fusion planning.",
43};
44
45#[runtime_builtin(
46 name = "ode23",
47 category = "math/ode",
48 summary = "Solve nonstiff ODE systems with an adaptive Bogacki-Shampine 3(2) method.",
49 keywords = "ode23,ode,nonstiff,bogacki-shampine,adaptive step",
50 accel = "sink",
51 type_resolver(ode_solution_type),
52 builtin_path = "crate::builtins::math::ode::ode23"
53)]
54async fn ode23_builtin(
55 function: Value,
56 tspan: Value,
57 y0: Value,
58 rest: Vec<Value>,
59) -> BuiltinResult<Value> {
60 if rest.len() > 1 {
61 return Err(crate::builtins::math::ode::common::ode_error(
62 NAME,
63 "ode23: too many input arguments",
64 ));
65 }
66 let options = parse_options(NAME, rest.first())?;
67 let opts = ode_options_from_struct(NAME, options.as_ref())?;
68 let input = parse_ode_input(NAME, tspan, y0).await?;
69 let result = solve_ode(NAME, OdeMethod::Ode23, &function, &input, &opts).await?;
70 build_ode_output(NAME, result)
71}
72
73#[cfg(test)]
74mod tests {
75 use super::*;
76 use futures::executor::block_on;
77 use runmat_builtins::Tensor;
78 use std::sync::Arc;
79
80 #[test]
81 fn ode23_supports_two_output_form() {
82 let _guard = crate::user_functions::install_user_function_invoker(Some(Arc::new(
83 move |_name, args| {
84 let y = match &args[1] {
85 Value::Num(n) => *n,
86 other => panic!("expected scalar state, got {other:?}"),
87 };
88 Box::pin(async move { Ok(Value::Num(-y)) })
89 },
90 )));
91
92 let _out_guard = crate::output_count::push_output_count(Some(2));
93 let out = block_on(ode23_builtin(
94 Value::FunctionHandle("decay".into()),
95 Value::Tensor(Tensor::new(vec![0.0, 0.5, 1.0], vec![1, 3]).unwrap()),
96 Value::Num(1.0),
97 Vec::new(),
98 ))
99 .unwrap();
100
101 match out {
102 Value::OutputList(values) => {
103 assert_eq!(values.len(), 2);
104 }
105 other => panic!("unexpected output {other:?}"),
106 }
107 }
108}