1use runmat_builtins::{LogicalArray, StructValue, Tensor, Value};
4use runmat_macros::runtime_builtin;
5
6use crate::builtins::common::spec::{
7 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
8 ReductionNaN, ResidencyPolicy, ShapeRequirements,
9};
10use crate::builtins::math::optim::common::{call_function, optim_error};
11use crate::builtins::math::optim::type_resolvers::numerical_integral_type;
12use crate::BuiltinResult;
13
14const NAME: &str = "integral";
15const DEFAULT_ABS_TOL: f64 = 1.0e-10;
16const DEFAULT_REL_TOL: f64 = 1.0e-6;
17const DEFAULT_MAX_FUN_EVALS: usize = 10_000;
18const MAX_DEPTH: usize = 30;
19
20#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::optim::integral")]
21pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
22 name: "integral",
23 op_kind: GpuOpKind::Custom("adaptive-quadrature"),
24 supported_precisions: &[],
25 broadcast: BroadcastSemantics::None,
26 provider_hooks: &[],
27 constant_strategy: ConstantStrategy::InlineLiteral,
28 residency: ResidencyPolicy::GatherImmediately,
29 nan_mode: ReductionNaN::Include,
30 two_pass_threshold: None,
31 workgroup_size: None,
32 accepts_nan_mode: false,
33 notes: "Host adaptive quadrature solver. Callback computations may use GPU-aware builtins, but the adaptive integration loop runs on the CPU.",
34};
35
36#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::math::optim::integral")]
37pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
38 name: "integral",
39 shape: ShapeRequirements::Any,
40 constant_strategy: ConstantStrategy::InlineLiteral,
41 elementwise: None,
42 reduction: None,
43 emits_nan: false,
44 notes: "Adaptive integration repeatedly invokes user code and terminates fusion planning.",
45};
46
47#[runtime_builtin(
48 name = "integral",
49 category = "math/optim",
50 summary = "Approximate a finite scalar definite integral using adaptive quadrature.",
51 keywords = "integral,numerical integration,adaptive quadrature,quadrature,function handle",
52 accel = "sink",
53 type_resolver(numerical_integral_type),
54 builtin_path = "crate::builtins::math::optim::integral"
55)]
56async fn integral_builtin(
57 function: Value,
58 a: Value,
59 b: Value,
60 rest: Vec<Value>,
61) -> BuiltinResult<Value> {
62 let options = IntegralOptions::parse(rest)?;
63 let a = scalar_bound("lower bound", a).await?;
64 let b = scalar_bound("upper bound", b).await?;
65 if a == b {
66 return Ok(Value::Num(0.0));
67 }
68
69 let sign = if b < a { -1.0 } else { 1.0 };
70 let lo = a.min(b);
71 let hi = a.max(b);
72 let result = integrate_finite_scalar(&function, lo, hi, &options).await?;
73 Ok(Value::Num(sign * result))
74}
75
76#[derive(Clone, Copy)]
77struct IntegralOptions {
78 abs_tol: f64,
79 rel_tol: f64,
80 max_fun_evals: usize,
81}
82
83impl IntegralOptions {
84 fn parse(rest: Vec<Value>) -> BuiltinResult<Self> {
85 let mut options = Self {
86 abs_tol: DEFAULT_ABS_TOL,
87 rel_tol: DEFAULT_REL_TOL,
88 max_fun_evals: DEFAULT_MAX_FUN_EVALS,
89 };
90 if rest.is_empty() {
91 return Ok(options);
92 }
93 if rest.len() == 1 {
94 return match &rest[0] {
95 Value::Struct(fields) => {
96 options.apply_struct(fields)?;
97 Ok(options)
98 }
99 other => Err(optim_error(
100 NAME,
101 format!("integral: expected option name/value pairs, got {other:?}"),
102 )),
103 };
104 }
105 if !rest.len().is_multiple_of(2) {
106 return Err(optim_error(
107 NAME,
108 "integral: expected option name/value pairs",
109 ));
110 }
111 for pair in rest.chunks(2) {
112 let name = option_name(&pair[0])?;
113 options.apply_option(&name, &pair[1])?;
114 }
115 options.validate()?;
116 Ok(options)
117 }
118
119 fn apply_struct(&mut self, fields: &StructValue) -> BuiltinResult<()> {
120 for (name, value) in &fields.fields {
121 self.apply_option(name, value)?;
122 }
123 self.validate()
124 }
125
126 fn apply_option(&mut self, name: &str, value: &Value) -> BuiltinResult<()> {
127 match name.to_ascii_lowercase().as_str() {
128 "abstol" => self.abs_tol = numeric_option("AbsTol", value)?,
129 "reltol" => self.rel_tol = numeric_option("RelTol", value)?,
130 "maxfunevals" | "maxintervalcount" => {
131 let parsed = integer_option(name, value)?;
132 if parsed < 5 {
133 return Err(optim_error(
134 NAME,
135 "integral: MaxFunEvals must be an integer scalar >= 5",
136 ));
137 }
138 self.max_fun_evals = parsed;
139 }
140 "arrayvalued" => {
141 if bool_option("ArrayValued", value)? {
142 return Err(optim_error(
143 NAME,
144 "integral: ArrayValued true is not supported yet",
145 ));
146 }
147 }
148 other => {
149 return Err(optim_error(
150 NAME,
151 format!("integral: unsupported option {other}"),
152 ))
153 }
154 }
155 Ok(())
156 }
157
158 fn validate(&self) -> BuiltinResult<()> {
159 if self.abs_tol < 0.0 {
160 return Err(optim_error(NAME, "integral: AbsTol must be nonnegative"));
161 }
162 if self.rel_tol < 0.0 {
163 return Err(optim_error(NAME, "integral: RelTol must be nonnegative"));
164 }
165 if self.abs_tol == 0.0 && self.rel_tol == 0.0 {
166 return Err(optim_error(
167 NAME,
168 "integral: AbsTol and RelTol cannot both be zero",
169 ));
170 }
171 Ok(())
172 }
173}
174
175fn option_name(value: &Value) -> BuiltinResult<String> {
176 match value {
177 Value::String(s) => Ok(s.clone()),
178 Value::StringArray(sa) if sa.data.len() == 1 => Ok(sa.data[0].clone()),
179 Value::CharArray(chars) if chars.rows == 1 => Ok(chars.data.iter().collect()),
180 other => Err(optim_error(
181 NAME,
182 format!("integral: option names must be strings, got {other:?}"),
183 )),
184 }
185}
186
187async fn scalar_bound(label: &str, value: Value) -> BuiltinResult<f64> {
188 let value = crate::dispatcher::gather_if_needed_async(&value).await?;
189 let parsed = match value {
190 Value::Num(n) => n,
191 Value::Int(i) => i.to_f64(),
192 Value::Bool(b) => {
193 if b {
194 1.0
195 } else {
196 0.0
197 }
198 }
199 Value::Tensor(tensor) if tensor.data.len() == 1 => tensor.data[0],
200 Value::LogicalArray(LogicalArray { data, .. }) if data.len() == 1 => {
201 if data[0] != 0 {
202 1.0
203 } else {
204 0.0
205 }
206 }
207 other => {
208 return Err(optim_error(
209 NAME,
210 format!("integral: {label} must be a finite real scalar, got {other:?}"),
211 ))
212 }
213 };
214 if parsed.is_finite() {
215 Ok(parsed)
216 } else {
217 Err(optim_error(
218 NAME,
219 format!("integral: {label} must be finite"),
220 ))
221 }
222}
223
224fn numeric_option(name: &str, value: &Value) -> BuiltinResult<f64> {
225 let parsed = match value {
226 Value::Num(n) => *n,
227 Value::Int(i) => i.to_f64(),
228 Value::Bool(b) => {
229 if *b {
230 1.0
231 } else {
232 0.0
233 }
234 }
235 Value::Tensor(Tensor { data, .. }) if data.len() == 1 => data[0],
236 Value::LogicalArray(LogicalArray { data, .. }) if data.len() == 1 => {
237 if data[0] != 0 {
238 1.0
239 } else {
240 0.0
241 }
242 }
243 other => {
244 return Err(optim_error(
245 NAME,
246 format!("integral: option {name} must be numeric, got {other:?}"),
247 ))
248 }
249 };
250 if parsed.is_finite() {
251 Ok(parsed)
252 } else {
253 Err(optim_error(
254 NAME,
255 format!("integral: option {name} must be finite"),
256 ))
257 }
258}
259
260fn integer_option(name: &str, value: &Value) -> BuiltinResult<usize> {
261 let parsed = numeric_option(name, value)?;
262 if parsed < 0.0 {
263 return Err(optim_error(
264 NAME,
265 format!("integral: option {name} must be nonnegative"),
266 ));
267 }
268 if parsed.fract() != 0.0 {
269 return Err(optim_error(
270 NAME,
271 format!("integral: option {name} must be an integer scalar"),
272 ));
273 }
274 Ok(parsed as usize)
275}
276
277fn bool_option(name: &str, value: &Value) -> BuiltinResult<bool> {
278 match value {
279 Value::Bool(flag) => Ok(*flag),
280 Value::Num(n) if *n == 0.0 || *n == 1.0 => Ok(*n != 0.0),
281 Value::Int(i) => {
282 let raw = i.to_i64();
283 if raw == 0 || raw == 1 {
284 Ok(raw != 0)
285 } else {
286 Err(optim_error(
287 NAME,
288 format!("integral: option {name} must be logical scalar"),
289 ))
290 }
291 }
292 other => Err(optim_error(
293 NAME,
294 format!("integral: option {name} must be logical scalar, got {other:?}"),
295 )),
296 }
297}
298
299async fn integrate_finite_scalar(
300 function: &Value,
301 a: f64,
302 b: f64,
303 options: &IntegralOptions,
304) -> BuiltinResult<f64> {
305 let fa = call_integrand(function, a).await?;
306 let m = 0.5 * (a + b);
307 let fm = call_integrand(function, m).await?;
308 let fb = call_integrand(function, b).await?;
309 let mut evals = 3usize;
310 let whole = simpson(a, b, fa, fm, fb);
311 let tol = options.abs_tol.max(options.rel_tol * whole.abs());
312 adaptive_simpson(
313 function,
314 SimpsonState {
315 a,
316 b,
317 fa,
318 fm,
319 fb,
320 whole,
321 tol,
322 depth: MAX_DEPTH,
323 },
324 &mut evals,
325 options.max_fun_evals,
326 )
327 .await
328}
329
330#[derive(Clone, Copy)]
331struct SimpsonState {
332 a: f64,
333 b: f64,
334 fa: f64,
335 fm: f64,
336 fb: f64,
337 whole: f64,
338 tol: f64,
339 depth: usize,
340}
341
342#[async_recursion::async_recursion(?Send)]
343async fn adaptive_simpson(
344 function: &Value,
345 state: SimpsonState,
346 evals: &mut usize,
347 max_fun_evals: usize,
348) -> BuiltinResult<f64> {
349 if *evals + 2 > max_fun_evals {
350 return Err(optim_error(
351 NAME,
352 "integral: exceeded maximum function evaluations",
353 ));
354 }
355
356 let c = 0.5 * (state.a + state.b);
357 let d = 0.5 * (state.a + c);
358 let e = 0.5 * (c + state.b);
359 let fd = call_integrand(function, d).await?;
360 let fe = call_integrand(function, e).await?;
361 *evals += 2;
362
363 let left = simpson(state.a, c, state.fa, fd, state.fm);
364 let right = simpson(c, state.b, state.fm, fe, state.fb);
365 let refined = left + right;
366 let error = refined - state.whole;
367 if error.abs() <= 15.0 * state.tol {
368 return Ok(refined + error / 15.0);
369 }
370 if state.depth == 0 {
371 return Err(optim_error(
372 NAME,
373 "integral: adaptive quadrature did not converge",
374 ));
375 }
376
377 let left_value = adaptive_simpson(
378 function,
379 SimpsonState {
380 a: state.a,
381 b: c,
382 fa: state.fa,
383 fm: fd,
384 fb: state.fm,
385 whole: left,
386 tol: state.tol * 0.5,
387 depth: state.depth - 1,
388 },
389 evals,
390 max_fun_evals,
391 )
392 .await?;
393 let right_value = adaptive_simpson(
394 function,
395 SimpsonState {
396 a: c,
397 b: state.b,
398 fa: state.fm,
399 fm: fe,
400 fb: state.fb,
401 whole: right,
402 tol: state.tol * 0.5,
403 depth: state.depth - 1,
404 },
405 evals,
406 max_fun_evals,
407 )
408 .await?;
409 Ok(left_value + right_value)
410}
411
412fn simpson(a: f64, b: f64, fa: f64, fm: f64, fb: f64) -> f64 {
413 (b - a) * (fa + 4.0 * fm + fb) / 6.0
414}
415
416async fn call_integrand(function: &Value, x: f64) -> BuiltinResult<f64> {
417 let value = call_function(function, vec![Value::Num(x)]).await?;
418 let value = crate::dispatcher::gather_if_needed_async(&value).await?;
419 match value {
420 Value::Num(n) if n.is_finite() => Ok(n),
421 Value::Int(i) => Ok(i.to_f64()),
422 Value::Bool(b) => Ok(if b { 1.0 } else { 0.0 }),
423 Value::Tensor(tensor) if tensor.data.len() == 1 && tensor.data[0].is_finite() => {
424 Ok(tensor.data[0])
425 }
426 Value::LogicalArray(logical) if logical.data.len() == 1 => {
427 Ok(if logical.data[0] != 0 { 1.0 } else { 0.0 })
428 }
429 Value::Num(_) | Value::Tensor(_) => Err(optim_error(
430 NAME,
431 "integral: function value must be a finite real scalar",
432 )),
433 other => Err(optim_error(
434 NAME,
435 format!("integral: function value must be real numeric scalar, got {other:?}"),
436 )),
437 }
438}
439
440#[cfg(test)]
441mod tests {
442 use super::*;
443 use futures::executor::block_on;
444
445 #[runtime_builtin(
446 name = "__integral_square",
447 type_resolver(crate::builtins::math::optim::type_resolvers::numerical_integral_type),
448 builtin_path = "crate::builtins::math::optim::integral::tests"
449 )]
450 async fn square_helper(x: Value) -> crate::BuiltinResult<Value> {
451 let x = scalar_bound("x", x).await?;
452 Ok(Value::Num(x * x))
453 }
454
455 #[runtime_builtin(
456 name = "__integral_vector",
457 type_resolver(crate::builtins::math::optim::type_resolvers::numerical_integral_type),
458 builtin_path = "crate::builtins::math::optim::integral::tests"
459 )]
460 async fn vector_helper(_x: Value) -> crate::BuiltinResult<Value> {
461 Ok(Value::Tensor(
462 Tensor::new(vec![1.0, 2.0], vec![1, 2]).unwrap(),
463 ))
464 }
465
466 #[runtime_builtin(
467 name = "__integral_nan",
468 type_resolver(crate::builtins::math::optim::type_resolvers::numerical_integral_type),
469 builtin_path = "crate::builtins::math::optim::integral::tests"
470 )]
471 async fn nan_helper(_x: Value) -> crate::BuiltinResult<Value> {
472 Ok(Value::Num(f64::NAN))
473 }
474
475 fn run(function: Value, a: f64, b: f64) -> crate::BuiltinResult<Value> {
476 block_on(integral_builtin(
477 function,
478 Value::Num(a),
479 Value::Num(b),
480 Vec::new(),
481 ))
482 }
483
484 #[test]
485 fn integrates_named_sine_function() {
486 let result = run(
487 Value::FunctionHandle("sin".into()),
488 0.0,
489 std::f64::consts::PI,
490 )
491 .expect("integral");
492 match result {
493 Value::Num(value) => assert!((value - 2.0).abs() < 1.0e-7),
494 other => panic!("unexpected value {other:?}"),
495 }
496 }
497
498 #[test]
499 fn integrates_polynomial_helper() {
500 let result =
501 run(Value::FunctionHandle("__integral_square".into()), 0.0, 1.0).expect("integral");
502 match result {
503 Value::Num(value) => assert!((value - (1.0 / 3.0)).abs() < 1.0e-9),
504 other => panic!("unexpected value {other:?}"),
505 }
506 }
507
508 #[test]
509 fn reversed_bounds_negate_result() {
510 let result = run(
511 Value::FunctionHandle("sin".into()),
512 std::f64::consts::PI,
513 0.0,
514 )
515 .expect("integral");
516 match result {
517 Value::Num(value) => assert!((value + 2.0).abs() < 1.0e-7),
518 other => panic!("unexpected value {other:?}"),
519 }
520 }
521
522 #[test]
523 fn zero_width_interval_returns_zero_without_callback() {
524 let result =
525 run(Value::FunctionHandle("__integral_nan".into()), 1.0, 1.0).expect("integral");
526 assert!(matches!(result, Value::Num(0.0)));
527 }
528
529 #[test]
530 fn rejects_vector_valued_integrand_for_initial_scope() {
531 let err = run(Value::FunctionHandle("__integral_vector".into()), 0.0, 1.0).unwrap_err();
532 assert!(err.message().contains("finite real scalar"));
533 }
534
535 #[test]
536 fn rejects_nonfinite_integrand_values() {
537 let err = run(Value::FunctionHandle("__integral_nan".into()), 0.0, 1.0).unwrap_err();
538 assert!(err.message().contains("finite real scalar"));
539 }
540
541 #[test]
542 fn accepts_tolerance_name_value_options() {
543 let result = block_on(integral_builtin(
544 Value::FunctionHandle("sin".into()),
545 Value::Num(0.0),
546 Value::Num(std::f64::consts::PI),
547 vec![
548 Value::from("AbsTol"),
549 Value::Num(1.0e-12),
550 Value::from("RelTol"),
551 Value::Num(1.0e-8),
552 ],
553 ))
554 .expect("integral");
555 match result {
556 Value::Num(value) => assert!((value - 2.0).abs() < 1.0e-8),
557 other => panic!("unexpected value {other:?}"),
558 }
559 }
560
561 #[test]
562 fn rejects_too_small_max_fun_evals() {
563 let err = block_on(integral_builtin(
564 Value::FunctionHandle("sin".into()),
565 Value::Num(0.0),
566 Value::Num(1.0),
567 vec![Value::from("MaxFunEvals"), Value::Num(4.0)],
568 ))
569 .unwrap_err();
570 assert!(err.message().contains("integer scalar >= 5"));
571 }
572
573 #[test]
574 fn rejects_fractional_max_fun_evals() {
575 let err = block_on(integral_builtin(
576 Value::FunctionHandle("sin".into()),
577 Value::Num(0.0),
578 Value::Num(1.0),
579 vec![Value::from("MaxFunEvals"), Value::Num(5.5)],
580 ))
581 .unwrap_err();
582 assert!(err.message().contains("integer scalar"));
583 }
584}