1use runmat_builtins::{
4 BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
5 BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor, Value,
6};
7use runmat_macros::runtime_builtin;
8
9use crate::builtins::common::spec::{
10 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
11 ReductionNaN, ResidencyPolicy, ShapeRequirements,
12};
13use crate::builtins::math::ode::common::{
14 build_ode_output, ode_options_from_struct, parse_ode_input, parse_options, solve_ode, OdeMethod,
15};
16use crate::builtins::math::ode::type_resolvers::ode_solution_type;
17use crate::{build_runtime_error, BuiltinResult, RuntimeError};
18
19const NAME: &str = "ode15s";
20
21const ODE15S_OUTPUT_Y: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
22 name: "y",
23 ty: BuiltinParamType::NumericArray,
24 arity: BuiltinParamArity::Required,
25 default: None,
26 description: "Solution states evaluated over tspan.",
27}];
28
29const ODE15S_OUTPUT_TY: [BuiltinParamDescriptor; 2] = [
30 BuiltinParamDescriptor {
31 name: "t",
32 ty: BuiltinParamType::NumericArray,
33 arity: BuiltinParamArity::Required,
34 default: None,
35 description: "Time points selected by solver.",
36 },
37 BuiltinParamDescriptor {
38 name: "y",
39 ty: BuiltinParamType::NumericArray,
40 arity: BuiltinParamArity::Required,
41 default: None,
42 description: "Solution states at each returned time point.",
43 },
44];
45
46const ODE15S_INPUTS_CORE: [BuiltinParamDescriptor; 3] = [
47 BuiltinParamDescriptor {
48 name: "odefun",
49 ty: BuiltinParamType::Any,
50 arity: BuiltinParamArity::Required,
51 default: None,
52 description: "ODE right-hand-side callback f(t,y).",
53 },
54 BuiltinParamDescriptor {
55 name: "tspan",
56 ty: BuiltinParamType::Any,
57 arity: BuiltinParamArity::Required,
58 default: None,
59 description: "Time interval or monotonic time vector.",
60 },
61 BuiltinParamDescriptor {
62 name: "y0",
63 ty: BuiltinParamType::Any,
64 arity: BuiltinParamArity::Required,
65 default: None,
66 description: "Initial state vector/value.",
67 },
68];
69
70const ODE15S_INPUTS_WITH_OPTIONS: [BuiltinParamDescriptor; 4] = [
71 BuiltinParamDescriptor {
72 name: "odefun",
73 ty: BuiltinParamType::Any,
74 arity: BuiltinParamArity::Required,
75 default: None,
76 description: "ODE right-hand-side callback f(t,y).",
77 },
78 BuiltinParamDescriptor {
79 name: "tspan",
80 ty: BuiltinParamType::Any,
81 arity: BuiltinParamArity::Required,
82 default: None,
83 description: "Time interval or monotonic time vector.",
84 },
85 BuiltinParamDescriptor {
86 name: "y0",
87 ty: BuiltinParamType::Any,
88 arity: BuiltinParamArity::Required,
89 default: None,
90 description: "Initial state vector/value.",
91 },
92 BuiltinParamDescriptor {
93 name: "options",
94 ty: BuiltinParamType::Any,
95 arity: BuiltinParamArity::Optional,
96 default: None,
97 description: "Optional struct with tolerances and step controls.",
98 },
99];
100
101const ODE15S_SIGNATURES: [BuiltinSignatureDescriptor; 4] = [
102 BuiltinSignatureDescriptor {
103 label: "y = ode15s(odefun, tspan, y0)",
104 inputs: &ODE15S_INPUTS_CORE,
105 outputs: &ODE15S_OUTPUT_Y,
106 },
107 BuiltinSignatureDescriptor {
108 label: "y = ode15s(odefun, tspan, y0, options)",
109 inputs: &ODE15S_INPUTS_WITH_OPTIONS,
110 outputs: &ODE15S_OUTPUT_Y,
111 },
112 BuiltinSignatureDescriptor {
113 label: "[t, y] = ode15s(odefun, tspan, y0)",
114 inputs: &ODE15S_INPUTS_CORE,
115 outputs: &ODE15S_OUTPUT_TY,
116 },
117 BuiltinSignatureDescriptor {
118 label: "[t, y] = ode15s(odefun, tspan, y0, options)",
119 inputs: &ODE15S_INPUTS_WITH_OPTIONS,
120 outputs: &ODE15S_OUTPUT_TY,
121 },
122];
123
124const ODE15S_ERROR_INVALID_ARGUMENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
125 code: "RM.ODE15S.INVALID_ARGUMENT",
126 identifier: Some("RunMat:ode15s:InvalidArgument"),
127 when: "Input argument count/options struct grammar is invalid.",
128 message: "ode15s: invalid argument",
129};
130
131const ODE15S_ERROR_INVALID_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
132 code: "RM.ODE15S.INVALID_INPUT",
133 identifier: Some("RunMat:ode15s:InvalidInput"),
134 when: "ODE input/state/callback semantics are invalid for integration.",
135 message: "ode15s: invalid input",
136};
137
138const ODE15S_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
139 code: "RM.ODE15S.INTERNAL",
140 identifier: Some("RunMat:ode15s:Internal"),
141 when: "Internal output materialization fails.",
142 message: "ode15s: internal runtime failure",
143};
144
145const ODE15S_ERRORS: [BuiltinErrorDescriptor; 3] = [
146 ODE15S_ERROR_INVALID_ARGUMENT,
147 ODE15S_ERROR_INVALID_INPUT,
148 ODE15S_ERROR_INTERNAL,
149];
150
151pub const ODE15S_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
152 signatures: &ODE15S_SIGNATURES,
153 output_mode: BuiltinOutputMode::ByRequestedOutputCount,
154 completion_policy: BuiltinCompletionPolicy::Public,
155 errors: &ODE15S_ERRORS,
156};
157
158fn ode15s_error_with_detail(
159 error: &'static BuiltinErrorDescriptor,
160 detail: impl AsRef<str>,
161) -> RuntimeError {
162 let detail = detail.as_ref();
163 let message = if detail.starts_with("ode15s:") {
164 detail.to_string()
165 } else {
166 format!("{}: {}", error.message, detail)
167 };
168 let mut builder = build_runtime_error(message).with_builtin(NAME);
169 if let Some(identifier) = error.identifier {
170 builder = builder.with_identifier(identifier);
171 }
172 builder.build()
173}
174
175fn ode15s_map_error(err: RuntimeError, fallback: &'static BuiltinErrorDescriptor) -> RuntimeError {
176 if err.identifier().is_some() {
177 err
178 } else {
179 ode15s_error_with_detail(fallback, err.message())
180 }
181}
182
183#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::ode::ode15s")]
184pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
185 name: "ode15s",
186 op_kind: GpuOpKind::Custom("ode-solve-stiff"),
187 supported_precisions: &[],
188 broadcast: BroadcastSemantics::None,
189 provider_hooks: &[],
190 constant_strategy: ConstantStrategy::InlineLiteral,
191 residency: ResidencyPolicy::GatherImmediately,
192 nan_mode: ReductionNaN::Include,
193 two_pass_threshold: None,
194 workgroup_size: None,
195 accepts_nan_mode: false,
196 notes: "Stiff ODE integration runs on the host. RHS callbacks may call GPU-aware builtins.",
197};
198
199#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::math::ode::ode15s")]
200pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
201 name: "ode15s",
202 shape: ShapeRequirements::Any,
203 constant_strategy: ConstantStrategy::InlineLiteral,
204 elementwise: None,
205 reduction: None,
206 emits_nan: false,
207 notes: "ODE integration repeatedly invokes user callbacks and terminates fusion planning.",
208};
209
210#[runtime_builtin(
211 name = "ode15s",
212 category = "math/ode",
213 summary = "Solve stiff ODE systems with adaptive implicit integration.",
214 keywords = "ode15s,ode,stiff,implicit,adaptive step",
215 accel = "sink",
216 type_resolver(ode_solution_type),
217 descriptor(crate::builtins::math::ode::ode15s::ODE15S_DESCRIPTOR),
218 builtin_path = "crate::builtins::math::ode::ode15s"
219)]
220async fn ode15s_builtin(
221 function: Value,
222 tspan: Value,
223 y0: Value,
224 rest: Vec<Value>,
225) -> BuiltinResult<Value> {
226 if rest.len() > 1 {
227 return Err(ode15s_error_with_detail(
228 &ODE15S_ERROR_INVALID_ARGUMENT,
229 "too many input arguments",
230 ));
231 }
232 let options = parse_options(NAME, rest.first())
233 .map_err(|err| ode15s_map_error(err, &ODE15S_ERROR_INVALID_ARGUMENT))?;
234 let opts = ode_options_from_struct(NAME, options.as_ref())
235 .map_err(|err| ode15s_map_error(err, &ODE15S_ERROR_INVALID_ARGUMENT))?;
236 let input = parse_ode_input(NAME, tspan, y0)
237 .await
238 .map_err(|err| ode15s_map_error(err, &ODE15S_ERROR_INVALID_INPUT))?;
239 let result = solve_ode(NAME, OdeMethod::Ode15s, &function, &input, &opts)
240 .await
241 .map_err(|err| ode15s_map_error(err, &ODE15S_ERROR_INVALID_INPUT))?;
242 build_ode_output(NAME, result).map_err(|err| ode15s_map_error(err, &ODE15S_ERROR_INTERNAL))
243}
244
245#[cfg(test)]
246mod tests {
247 use super::*;
248 use futures::executor::block_on;
249 use runmat_builtins::{StructValue, Tensor};
250 use std::sync::Arc;
251
252 #[test]
253 fn ode15s_handles_linear_stiff_decay() {
254 let _resolver =
255 crate::user_functions::install_semantic_function_resolver(Some(Arc::new(|_name| {
256 Some(0)
257 })));
258 let _invoker = crate::user_functions::install_semantic_function_invoker(Some(Arc::new(
259 move |_function, args, _requested_outputs| {
260 let y = match &args[1] {
261 Value::Num(n) => *n,
262 other => panic!("expected scalar state, got {other:?}"),
263 };
264 Box::pin(async move { Ok(Value::Num(-15.0 * y)) })
265 },
266 )));
267
268 let out = block_on(ode15s_builtin(
269 Value::FunctionHandle("stiff_decay".into()),
270 Value::Tensor(Tensor::new(vec![0.0, 1.0], vec![1, 2]).unwrap()),
271 Value::Num(1.0),
272 Vec::new(),
273 ))
274 .unwrap();
275
276 match out {
277 Value::Tensor(t) => {
278 assert_eq!(t.cols(), 1);
279 let last = t.data[t.rows() - 1];
280 assert!(last.is_finite());
281 assert!(last > 0.0);
282 assert!(last < 0.1);
283 }
284 other => panic!("unexpected output {other:?}"),
285 }
286 }
287
288 #[test]
289 fn ode15s_accepts_picard_unstable_stiff_step_with_newton() {
290 let _resolver =
291 crate::user_functions::install_semantic_function_resolver(Some(Arc::new(|_name| {
292 Some(0)
293 })));
294 let _invoker = crate::user_functions::install_semantic_function_invoker(Some(Arc::new(
295 move |_function, args, _requested_outputs| {
296 let y = match &args[1] {
297 Value::Num(n) => *n,
298 other => panic!("expected scalar state, got {other:?}"),
299 };
300 Box::pin(async move { Ok(Value::Num(-1000.0 * y)) })
301 },
302 )));
303 let mut options = StructValue::new();
304 options.insert("RelTol", Value::Num(1.0e6));
305 options.insert("AbsTol", Value::Num(1.0e6));
306 options.insert("InitialStep", Value::Num(0.1));
307 options.insert("MaxStep", Value::Num(0.1));
308 options.insert("MaxSteps", Value::Num(2.0));
309
310 let out = block_on(ode15s_builtin(
311 Value::FunctionHandle("very_stiff_decay".into()),
312 Value::Tensor(Tensor::new(vec![0.0, 0.1], vec![1, 2]).unwrap()),
313 Value::Num(1.0),
314 vec![Value::Struct(options)],
315 ))
316 .unwrap();
317
318 match out {
319 Value::Tensor(t) => {
320 assert_eq!(t.cols(), 1);
321 let last = t.data[t.rows() - 1];
322 assert!(last.is_finite());
323 assert!(last > 0.0);
324 assert!(last < 0.02);
325 }
326 other => panic!("unexpected output {other:?}"),
327 }
328 }
329
330 #[test]
331 fn ode15s_accepts_semantic_function_handle_rhs() {
332 let _invoker = crate::user_functions::install_semantic_function_invoker(Some(Arc::new(
333 move |function, args, _requested_outputs| {
334 assert_eq!(function, 57);
335 let y = match &args[1] {
336 Value::Num(n) => *n,
337 other => panic!("expected scalar state, got {other:?}"),
338 };
339 Box::pin(async move { Ok(Value::Num(-15.0 * y)) })
340 },
341 )));
342
343 let out = block_on(ode15s_builtin(
344 Value::BoundFunctionHandle {
345 name: "ode_stiff_decay".to_string(),
346 function: 57,
347 },
348 Value::Tensor(Tensor::new(vec![0.0, 1.0], vec![1, 2]).unwrap()),
349 Value::Num(1.0),
350 Vec::new(),
351 ))
352 .unwrap();
353
354 match out {
355 Value::Tensor(t) => {
356 assert_eq!(t.cols(), 1);
357 let last = t.data[t.rows() - 1];
358 assert!(last.is_finite());
359 assert!(last > 0.0);
360 assert!(last < 0.1);
361 }
362 other => panic!("unexpected output {other:?}"),
363 }
364 }
365
366 #[test]
367 fn ode15s_too_many_inputs_uses_stable_identifier() {
368 let err = block_on(ode15s_builtin(
369 Value::FunctionHandle("stiff_decay".into()),
370 Value::Tensor(Tensor::new(vec![0.0, 1.0], vec![1, 2]).unwrap()),
371 Value::Num(1.0),
372 vec![Value::Num(1.0), Value::Num(2.0)],
373 ))
374 .expect_err("expected too many inputs error");
375 assert_eq!(err.identifier(), ODE15S_ERROR_INVALID_ARGUMENT.identifier);
376 }
377
378 #[test]
379 fn ode15s_descriptor_signatures_cover_surface() {
380 let labels: Vec<&str> = ODE15S_DESCRIPTOR
381 .signatures
382 .iter()
383 .map(|signature| signature.label)
384 .collect();
385 assert_eq!(
386 labels,
387 vec![
388 "y = ode15s(odefun, tspan, y0)",
389 "y = ode15s(odefun, tspan, y0, options)",
390 "[t, y] = ode15s(odefun, tspan, y0)",
391 "[t, y] = ode15s(odefun, tspan, y0, options)",
392 ]
393 );
394 }
395
396 #[test]
397 fn ode15s_descriptor_errors_have_stable_codes() {
398 let codes: Vec<&str> = ODE15S_DESCRIPTOR
399 .errors
400 .iter()
401 .map(|error| error.code)
402 .collect();
403 assert_eq!(
404 codes,
405 vec![
406 "RM.ODE15S.INVALID_ARGUMENT",
407 "RM.ODE15S.INVALID_INPUT",
408 "RM.ODE15S.INTERNAL",
409 ]
410 );
411 }
412}