1use runmat_builtins::{
4 BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
5 BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor, Value,
6};
7use runmat_macros::runtime_builtin;
8
9use crate::builtins::common::spec::{
10 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
11 ReductionNaN, ResidencyPolicy, ShapeRequirements,
12};
13use crate::builtins::math::ode::common::{
14 build_ode_output, ode_options_from_struct, parse_ode_input, parse_options, solve_ode, OdeMethod,
15};
16use crate::builtins::math::ode::type_resolvers::ode_solution_type;
17use crate::{build_runtime_error, BuiltinResult, RuntimeError};
18
19const NAME: &str = "ode23";
20
21const ODE23_OUTPUT_Y: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
22 name: "y",
23 ty: BuiltinParamType::NumericArray,
24 arity: BuiltinParamArity::Required,
25 default: None,
26 description: "Solution states evaluated over tspan.",
27}];
28
29const ODE23_OUTPUT_TY: [BuiltinParamDescriptor; 2] = [
30 BuiltinParamDescriptor {
31 name: "t",
32 ty: BuiltinParamType::NumericArray,
33 arity: BuiltinParamArity::Required,
34 default: None,
35 description: "Time points selected by solver.",
36 },
37 BuiltinParamDescriptor {
38 name: "y",
39 ty: BuiltinParamType::NumericArray,
40 arity: BuiltinParamArity::Required,
41 default: None,
42 description: "Solution states at each returned time point.",
43 },
44];
45
46const ODE23_INPUTS_CORE: [BuiltinParamDescriptor; 3] = [
47 BuiltinParamDescriptor {
48 name: "odefun",
49 ty: BuiltinParamType::Any,
50 arity: BuiltinParamArity::Required,
51 default: None,
52 description: "ODE right-hand-side callback f(t,y).",
53 },
54 BuiltinParamDescriptor {
55 name: "tspan",
56 ty: BuiltinParamType::Any,
57 arity: BuiltinParamArity::Required,
58 default: None,
59 description: "Time interval or monotonic time vector.",
60 },
61 BuiltinParamDescriptor {
62 name: "y0",
63 ty: BuiltinParamType::Any,
64 arity: BuiltinParamArity::Required,
65 default: None,
66 description: "Initial state vector/value.",
67 },
68];
69
70const ODE23_INPUTS_WITH_OPTIONS: [BuiltinParamDescriptor; 4] = [
71 BuiltinParamDescriptor {
72 name: "odefun",
73 ty: BuiltinParamType::Any,
74 arity: BuiltinParamArity::Required,
75 default: None,
76 description: "ODE right-hand-side callback f(t,y).",
77 },
78 BuiltinParamDescriptor {
79 name: "tspan",
80 ty: BuiltinParamType::Any,
81 arity: BuiltinParamArity::Required,
82 default: None,
83 description: "Time interval or monotonic time vector.",
84 },
85 BuiltinParamDescriptor {
86 name: "y0",
87 ty: BuiltinParamType::Any,
88 arity: BuiltinParamArity::Required,
89 default: None,
90 description: "Initial state vector/value.",
91 },
92 BuiltinParamDescriptor {
93 name: "options",
94 ty: BuiltinParamType::Any,
95 arity: BuiltinParamArity::Optional,
96 default: None,
97 description: "Optional struct with tolerances and step controls.",
98 },
99];
100
101const ODE23_SIGNATURES: [BuiltinSignatureDescriptor; 4] = [
102 BuiltinSignatureDescriptor {
103 label: "y = ode23(odefun, tspan, y0)",
104 inputs: &ODE23_INPUTS_CORE,
105 outputs: &ODE23_OUTPUT_Y,
106 },
107 BuiltinSignatureDescriptor {
108 label: "y = ode23(odefun, tspan, y0, options)",
109 inputs: &ODE23_INPUTS_WITH_OPTIONS,
110 outputs: &ODE23_OUTPUT_Y,
111 },
112 BuiltinSignatureDescriptor {
113 label: "[t, y] = ode23(odefun, tspan, y0)",
114 inputs: &ODE23_INPUTS_CORE,
115 outputs: &ODE23_OUTPUT_TY,
116 },
117 BuiltinSignatureDescriptor {
118 label: "[t, y] = ode23(odefun, tspan, y0, options)",
119 inputs: &ODE23_INPUTS_WITH_OPTIONS,
120 outputs: &ODE23_OUTPUT_TY,
121 },
122];
123
124const ODE23_ERROR_INVALID_ARGUMENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
125 code: "RM.ODE23.INVALID_ARGUMENT",
126 identifier: Some("RunMat:ode23:InvalidArgument"),
127 when: "Input argument count/options struct grammar is invalid.",
128 message: "ode23: invalid argument",
129};
130
131const ODE23_ERROR_INVALID_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
132 code: "RM.ODE23.INVALID_INPUT",
133 identifier: Some("RunMat:ode23:InvalidInput"),
134 when: "ODE input/state/callback semantics are invalid for integration.",
135 message: "ode23: invalid input",
136};
137
138const ODE23_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
139 code: "RM.ODE23.INTERNAL",
140 identifier: Some("RunMat:ode23:Internal"),
141 when: "Internal output materialization fails.",
142 message: "ode23: internal runtime failure",
143};
144
145const ODE23_ERRORS: [BuiltinErrorDescriptor; 3] = [
146 ODE23_ERROR_INVALID_ARGUMENT,
147 ODE23_ERROR_INVALID_INPUT,
148 ODE23_ERROR_INTERNAL,
149];
150
151pub const ODE23_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
152 signatures: &ODE23_SIGNATURES,
153 output_mode: BuiltinOutputMode::ByRequestedOutputCount,
154 completion_policy: BuiltinCompletionPolicy::Public,
155 errors: &ODE23_ERRORS,
156};
157
158fn ode23_error_with_detail(
159 error: &'static BuiltinErrorDescriptor,
160 detail: impl AsRef<str>,
161) -> RuntimeError {
162 let detail = detail.as_ref();
163 let message = if detail.starts_with("ode23:") {
164 detail.to_string()
165 } else {
166 format!("{}: {}", error.message, detail)
167 };
168 let mut builder = build_runtime_error(message).with_builtin(NAME);
169 if let Some(identifier) = error.identifier {
170 builder = builder.with_identifier(identifier);
171 }
172 builder.build()
173}
174
175fn ode23_map_error(err: RuntimeError, fallback: &'static BuiltinErrorDescriptor) -> RuntimeError {
176 if err.identifier().is_some() {
177 err
178 } else {
179 ode23_error_with_detail(fallback, err.message())
180 }
181}
182
183#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::ode::ode23")]
184pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
185 name: "ode23",
186 op_kind: GpuOpKind::Custom("ode-solve"),
187 supported_precisions: &[],
188 broadcast: BroadcastSemantics::None,
189 provider_hooks: &[],
190 constant_strategy: ConstantStrategy::InlineLiteral,
191 residency: ResidencyPolicy::GatherImmediately,
192 nan_mode: ReductionNaN::Include,
193 two_pass_threshold: None,
194 workgroup_size: None,
195 accepts_nan_mode: false,
196 notes: "Adaptive ODE integration runs on the host. RHS callbacks may call GPU-aware builtins.",
197};
198
199#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::math::ode::ode23")]
200pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
201 name: "ode23",
202 shape: ShapeRequirements::Any,
203 constant_strategy: ConstantStrategy::InlineLiteral,
204 elementwise: None,
205 reduction: None,
206 emits_nan: false,
207 notes: "ODE integration repeatedly invokes user callbacks and terminates fusion planning.",
208};
209
210#[runtime_builtin(
211 name = "ode23",
212 category = "math/ode",
213 summary = "Solve nonstiff ODE systems using adaptive Bogacki-Shampine 3(2) integration.",
214 keywords = "ode23,ode,nonstiff,bogacki-shampine,adaptive step",
215 accel = "sink",
216 type_resolver(ode_solution_type),
217 descriptor(crate::builtins::math::ode::ode23::ODE23_DESCRIPTOR),
218 builtin_path = "crate::builtins::math::ode::ode23"
219)]
220async fn ode23_builtin(
221 function: Value,
222 tspan: Value,
223 y0: Value,
224 rest: Vec<Value>,
225) -> BuiltinResult<Value> {
226 if rest.len() > 1 {
227 return Err(ode23_error_with_detail(
228 &ODE23_ERROR_INVALID_ARGUMENT,
229 "too many input arguments",
230 ));
231 }
232 let options = parse_options(NAME, rest.first())
233 .map_err(|err| ode23_map_error(err, &ODE23_ERROR_INVALID_ARGUMENT))?;
234 let opts = ode_options_from_struct(NAME, options.as_ref())
235 .map_err(|err| ode23_map_error(err, &ODE23_ERROR_INVALID_ARGUMENT))?;
236 let input = parse_ode_input(NAME, tspan, y0)
237 .await
238 .map_err(|err| ode23_map_error(err, &ODE23_ERROR_INVALID_INPUT))?;
239 let result = solve_ode(NAME, OdeMethod::Ode23, &function, &input, &opts)
240 .await
241 .map_err(|err| ode23_map_error(err, &ODE23_ERROR_INVALID_INPUT))?;
242 build_ode_output(NAME, result).map_err(|err| ode23_map_error(err, &ODE23_ERROR_INTERNAL))
243}
244
245#[cfg(test)]
246mod tests {
247 use super::*;
248 use futures::executor::block_on;
249 use runmat_builtins::Tensor;
250 use std::sync::Arc;
251
252 #[test]
253 fn ode23_supports_two_output_form() {
254 let _resolver =
255 crate::user_functions::install_semantic_function_resolver(Some(Arc::new(|_name| {
256 Some(0)
257 })));
258 let _invoker = crate::user_functions::install_semantic_function_invoker(Some(Arc::new(
259 move |_function, args, _requested_outputs| {
260 let y = match &args[1] {
261 Value::Num(n) => *n,
262 other => panic!("expected scalar state, got {other:?}"),
263 };
264 Box::pin(async move { Ok(Value::Num(-y)) })
265 },
266 )));
267
268 let _out_guard = crate::output_count::push_output_count(Some(2));
269 let out = block_on(ode23_builtin(
270 Value::FunctionHandle("decay".into()),
271 Value::Tensor(Tensor::new(vec![0.0, 0.5, 1.0], vec![1, 3]).unwrap()),
272 Value::Num(1.0),
273 Vec::new(),
274 ))
275 .unwrap();
276
277 match out {
278 Value::OutputList(values) => {
279 assert_eq!(values.len(), 2);
280 }
281 other => panic!("unexpected output {other:?}"),
282 }
283 }
284
285 #[test]
286 fn ode23_accepts_semantic_function_handle_rhs() {
287 let _invoker = crate::user_functions::install_semantic_function_invoker(Some(Arc::new(
288 move |function, args, _requested_outputs| {
289 assert_eq!(function, 55);
290 let y = match &args[1] {
291 Value::Num(n) => *n,
292 other => panic!("expected scalar state, got {other:?}"),
293 };
294 Box::pin(async move { Ok(Value::Num(-y)) })
295 },
296 )));
297
298 let out = block_on(ode23_builtin(
299 Value::BoundFunctionHandle {
300 name: "ode_decay".to_string(),
301 function: 55,
302 },
303 Value::Tensor(Tensor::new(vec![0.0, 1.0], vec![1, 2]).unwrap()),
304 Value::Num(1.0),
305 Vec::new(),
306 ))
307 .unwrap();
308
309 match out {
310 Value::Tensor(t) => {
311 assert_eq!(t.cols(), 1);
312 let last = t.data[t.rows() - 1];
313 assert!(last.is_finite());
314 assert!(last > 0.0);
315 assert!(last < 1.0);
316 }
317 other => panic!("unexpected output {other:?}"),
318 }
319 }
320
321 #[test]
322 fn ode23_too_many_inputs_uses_stable_identifier() {
323 let err = block_on(ode23_builtin(
324 Value::FunctionHandle("decay".into()),
325 Value::Tensor(Tensor::new(vec![0.0, 1.0], vec![1, 2]).unwrap()),
326 Value::Num(1.0),
327 vec![Value::Num(1.0), Value::Num(2.0)],
328 ))
329 .expect_err("expected too many inputs error");
330 assert_eq!(err.identifier(), ODE23_ERROR_INVALID_ARGUMENT.identifier);
331 }
332
333 #[test]
334 fn ode23_descriptor_signatures_cover_surface() {
335 let labels: Vec<&str> = ODE23_DESCRIPTOR
336 .signatures
337 .iter()
338 .map(|signature| signature.label)
339 .collect();
340 assert_eq!(
341 labels,
342 vec![
343 "y = ode23(odefun, tspan, y0)",
344 "y = ode23(odefun, tspan, y0, options)",
345 "[t, y] = ode23(odefun, tspan, y0)",
346 "[t, y] = ode23(odefun, tspan, y0, options)",
347 ]
348 );
349 }
350
351 #[test]
352 fn ode23_descriptor_errors_have_stable_codes() {
353 let codes: Vec<&str> = ODE23_DESCRIPTOR
354 .errors
355 .iter()
356 .map(|error| error.code)
357 .collect();
358 assert_eq!(
359 codes,
360 vec![
361 "RM.ODE23.INVALID_ARGUMENT",
362 "RM.ODE23.INVALID_INPUT",
363 "RM.ODE23.INTERNAL",
364 ]
365 );
366 }
367}