1use runmat_builtins::{
4 BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
5 BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
6 LogicalArray, Tensor, Value,
7};
8use runmat_macros::runtime_builtin;
9
10use crate::builtins::common::spec::{
11 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
12 ReductionNaN, ResidencyPolicy, ShapeRequirements,
13};
14use crate::builtins::math::optim::common::{call_function, value_to_scalar};
15use crate::builtins::math::optim::type_resolvers::numerical_integral_type;
16use crate::{build_runtime_error, BuiltinResult, RuntimeError};
17
18const NAME: &str = "quad";
19const DEFAULT_TOL: f64 = 1.0e-6;
20const MAX_DEPTH: usize = 30;
21const MAX_FUN_EVALS: usize = 100_000;
22
23const QUAD_OUTPUT_Q: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
24 name: "q",
25 ty: BuiltinParamType::NumericScalar,
26 arity: BuiltinParamArity::Required,
27 default: None,
28 description: "Numerical integral estimate.",
29}];
30
31const QUAD_OUTPUT_Q_FCNT: [BuiltinParamDescriptor; 2] = [
32 BuiltinParamDescriptor {
33 name: "q",
34 ty: BuiltinParamType::NumericScalar,
35 arity: BuiltinParamArity::Required,
36 default: None,
37 description: "Numerical integral estimate.",
38 },
39 BuiltinParamDescriptor {
40 name: "fcnt",
41 ty: BuiltinParamType::NumericScalar,
42 arity: BuiltinParamArity::Required,
43 default: None,
44 description: "Number of integrand evaluations.",
45 },
46];
47
48const QUAD_INPUTS_CORE: [BuiltinParamDescriptor; 3] = [
49 BuiltinParamDescriptor {
50 name: "fun",
51 ty: BuiltinParamType::Any,
52 arity: BuiltinParamArity::Required,
53 default: None,
54 description: "Scalar integrand callback.",
55 },
56 BuiltinParamDescriptor {
57 name: "a",
58 ty: BuiltinParamType::Any,
59 arity: BuiltinParamArity::Required,
60 default: None,
61 description: "Lower integration bound.",
62 },
63 BuiltinParamDescriptor {
64 name: "b",
65 ty: BuiltinParamType::Any,
66 arity: BuiltinParamArity::Required,
67 default: None,
68 description: "Upper integration bound.",
69 },
70];
71
72const QUAD_INPUTS_TOL_TRACE_ARGS: [BuiltinParamDescriptor; 6] = [
73 BuiltinParamDescriptor {
74 name: "fun",
75 ty: BuiltinParamType::Any,
76 arity: BuiltinParamArity::Required,
77 default: None,
78 description: "Scalar integrand callback.",
79 },
80 BuiltinParamDescriptor {
81 name: "a",
82 ty: BuiltinParamType::Any,
83 arity: BuiltinParamArity::Required,
84 default: None,
85 description: "Lower integration bound.",
86 },
87 BuiltinParamDescriptor {
88 name: "b",
89 ty: BuiltinParamType::Any,
90 arity: BuiltinParamArity::Required,
91 default: None,
92 description: "Upper integration bound.",
93 },
94 BuiltinParamDescriptor {
95 name: "tol",
96 ty: BuiltinParamType::NumericScalar,
97 arity: BuiltinParamArity::Optional,
98 default: Some("1e-6"),
99 description: "Absolute error tolerance. Empty uses the default.",
100 },
101 BuiltinParamDescriptor {
102 name: "trace",
103 ty: BuiltinParamType::Any,
104 arity: BuiltinParamArity::Optional,
105 default: Some("false"),
106 description: "Nonzero value prints legacy [fcnEvals, a, b-a, Q] trace rows.",
107 },
108 BuiltinParamDescriptor {
109 name: "p",
110 ty: BuiltinParamType::Any,
111 arity: BuiltinParamArity::Variadic,
112 default: None,
113 description: "Additional arguments forwarded to the integrand.",
114 },
115];
116
117const QUAD_SIGNATURES: [BuiltinSignatureDescriptor; 4] = [
118 BuiltinSignatureDescriptor {
119 label: "q = quad(fun, a, b)",
120 inputs: &QUAD_INPUTS_CORE,
121 outputs: &QUAD_OUTPUT_Q,
122 },
123 BuiltinSignatureDescriptor {
124 label: "q = quad(fun, a, b, tol, trace, p1, p2, ...)",
125 inputs: &QUAD_INPUTS_TOL_TRACE_ARGS,
126 outputs: &QUAD_OUTPUT_Q,
127 },
128 BuiltinSignatureDescriptor {
129 label: "[q, fcnt] = quad(fun, a, b)",
130 inputs: &QUAD_INPUTS_CORE,
131 outputs: &QUAD_OUTPUT_Q_FCNT,
132 },
133 BuiltinSignatureDescriptor {
134 label: "[q, fcnt] = quad(fun, a, b, tol, trace, p1, p2, ...)",
135 inputs: &QUAD_INPUTS_TOL_TRACE_ARGS,
136 outputs: &QUAD_OUTPUT_Q_FCNT,
137 },
138];
139
140const QUAD_ERROR_INVALID_ARGUMENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
141 code: "RM.QUAD.INVALID_ARGUMENT",
142 identifier: Some("RunMat:quad:InvalidArgument"),
143 when: "Tolerance, trace flag, or argument grammar is invalid.",
144 message: "quad: invalid argument",
145};
146
147const QUAD_ERROR_INVALID_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
148 code: "RM.QUAD.INVALID_INPUT",
149 identifier: Some("RunMat:quad:InvalidInput"),
150 when: "Bounds, integrand values, or adaptive solver semantics are invalid.",
151 message: "quad: invalid input",
152};
153
154const QUAD_ERROR_TOO_MANY_OUTPUTS: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
155 code: "RM.QUAD.TOO_MANY_OUTPUTS",
156 identifier: Some("RunMat:quad:TooManyOutputs"),
157 when: "`quad` is called with more than two requested output arguments.",
158 message: "quad: too many output arguments",
159};
160
161const QUAD_ERRORS: [BuiltinErrorDescriptor; 3] = [
162 QUAD_ERROR_INVALID_ARGUMENT,
163 QUAD_ERROR_INVALID_INPUT,
164 QUAD_ERROR_TOO_MANY_OUTPUTS,
165];
166
167pub const QUAD_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
168 signatures: &QUAD_SIGNATURES,
169 output_mode: BuiltinOutputMode::ByRequestedOutputCount,
170 completion_policy: BuiltinCompletionPolicy::Public,
171 errors: &QUAD_ERRORS,
172};
173
174fn quad_error_with_detail(
175 error: &'static BuiltinErrorDescriptor,
176 detail: impl AsRef<str>,
177) -> RuntimeError {
178 let detail = detail.as_ref();
179 let message = if detail.starts_with("quad:") {
180 detail.to_string()
181 } else {
182 format!("{}: {detail}", error.message)
183 };
184 let mut builder = build_runtime_error(message).with_builtin(NAME);
185 if let Some(identifier) = error.identifier {
186 builder = builder.with_identifier(identifier);
187 }
188 builder.build()
189}
190
191fn quad_map_error(err: RuntimeError, fallback: &'static BuiltinErrorDescriptor) -> RuntimeError {
192 if err.identifier().is_some() {
193 err
194 } else {
195 quad_error_with_detail(fallback, err.message())
196 }
197}
198
199fn validate_requested_outputs() -> BuiltinResult<()> {
200 if matches!(crate::output_count::current_output_count(), Some(n) if n > 2) {
201 return Err(quad_error_with_detail(
202 &QUAD_ERROR_TOO_MANY_OUTPUTS,
203 "quad: too many output arguments; maximum is 2",
204 ));
205 }
206 Ok(())
207}
208
209#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::optim::quad")]
210pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
211 name: "quad",
212 op_kind: GpuOpKind::Custom("legacy-adaptive-simpson"),
213 supported_precisions: &[],
214 broadcast: BroadcastSemantics::None,
215 provider_hooks: &[],
216 constant_strategy: ConstantStrategy::InlineLiteral,
217 residency: ResidencyPolicy::GatherImmediately,
218 nan_mode: ReductionNaN::Include,
219 two_pass_threshold: None,
220 workgroup_size: None,
221 accepts_nan_mode: false,
222 notes: "Host adaptive Simpson solver. Callback computations may use GPU-aware builtins, but the adaptive integration loop runs on the CPU.",
223};
224
225#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::math::optim::quad")]
226pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
227 name: "quad",
228 shape: ShapeRequirements::Any,
229 constant_strategy: ConstantStrategy::InlineLiteral,
230 elementwise: None,
231 reduction: None,
232 emits_nan: false,
233 notes:
234 "Legacy adaptive quadrature repeatedly invokes user code and terminates fusion planning.",
235};
236
237#[runtime_builtin(
238 name = "quad",
239 category = "math/optim",
240 summary = "Approximate finite scalar definite integrals using legacy adaptive Simpson quadrature.",
241 keywords = "quad,numerical integration,adaptive simpson,quadrature,function handle",
242 accel = "sink",
243 type_resolver(numerical_integral_type),
244 descriptor(crate::builtins::math::optim::quad::QUAD_DESCRIPTOR),
245 builtin_path = "crate::builtins::math::optim::quad"
246)]
247async fn quad_builtin(
248 function: Value,
249 a: Value,
250 b: Value,
251 rest: Vec<Value>,
252) -> BuiltinResult<Value> {
253 validate_requested_outputs()?;
254 let options = QuadOptions::parse(rest)
255 .await
256 .map_err(|err| quad_map_error(err, &QUAD_ERROR_INVALID_ARGUMENT))?;
257 let a = scalar_real("lower bound", a)
258 .await
259 .map_err(|err| quad_map_error(err, &QUAD_ERROR_INVALID_INPUT))?;
260 let b = scalar_real("upper bound", b)
261 .await
262 .map_err(|err| quad_map_error(err, &QUAD_ERROR_INVALID_INPUT))?;
263
264 let result = if a == b {
265 QuadResult {
266 q: 0.0,
267 func_count: 0,
268 }
269 } else {
270 let sign = if b < a { -1.0 } else { 1.0 };
271 let lo = a.min(b);
272 let hi = a.max(b);
273 let mut result = integrate_quad(&function, lo, hi, &options)
274 .await
275 .map_err(|err| quad_map_error(err, &QUAD_ERROR_INVALID_INPUT))?;
276 result.q *= sign;
277 result
278 };
279
280 finalize(result)
281}
282
283struct QuadOptions {
284 tol: f64,
285 trace: bool,
286 extra_args: Vec<Value>,
287}
288
289impl QuadOptions {
290 async fn parse(rest: Vec<Value>) -> BuiltinResult<Self> {
291 let mut values = rest.into_iter();
292 let tol = match values.next() {
293 Some(value) => parse_optional_tol(value).await?,
294 None => DEFAULT_TOL,
295 };
296 let trace = match values.next() {
297 Some(value) => parse_optional_trace(value).await?,
298 None => false,
299 };
300 Ok(Self {
301 tol,
302 trace,
303 extra_args: values.collect(),
304 })
305 }
306}
307
308async fn parse_optional_tol(value: Value) -> BuiltinResult<f64> {
309 let value = crate::dispatcher::gather_if_needed_async(&value).await?;
310 if is_empty_value(&value) {
311 return Ok(DEFAULT_TOL);
312 }
313 let tol = scalar_real_sync("tolerance", value, &QUAD_ERROR_INVALID_ARGUMENT)?;
314 if tol > 0.0 {
315 Ok(tol)
316 } else {
317 Err(quad_error_with_detail(
318 &QUAD_ERROR_INVALID_ARGUMENT,
319 "tolerance must be a positive finite scalar",
320 ))
321 }
322}
323
324async fn parse_optional_trace(value: Value) -> BuiltinResult<bool> {
325 let value = crate::dispatcher::gather_if_needed_async(&value).await?;
326 if is_empty_value(&value) {
327 return Ok(false);
328 }
329 Ok(scalar_real_sync("trace", value, &QUAD_ERROR_INVALID_ARGUMENT)? != 0.0)
330}
331
332fn is_empty_value(value: &Value) -> bool {
333 match value {
334 Value::Tensor(Tensor { data, .. }) => data.is_empty(),
335 Value::LogicalArray(LogicalArray { data, .. }) => data.is_empty(),
336 _ => false,
337 }
338}
339
340async fn scalar_real(label: &str, value: Value) -> BuiltinResult<f64> {
341 let value = crate::dispatcher::gather_if_needed_async(&value).await?;
342 scalar_real_sync(label, value, &QUAD_ERROR_INVALID_INPUT)
343}
344
345fn scalar_real_sync(
346 label: &str,
347 value: Value,
348 error: &'static BuiltinErrorDescriptor,
349) -> BuiltinResult<f64> {
350 let parsed = match value {
351 Value::Num(n) => n,
352 Value::Int(i) => i.to_f64(),
353 Value::Bool(flag) => {
354 if flag {
355 1.0
356 } else {
357 0.0
358 }
359 }
360 Value::Tensor(Tensor { data, .. }) if data.len() == 1 => data[0],
361 Value::LogicalArray(LogicalArray { data, .. }) if data.len() == 1 => {
362 if data[0] != 0 {
363 1.0
364 } else {
365 0.0
366 }
367 }
368 other => {
369 return Err(quad_error_with_detail(
370 error,
371 format!("{label} must be a finite real scalar, got {other:?}"),
372 ))
373 }
374 };
375 if parsed.is_finite() {
376 Ok(parsed)
377 } else {
378 Err(quad_error_with_detail(
379 error,
380 format!("{label} must be finite"),
381 ))
382 }
383}
384
385#[derive(Clone, Copy)]
386struct QuadResult {
387 q: f64,
388 func_count: usize,
389}
390
391async fn integrate_quad(
392 function: &Value,
393 a: f64,
394 b: f64,
395 options: &QuadOptions,
396) -> BuiltinResult<QuadResult> {
397 let fa = call_integrand(function, a, &options.extra_args).await?;
398 let c = midpoint(a, b);
399 let fc = call_integrand(function, c, &options.extra_args).await?;
400 let fb = call_integrand(function, b, &options.extra_args).await?;
401 let whole = simpson(a, b, fa, fc, fb);
402 let mut func_count = 3usize;
403 let mut trace = options.trace.then_some(QuadTrace);
404 let q = adaptive_simpson(
405 function,
406 &options.extra_args,
407 SimpsonState {
408 a,
409 b,
410 fa,
411 fc,
412 fb,
413 whole,
414 tol: options.tol,
415 depth: MAX_DEPTH,
416 },
417 &mut func_count,
418 &mut trace,
419 )
420 .await?;
421 Ok(QuadResult { q, func_count })
422}
423
424#[derive(Clone, Copy)]
425struct SimpsonState {
426 a: f64,
427 b: f64,
428 fa: f64,
429 fc: f64,
430 fb: f64,
431 whole: f64,
432 tol: f64,
433 depth: usize,
434}
435
436#[async_recursion::async_recursion(?Send)]
437async fn adaptive_simpson(
438 function: &Value,
439 extra_args: &[Value],
440 state: SimpsonState,
441 func_count: &mut usize,
442 trace: &mut Option<QuadTrace>,
443) -> BuiltinResult<f64> {
444 if *func_count + 2 > MAX_FUN_EVALS {
445 return Err(quad_error_with_detail(
446 &QUAD_ERROR_INVALID_INPUT,
447 "exceeded maximum function evaluations",
448 ));
449 }
450
451 let c = midpoint(state.a, state.b);
452 let d = midpoint(state.a, c);
453 let e = midpoint(c, state.b);
454 let fd = call_integrand(function, d, extra_args).await?;
455 let fe = call_integrand(function, e, extra_args).await?;
456 *func_count += 2;
457
458 let left = simpson(state.a, c, state.fa, fd, state.fc);
459 let right = simpson(c, state.b, state.fc, fe, state.fb);
460 let refined = left + right;
461 let error = refined - state.whole;
462 if let Some(trace) = trace {
463 trace.record(*func_count, state.a, state.b, refined, error);
464 }
465 if error.abs() <= 15.0 * state.tol {
466 return Ok(refined + error / 15.0);
467 }
468 if state.depth == 0 {
469 return Err(quad_error_with_detail(
470 &QUAD_ERROR_INVALID_INPUT,
471 "adaptive Simpson quadrature did not converge",
472 ));
473 }
474
475 let left_value = adaptive_simpson(
476 function,
477 extra_args,
478 SimpsonState {
479 a: state.a,
480 b: c,
481 fa: state.fa,
482 fc: fd,
483 fb: state.fc,
484 whole: left,
485 tol: state.tol * 0.5,
486 depth: state.depth - 1,
487 },
488 func_count,
489 trace,
490 )
491 .await?;
492 let right_value = adaptive_simpson(
493 function,
494 extra_args,
495 SimpsonState {
496 a: c,
497 b: state.b,
498 fa: state.fc,
499 fc: fe,
500 fb: state.fb,
501 whole: right,
502 tol: state.tol * 0.5,
503 depth: state.depth - 1,
504 },
505 func_count,
506 trace,
507 )
508 .await?;
509 Ok(left_value + right_value)
510}
511
512fn midpoint(a: f64, b: f64) -> f64 {
513 a + (b - a) * 0.5
514}
515
516fn simpson(a: f64, b: f64, fa: f64, fm: f64, fb: f64) -> f64 {
517 (b - a) * (fa + 4.0 * fm + fb) / 6.0
518}
519
520async fn call_integrand(function: &Value, x: f64, extra_args: &[Value]) -> BuiltinResult<f64> {
521 let mut args = Vec::with_capacity(1 + extra_args.len());
522 args.push(Value::Num(x));
523 args.extend(extra_args.iter().cloned());
524 let value = call_function(function, args).await?;
525 let value = crate::dispatcher::gather_if_needed_async(&value).await?;
526 value_to_scalar(NAME, value)
527}
528
529struct QuadTrace;
530
531impl QuadTrace {
532 fn record(&mut self, func_count: usize, a: f64, b: f64, q: f64, _err: f64) {
533 crate::console::record_console_line(
534 crate::console::ConsoleStream::Stdout,
535 format!(
536 " {func_count:>5} {a:13.6e} {width:13.6e} {q:13.6e}",
537 width = b - a,
538 ),
539 );
540 }
541}
542
543fn finalize(result: QuadResult) -> BuiltinResult<Value> {
544 let q = Value::Num(result.q);
545 let fcnt = Value::Num(result.func_count as f64);
546 match crate::output_count::current_output_count() {
547 None => Ok(q),
548 Some(0) => Ok(Value::OutputList(Vec::new())),
549 Some(1) => Ok(crate::output_count::output_list_with_padding(1, vec![q])),
550 Some(2) => Ok(crate::output_count::output_list_with_padding(
551 2,
552 vec![q, fcnt],
553 )),
554 Some(_) => Err(quad_error_with_detail(
555 &QUAD_ERROR_TOO_MANY_OUTPUTS,
556 "quad: too many output arguments; maximum is 2",
557 )),
558 }
559}
560
561#[cfg(test)]
562mod tests {
563 use super::*;
564 use futures::executor::block_on;
565 use std::sync::Arc;
566
567 #[test]
568 fn quad_integrates_sine_with_default_tolerance() {
569 let result = block_on(quad_builtin(
570 Value::FunctionHandle("sin".into()),
571 Value::Num(0.0),
572 Value::Num(std::f64::consts::PI),
573 Vec::new(),
574 ))
575 .expect("quad");
576 match result {
577 Value::Num(value) => assert!((value - 2.0).abs() < 1.0e-6),
578 other => panic!("unexpected value {other:?}"),
579 }
580 }
581
582 #[test]
583 fn quad_respects_tighter_tolerance_on_polynomial() {
584 let _invoker = crate::user_functions::install_semantic_function_invoker(Some(Arc::new(
585 |_function, args, requested_outputs| {
586 assert_eq!(requested_outputs, 1);
587 let x = match &args[0] {
588 Value::Num(value) => *value,
589 other => panic!("expected x, got {other:?}"),
590 };
591 Box::pin(async move { Ok(Value::Num(x * x)) })
592 },
593 )));
594
595 let result = block_on(quad_builtin(
596 Value::BoundFunctionHandle {
597 name: "square".to_string(),
598 function: 7,
599 },
600 Value::Num(0.0),
601 Value::Num(1.0),
602 vec![Value::Num(1.0e-10)],
603 ))
604 .expect("quad");
605 match result {
606 Value::Num(value) => assert!((value - (1.0 / 3.0)).abs() < 1.0e-10),
607 other => panic!("unexpected value {other:?}"),
608 }
609 }
610
611 #[test]
612 fn quad_two_outputs_include_function_count() {
613 let _guard = crate::output_count::push_output_count(Some(2));
614 let result = block_on(quad_builtin(
615 Value::FunctionHandle("sin".into()),
616 Value::Num(0.0),
617 Value::Num(std::f64::consts::PI),
618 Vec::new(),
619 ))
620 .expect("quad");
621 match result {
622 Value::OutputList(outputs) => {
623 assert_eq!(outputs.len(), 2);
624 assert!(matches!(&outputs[0], Value::Num(value) if (value - 2.0).abs() < 1.0e-6));
625 assert!(matches!(&outputs[1], Value::Num(fcnt) if *fcnt >= 5.0));
626 }
627 other => panic!("unexpected value {other:?}"),
628 }
629 }
630
631 #[test]
632 fn quad_trace_records_rows() {
633 crate::console::reset_thread_buffer();
634 let result = block_on(quad_builtin(
635 Value::FunctionHandle("sin".into()),
636 Value::Num(0.0),
637 Value::Num(std::f64::consts::PI),
638 vec![Value::Num(1.0e-6), Value::Num(1.0)],
639 ))
640 .expect("quad");
641 assert!(matches!(result, Value::Num(_)));
642
643 let joined = crate::console::take_thread_buffer()
644 .into_iter()
645 .map(|entry| entry.text)
646 .collect::<String>();
647 let first_row: Vec<&str> = joined
648 .lines()
649 .next()
650 .expect("expected at least one trace row")
651 .split_whitespace()
652 .collect();
653 assert_eq!(first_row.len(), 4, "{joined}");
654 assert_eq!(first_row[0], "5", "{joined}");
655 assert!((first_row[1].parse::<f64>().unwrap() - 0.0).abs() < 1.0e-12);
656 assert!(
657 (first_row[2].parse::<f64>().unwrap() - std::f64::consts::PI).abs() < 1.0e-6,
658 "{joined}"
659 );
660 }
661
662 #[test]
663 fn quad_forwards_extra_arguments_after_tol_and_trace() {
664 let _invoker = crate::user_functions::install_semantic_function_invoker(Some(Arc::new(
665 |function, args, requested_outputs| {
666 assert_eq!(function, 42);
667 assert_eq!(requested_outputs, 1);
668 assert_eq!(args.len(), 2);
669 let x = match &args[0] {
670 Value::Num(value) => *value,
671 other => panic!("expected x, got {other:?}"),
672 };
673 let scale = match &args[1] {
674 Value::Num(value) => *value,
675 other => panic!("expected scale, got {other:?}"),
676 };
677 Box::pin(async move { Ok(Value::Num(scale * x)) })
678 },
679 )));
680
681 let result = block_on(quad_builtin(
682 Value::BoundFunctionHandle {
683 name: "scaled_line".to_string(),
684 function: 42,
685 },
686 Value::Num(0.0),
687 Value::Num(2.0),
688 vec![
689 Value::Tensor(Tensor::zeros(vec![0, 0])),
690 Value::Tensor(Tensor::zeros(vec![0, 0])),
691 Value::Num(3.0),
692 ],
693 ))
694 .expect("quad");
695 match result {
696 Value::Num(value) => assert!((value - 6.0).abs() < 1.0e-8),
697 other => panic!("unexpected value {other:?}"),
698 }
699 }
700
701 #[test]
702 fn quad_handles_oscillatory_integrand() {
703 let result = block_on(quad_builtin(
704 Value::FunctionHandle("sin".into()),
705 Value::Num(0.0),
706 Value::Num(2.0 * std::f64::consts::PI),
707 vec![Value::Num(1.0e-8)],
708 ))
709 .expect("quad");
710 match result {
711 Value::Num(value) => assert!(value.abs() < 1.0e-8),
712 other => panic!("unexpected value {other:?}"),
713 }
714 }
715
716 #[test]
717 fn quad_handles_integrable_endpoint_shape() {
718 let _invoker = crate::user_functions::install_semantic_function_invoker(Some(Arc::new(
719 |_function, args, _requested_outputs| {
720 let x = match &args[0] {
721 Value::Num(value) => *value,
722 other => panic!("expected x, got {other:?}"),
723 };
724 Box::pin(async move { Ok(Value::Num(x.sqrt())) })
725 },
726 )));
727
728 let result = block_on(quad_builtin(
729 Value::BoundFunctionHandle {
730 name: "sqrt_fn".to_string(),
731 function: 9,
732 },
733 Value::Num(0.0),
734 Value::Num(1.0),
735 vec![Value::Num(1.0e-7)],
736 ))
737 .expect("quad");
738 match result {
739 Value::Num(value) => assert!((value - (2.0 / 3.0)).abs() < 1.0e-6),
740 other => panic!("unexpected value {other:?}"),
741 }
742 }
743
744 #[test]
745 fn quad_reversed_bounds_negate_result() {
746 let result = block_on(quad_builtin(
747 Value::FunctionHandle("sin".into()),
748 Value::Num(std::f64::consts::PI),
749 Value::Num(0.0),
750 Vec::new(),
751 ))
752 .expect("quad");
753 match result {
754 Value::Num(value) => assert!((value + 2.0).abs() < 1.0e-6),
755 other => panic!("unexpected value {other:?}"),
756 }
757 }
758
759 #[test]
760 fn quad_rejects_more_than_two_outputs() {
761 let _guard = crate::output_count::push_output_count(Some(3));
762 let err = block_on(quad_builtin(
763 Value::FunctionHandle("sin".into()),
764 Value::Num(0.0),
765 Value::Num(1.0),
766 Vec::new(),
767 ))
768 .expect_err("too many outputs should fail");
769 assert_eq!(err.identifier(), Some("RunMat:quad:TooManyOutputs"));
770 }
771
772 #[test]
773 fn quad_descriptor_signatures_cover_legacy_forms() {
774 let labels: Vec<&str> = QUAD_DESCRIPTOR
775 .signatures
776 .iter()
777 .map(|signature| signature.label)
778 .collect();
779 assert_eq!(
780 labels,
781 vec![
782 "q = quad(fun, a, b)",
783 "q = quad(fun, a, b, tol, trace, p1, p2, ...)",
784 "[q, fcnt] = quad(fun, a, b)",
785 "[q, fcnt] = quad(fun, a, b, tol, trace, p1, p2, ...)",
786 ]
787 );
788
789 let codes: Vec<&str> = QUAD_DESCRIPTOR
790 .errors
791 .iter()
792 .map(|error| error.code)
793 .collect();
794 assert_eq!(
795 codes,
796 vec![
797 "RM.QUAD.INVALID_ARGUMENT",
798 "RM.QUAD.INVALID_INPUT",
799 "RM.QUAD.TOO_MANY_OUTPUTS",
800 ]
801 );
802 }
803}