Skip to main content

runmat_runtime/builtins/math/optim/
integral.rs

1//! MATLAB-compatible `integral` builtin for finite scalar numerical integration.
2
3use runmat_builtins::{LogicalArray, StructValue, Tensor, Value};
4use runmat_macros::runtime_builtin;
5
6use crate::builtins::common::spec::{
7    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
8    ReductionNaN, ResidencyPolicy, ShapeRequirements,
9};
10use crate::builtins::math::optim::common::{call_function, optim_error};
11use crate::builtins::math::optim::type_resolvers::numerical_integral_type;
12use crate::BuiltinResult;
13
14const NAME: &str = "integral";
15const DEFAULT_ABS_TOL: f64 = 1.0e-10;
16const DEFAULT_REL_TOL: f64 = 1.0e-6;
17const DEFAULT_MAX_FUN_EVALS: usize = 10_000;
18const MAX_DEPTH: usize = 30;
19
20#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::optim::integral")]
21pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
22    name: "integral",
23    op_kind: GpuOpKind::Custom("adaptive-quadrature"),
24    supported_precisions: &[],
25    broadcast: BroadcastSemantics::None,
26    provider_hooks: &[],
27    constant_strategy: ConstantStrategy::InlineLiteral,
28    residency: ResidencyPolicy::GatherImmediately,
29    nan_mode: ReductionNaN::Include,
30    two_pass_threshold: None,
31    workgroup_size: None,
32    accepts_nan_mode: false,
33    notes: "Host adaptive quadrature solver. Callback computations may use GPU-aware builtins, but the adaptive integration loop runs on the CPU.",
34};
35
36#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::math::optim::integral")]
37pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
38    name: "integral",
39    shape: ShapeRequirements::Any,
40    constant_strategy: ConstantStrategy::InlineLiteral,
41    elementwise: None,
42    reduction: None,
43    emits_nan: false,
44    notes: "Adaptive integration repeatedly invokes user code and terminates fusion planning.",
45};
46
47#[runtime_builtin(
48    name = "integral",
49    category = "math/optim",
50    summary = "Approximate a finite scalar definite integral using adaptive quadrature.",
51    keywords = "integral,numerical integration,adaptive quadrature,quadrature,function handle",
52    accel = "sink",
53    type_resolver(numerical_integral_type),
54    builtin_path = "crate::builtins::math::optim::integral"
55)]
56async fn integral_builtin(
57    function: Value,
58    a: Value,
59    b: Value,
60    rest: Vec<Value>,
61) -> BuiltinResult<Value> {
62    let options = IntegralOptions::parse(rest)?;
63    let a = scalar_bound("lower bound", a).await?;
64    let b = scalar_bound("upper bound", b).await?;
65    if a == b {
66        return Ok(Value::Num(0.0));
67    }
68
69    let sign = if b < a { -1.0 } else { 1.0 };
70    let lo = a.min(b);
71    let hi = a.max(b);
72    let result = integrate_finite_scalar(&function, lo, hi, &options).await?;
73    Ok(Value::Num(sign * result))
74}
75
76#[derive(Clone, Copy)]
77struct IntegralOptions {
78    abs_tol: f64,
79    rel_tol: f64,
80    max_fun_evals: usize,
81}
82
83impl IntegralOptions {
84    fn parse(rest: Vec<Value>) -> BuiltinResult<Self> {
85        let mut options = Self {
86            abs_tol: DEFAULT_ABS_TOL,
87            rel_tol: DEFAULT_REL_TOL,
88            max_fun_evals: DEFAULT_MAX_FUN_EVALS,
89        };
90        if rest.is_empty() {
91            return Ok(options);
92        }
93        if rest.len() == 1 {
94            return match &rest[0] {
95                Value::Struct(fields) => {
96                    options.apply_struct(fields)?;
97                    Ok(options)
98                }
99                other => Err(optim_error(
100                    NAME,
101                    format!("integral: expected option name/value pairs, got {other:?}"),
102                )),
103            };
104        }
105        if !rest.len().is_multiple_of(2) {
106            return Err(optim_error(
107                NAME,
108                "integral: expected option name/value pairs",
109            ));
110        }
111        for pair in rest.chunks(2) {
112            let name = option_name(&pair[0])?;
113            options.apply_option(&name, &pair[1])?;
114        }
115        options.validate()?;
116        Ok(options)
117    }
118
119    fn apply_struct(&mut self, fields: &StructValue) -> BuiltinResult<()> {
120        for (name, value) in &fields.fields {
121            self.apply_option(name, value)?;
122        }
123        self.validate()
124    }
125
126    fn apply_option(&mut self, name: &str, value: &Value) -> BuiltinResult<()> {
127        match name.to_ascii_lowercase().as_str() {
128            "abstol" => self.abs_tol = numeric_option("AbsTol", value)?,
129            "reltol" => self.rel_tol = numeric_option("RelTol", value)?,
130            "maxfunevals" | "maxintervalcount" => {
131                let parsed = integer_option(name, value)?;
132                if parsed < 5 {
133                    return Err(optim_error(
134                        NAME,
135                        "integral: MaxFunEvals must be an integer scalar >= 5",
136                    ));
137                }
138                self.max_fun_evals = parsed;
139            }
140            "arrayvalued" => {
141                if bool_option("ArrayValued", value)? {
142                    return Err(optim_error(
143                        NAME,
144                        "integral: ArrayValued true is not supported yet",
145                    ));
146                }
147            }
148            other => {
149                return Err(optim_error(
150                    NAME,
151                    format!("integral: unsupported option {other}"),
152                ))
153            }
154        }
155        Ok(())
156    }
157
158    fn validate(&self) -> BuiltinResult<()> {
159        if self.abs_tol < 0.0 {
160            return Err(optim_error(NAME, "integral: AbsTol must be nonnegative"));
161        }
162        if self.rel_tol < 0.0 {
163            return Err(optim_error(NAME, "integral: RelTol must be nonnegative"));
164        }
165        if self.abs_tol == 0.0 && self.rel_tol == 0.0 {
166            return Err(optim_error(
167                NAME,
168                "integral: AbsTol and RelTol cannot both be zero",
169            ));
170        }
171        Ok(())
172    }
173}
174
175fn option_name(value: &Value) -> BuiltinResult<String> {
176    match value {
177        Value::String(s) => Ok(s.clone()),
178        Value::StringArray(sa) if sa.data.len() == 1 => Ok(sa.data[0].clone()),
179        Value::CharArray(chars) if chars.rows == 1 => Ok(chars.data.iter().collect()),
180        other => Err(optim_error(
181            NAME,
182            format!("integral: option names must be strings, got {other:?}"),
183        )),
184    }
185}
186
187async fn scalar_bound(label: &str, value: Value) -> BuiltinResult<f64> {
188    let value = crate::dispatcher::gather_if_needed_async(&value).await?;
189    let parsed = match value {
190        Value::Num(n) => n,
191        Value::Int(i) => i.to_f64(),
192        Value::Bool(b) => {
193            if b {
194                1.0
195            } else {
196                0.0
197            }
198        }
199        Value::Tensor(tensor) if tensor.data.len() == 1 => tensor.data[0],
200        Value::LogicalArray(LogicalArray { data, .. }) if data.len() == 1 => {
201            if data[0] != 0 {
202                1.0
203            } else {
204                0.0
205            }
206        }
207        other => {
208            return Err(optim_error(
209                NAME,
210                format!("integral: {label} must be a finite real scalar, got {other:?}"),
211            ))
212        }
213    };
214    if parsed.is_finite() {
215        Ok(parsed)
216    } else {
217        Err(optim_error(
218            NAME,
219            format!("integral: {label} must be finite"),
220        ))
221    }
222}
223
224fn numeric_option(name: &str, value: &Value) -> BuiltinResult<f64> {
225    let parsed = match value {
226        Value::Num(n) => *n,
227        Value::Int(i) => i.to_f64(),
228        Value::Bool(b) => {
229            if *b {
230                1.0
231            } else {
232                0.0
233            }
234        }
235        Value::Tensor(Tensor { data, .. }) if data.len() == 1 => data[0],
236        Value::LogicalArray(LogicalArray { data, .. }) if data.len() == 1 => {
237            if data[0] != 0 {
238                1.0
239            } else {
240                0.0
241            }
242        }
243        other => {
244            return Err(optim_error(
245                NAME,
246                format!("integral: option {name} must be numeric, got {other:?}"),
247            ))
248        }
249    };
250    if parsed.is_finite() {
251        Ok(parsed)
252    } else {
253        Err(optim_error(
254            NAME,
255            format!("integral: option {name} must be finite"),
256        ))
257    }
258}
259
260fn integer_option(name: &str, value: &Value) -> BuiltinResult<usize> {
261    let parsed = numeric_option(name, value)?;
262    if parsed < 0.0 {
263        return Err(optim_error(
264            NAME,
265            format!("integral: option {name} must be nonnegative"),
266        ));
267    }
268    if parsed.fract() != 0.0 {
269        return Err(optim_error(
270            NAME,
271            format!("integral: option {name} must be an integer scalar"),
272        ));
273    }
274    Ok(parsed as usize)
275}
276
277fn bool_option(name: &str, value: &Value) -> BuiltinResult<bool> {
278    match value {
279        Value::Bool(flag) => Ok(*flag),
280        Value::Num(n) if *n == 0.0 || *n == 1.0 => Ok(*n != 0.0),
281        Value::Int(i) => {
282            let raw = i.to_i64();
283            if raw == 0 || raw == 1 {
284                Ok(raw != 0)
285            } else {
286                Err(optim_error(
287                    NAME,
288                    format!("integral: option {name} must be logical scalar"),
289                ))
290            }
291        }
292        other => Err(optim_error(
293            NAME,
294            format!("integral: option {name} must be logical scalar, got {other:?}"),
295        )),
296    }
297}
298
299async fn integrate_finite_scalar(
300    function: &Value,
301    a: f64,
302    b: f64,
303    options: &IntegralOptions,
304) -> BuiltinResult<f64> {
305    let fa = call_integrand(function, a).await?;
306    let m = 0.5 * (a + b);
307    let fm = call_integrand(function, m).await?;
308    let fb = call_integrand(function, b).await?;
309    let mut evals = 3usize;
310    let whole = simpson(a, b, fa, fm, fb);
311    let tol = options.abs_tol.max(options.rel_tol * whole.abs());
312    adaptive_simpson(
313        function,
314        SimpsonState {
315            a,
316            b,
317            fa,
318            fm,
319            fb,
320            whole,
321            tol,
322            depth: MAX_DEPTH,
323        },
324        &mut evals,
325        options.max_fun_evals,
326    )
327    .await
328}
329
330#[derive(Clone, Copy)]
331struct SimpsonState {
332    a: f64,
333    b: f64,
334    fa: f64,
335    fm: f64,
336    fb: f64,
337    whole: f64,
338    tol: f64,
339    depth: usize,
340}
341
342#[async_recursion::async_recursion(?Send)]
343async fn adaptive_simpson(
344    function: &Value,
345    state: SimpsonState,
346    evals: &mut usize,
347    max_fun_evals: usize,
348) -> BuiltinResult<f64> {
349    if *evals + 2 > max_fun_evals {
350        return Err(optim_error(
351            NAME,
352            "integral: exceeded maximum function evaluations",
353        ));
354    }
355
356    let c = 0.5 * (state.a + state.b);
357    let d = 0.5 * (state.a + c);
358    let e = 0.5 * (c + state.b);
359    let fd = call_integrand(function, d).await?;
360    let fe = call_integrand(function, e).await?;
361    *evals += 2;
362
363    let left = simpson(state.a, c, state.fa, fd, state.fm);
364    let right = simpson(c, state.b, state.fm, fe, state.fb);
365    let refined = left + right;
366    let error = refined - state.whole;
367    if error.abs() <= 15.0 * state.tol {
368        return Ok(refined + error / 15.0);
369    }
370    if state.depth == 0 {
371        return Err(optim_error(
372            NAME,
373            "integral: adaptive quadrature did not converge",
374        ));
375    }
376
377    let left_value = adaptive_simpson(
378        function,
379        SimpsonState {
380            a: state.a,
381            b: c,
382            fa: state.fa,
383            fm: fd,
384            fb: state.fm,
385            whole: left,
386            tol: state.tol * 0.5,
387            depth: state.depth - 1,
388        },
389        evals,
390        max_fun_evals,
391    )
392    .await?;
393    let right_value = adaptive_simpson(
394        function,
395        SimpsonState {
396            a: c,
397            b: state.b,
398            fa: state.fm,
399            fm: fe,
400            fb: state.fb,
401            whole: right,
402            tol: state.tol * 0.5,
403            depth: state.depth - 1,
404        },
405        evals,
406        max_fun_evals,
407    )
408    .await?;
409    Ok(left_value + right_value)
410}
411
412fn simpson(a: f64, b: f64, fa: f64, fm: f64, fb: f64) -> f64 {
413    (b - a) * (fa + 4.0 * fm + fb) / 6.0
414}
415
416async fn call_integrand(function: &Value, x: f64) -> BuiltinResult<f64> {
417    let value = call_function(function, vec![Value::Num(x)]).await?;
418    let value = crate::dispatcher::gather_if_needed_async(&value).await?;
419    match value {
420        Value::Num(n) if n.is_finite() => Ok(n),
421        Value::Int(i) => Ok(i.to_f64()),
422        Value::Bool(b) => Ok(if b { 1.0 } else { 0.0 }),
423        Value::Tensor(tensor) if tensor.data.len() == 1 && tensor.data[0].is_finite() => {
424            Ok(tensor.data[0])
425        }
426        Value::LogicalArray(logical) if logical.data.len() == 1 => {
427            Ok(if logical.data[0] != 0 { 1.0 } else { 0.0 })
428        }
429        Value::Num(_) | Value::Tensor(_) => Err(optim_error(
430            NAME,
431            "integral: function value must be a finite real scalar",
432        )),
433        other => Err(optim_error(
434            NAME,
435            format!("integral: function value must be real numeric scalar, got {other:?}"),
436        )),
437    }
438}
439
440#[cfg(test)]
441mod tests {
442    use super::*;
443    use futures::executor::block_on;
444
445    #[runtime_builtin(
446        name = "__integral_square",
447        type_resolver(crate::builtins::math::optim::type_resolvers::numerical_integral_type),
448        builtin_path = "crate::builtins::math::optim::integral::tests"
449    )]
450    async fn square_helper(x: Value) -> crate::BuiltinResult<Value> {
451        let x = scalar_bound("x", x).await?;
452        Ok(Value::Num(x * x))
453    }
454
455    #[runtime_builtin(
456        name = "__integral_vector",
457        type_resolver(crate::builtins::math::optim::type_resolvers::numerical_integral_type),
458        builtin_path = "crate::builtins::math::optim::integral::tests"
459    )]
460    async fn vector_helper(_x: Value) -> crate::BuiltinResult<Value> {
461        Ok(Value::Tensor(
462            Tensor::new(vec![1.0, 2.0], vec![1, 2]).unwrap(),
463        ))
464    }
465
466    #[runtime_builtin(
467        name = "__integral_nan",
468        type_resolver(crate::builtins::math::optim::type_resolvers::numerical_integral_type),
469        builtin_path = "crate::builtins::math::optim::integral::tests"
470    )]
471    async fn nan_helper(_x: Value) -> crate::BuiltinResult<Value> {
472        Ok(Value::Num(f64::NAN))
473    }
474
475    fn run(function: Value, a: f64, b: f64) -> crate::BuiltinResult<Value> {
476        block_on(integral_builtin(
477            function,
478            Value::Num(a),
479            Value::Num(b),
480            Vec::new(),
481        ))
482    }
483
484    #[test]
485    fn integrates_named_sine_function() {
486        let result = run(
487            Value::FunctionHandle("sin".into()),
488            0.0,
489            std::f64::consts::PI,
490        )
491        .expect("integral");
492        match result {
493            Value::Num(value) => assert!((value - 2.0).abs() < 1.0e-7),
494            other => panic!("unexpected value {other:?}"),
495        }
496    }
497
498    #[test]
499    fn integrates_polynomial_helper() {
500        let result =
501            run(Value::FunctionHandle("__integral_square".into()), 0.0, 1.0).expect("integral");
502        match result {
503            Value::Num(value) => assert!((value - (1.0 / 3.0)).abs() < 1.0e-9),
504            other => panic!("unexpected value {other:?}"),
505        }
506    }
507
508    #[test]
509    fn reversed_bounds_negate_result() {
510        let result = run(
511            Value::FunctionHandle("sin".into()),
512            std::f64::consts::PI,
513            0.0,
514        )
515        .expect("integral");
516        match result {
517            Value::Num(value) => assert!((value + 2.0).abs() < 1.0e-7),
518            other => panic!("unexpected value {other:?}"),
519        }
520    }
521
522    #[test]
523    fn zero_width_interval_returns_zero_without_callback() {
524        let result =
525            run(Value::FunctionHandle("__integral_nan".into()), 1.0, 1.0).expect("integral");
526        assert!(matches!(result, Value::Num(0.0)));
527    }
528
529    #[test]
530    fn rejects_vector_valued_integrand_for_initial_scope() {
531        let err = run(Value::FunctionHandle("__integral_vector".into()), 0.0, 1.0).unwrap_err();
532        assert!(err.message().contains("finite real scalar"));
533    }
534
535    #[test]
536    fn rejects_nonfinite_integrand_values() {
537        let err = run(Value::FunctionHandle("__integral_nan".into()), 0.0, 1.0).unwrap_err();
538        assert!(err.message().contains("finite real scalar"));
539    }
540
541    #[test]
542    fn accepts_tolerance_name_value_options() {
543        let result = block_on(integral_builtin(
544            Value::FunctionHandle("sin".into()),
545            Value::Num(0.0),
546            Value::Num(std::f64::consts::PI),
547            vec![
548                Value::from("AbsTol"),
549                Value::Num(1.0e-12),
550                Value::from("RelTol"),
551                Value::Num(1.0e-8),
552            ],
553        ))
554        .expect("integral");
555        match result {
556            Value::Num(value) => assert!((value - 2.0).abs() < 1.0e-8),
557            other => panic!("unexpected value {other:?}"),
558        }
559    }
560
561    #[test]
562    fn rejects_too_small_max_fun_evals() {
563        let err = block_on(integral_builtin(
564            Value::FunctionHandle("sin".into()),
565            Value::Num(0.0),
566            Value::Num(1.0),
567            vec![Value::from("MaxFunEvals"), Value::Num(4.0)],
568        ))
569        .unwrap_err();
570        assert!(err.message().contains("integer scalar >= 5"));
571    }
572
573    #[test]
574    fn rejects_fractional_max_fun_evals() {
575        let err = block_on(integral_builtin(
576            Value::FunctionHandle("sin".into()),
577            Value::Num(0.0),
578            Value::Num(1.0),
579            vec![Value::from("MaxFunEvals"), Value::Num(5.5)],
580        ))
581        .unwrap_err();
582        assert!(err.message().contains("integer scalar"));
583    }
584}