1use scirs2_core::ndarray::Array1;
8use std::cmp::Ordering;
9use std::collections::BinaryHeap;
10use std::f64::consts::PI;
11use std::fmt;
12
13use crate::error::{IntegrateError, IntegrateResult};
14
15#[derive(Clone, Debug)]
17pub struct QuadVecResult<T> {
18 pub integral: Array1<T>,
20 pub error: Array1<T>,
22 pub nfev: usize,
24 pub nintervals: usize,
26 pub success: bool,
28}
29
30impl<T: fmt::Display> fmt::Display for QuadVecResult<T> {
31 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
32 write!(
33 f,
34 "QuadVecResult(\n integral=[{:}],\n error=[{:}],\n nfev={},\n nintervals={},\n success={}\n)",
35 self.integral
36 .iter()
37 .map(|v| format!("{v}"))
38 .collect::<Vec<_>>()
39 .join(", "),
40 self.error
41 .iter()
42 .map(|v| format!("{v}"))
43 .collect::<Vec<_>>()
44 .join(", "),
45 self.nfev,
46 self.nintervals,
47 self.success
48 )
49 }
50}
51
52#[derive(Clone, Debug)]
54pub struct QuadVecOptions {
55 pub epsabs: f64,
57 pub epsrel: f64,
59 pub norm: NormType,
61 pub limit: usize,
63 pub rule: QuadRule,
65 pub points: Option<Vec<f64>>,
67}
68
69impl Default for QuadVecOptions {
70 fn default() -> Self {
71 Self {
72 epsabs: 1e-10,
73 epsrel: 1e-8,
74 norm: NormType::L2,
75 limit: 50,
76 rule: QuadRule::GK21,
77 points: None,
78 }
79 }
80}
81
82#[derive(Clone, Copy, Debug, PartialEq)]
84pub enum NormType {
85 Max,
87 L2,
89}
90
91#[derive(Clone, Copy, Debug, PartialEq)]
93pub enum QuadRule {
94 GK15,
96 GK21,
98 Trapezoid,
100}
101
102#[derive(Clone, Debug)]
104struct Subinterval {
105 a: f64,
107 b: f64,
109 integral: Array1<f64>,
111 error: Array1<f64>,
113 error_norm: f64,
115}
116
117impl PartialEq for Subinterval {
118 fn eq(&self, other: &Self) -> bool {
119 self.error_norm == other.error_norm
120 }
121}
122
123impl Eq for Subinterval {}
124
125impl PartialOrd for Subinterval {
126 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
127 Some(self.cmp(other))
128 }
129}
130
131impl Ord for Subinterval {
132 fn cmp(&self, other: &Self) -> Ordering {
133 other
135 .error_norm
136 .partial_cmp(&self.error_norm)
137 .unwrap_or(Ordering::Equal)
138 }
139}
140
141#[allow(dead_code)]
143fn compute_norm(array: &Array1<f64>, normtype: NormType) -> f64 {
144 match normtype {
145 NormType::Max => {
146 let mut max_abs = 0.0;
147 for &val in array.iter() {
148 let abs_val = val.abs();
149 if abs_val > max_abs {
150 max_abs = abs_val;
151 }
152 }
153 max_abs
154 }
155 NormType::L2 => {
156 let mut sum_squares: f64 = 0.0;
157 for &val in array.iter() {
158 sum_squares += val * val;
159 }
160 sum_squares.sqrt()
161 }
162 }
163}
164
165#[allow(dead_code)]
196pub fn quad_vec<F>(
197 f: F,
198 a: f64,
199 b: f64,
200 options: Option<QuadVecOptions>,
201) -> IntegrateResult<QuadVecResult<f64>>
202where
203 F: Fn(f64) -> Array1<f64>,
204{
205 let options = options.unwrap_or_default();
206
207 if !a.is_finite() || !b.is_finite() {
209 return Err(IntegrateError::ValueError(
210 "Integration limits must be finite".to_string(),
211 ));
212 }
213
214 if (b - a).abs() <= f64::EPSILON * a.abs().max(b.abs()) {
216 let fval = f((a + b) / 2.0);
218 let zeros = Array1::zeros(fval.len());
219
220 return Ok(QuadVecResult {
221 integral: zeros.clone(),
222 error: zeros,
223 nfev: 1,
224 nintervals: 0,
225 success: true,
226 });
227 }
228
229 let intervals = if let Some(ref points) = options.points {
231 let mut sorted_points: Vec<f64> = points.clone();
233 sorted_points.sort_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal));
234
235 let mut filtered_points: Vec<f64> = Vec::new();
237 for &point in sorted_points.iter() {
238 if point > a
239 && point < b
240 && (filtered_points.is_empty()
241 || (point - filtered_points.last().expect("Operation failed")).abs()
242 > f64::EPSILON)
243 {
244 filtered_points.push(point);
245 }
246 }
247
248 let mut intervals: Vec<(f64, f64)> = Vec::new();
250
251 if filtered_points.is_empty() {
252 intervals.push((a, b));
253 } else {
254 intervals.push((a, filtered_points[0]));
255
256 for i in 0..filtered_points.len() - 1 {
257 intervals.push((filtered_points[i], filtered_points[i + 1]));
258 }
259
260 intervals.push((*filtered_points.last().expect("Operation failed"), b));
261 }
262
263 intervals
264 } else {
265 vec![(a, b)]
267 };
268
269 let fval = f((intervals[0].0 + intervals[0].1) / 2.0);
271 let output_size = fval.len();
272
273 let mut subintervals = BinaryHeap::new();
275 let mut nfev = 1; for (a_i, b_i) in intervals {
279 let (integral, error, evals) = evaluate_interval(&f, a_i, b_i, output_size, options.rule)?;
280
281 nfev += evals;
282
283 let error_norm = compute_norm(&error, options.norm);
284
285 subintervals.push(Subinterval {
286 a: a_i,
287 b: b_i,
288 integral,
289 error,
290 error_norm,
291 });
292 }
293
294 while subintervals.len() < options.limit {
296 let interval = match subintervals.pop() {
298 Some(i) => i,
299 None => break, };
301
302 let total_integral = get_total(&subintervals, &interval, |i| &i.integral);
304 let total_error = get_total(&subintervals, &interval, |i| &i.error);
305
306 let error_norm = compute_norm(&total_error, options.norm);
307 let abs_tol = options.epsabs;
308 let rel_tol = options.epsrel * compute_norm(&total_integral, options.norm);
309
310 if error_norm <= abs_tol || error_norm <= rel_tol {
311 subintervals.push(interval);
313 break;
314 }
315
316 let mid = (interval.a + interval.b) / 2.0;
318
319 let (left_integral, left_error, left_evals) =
321 evaluate_interval(&f, interval.a, mid, output_size, options.rule)?;
322
323 let (right_integral, right_error, right_evals) =
324 evaluate_interval(&f, mid, interval.b, output_size, options.rule)?;
325
326 nfev += left_evals + right_evals;
327
328 let left_error_norm = compute_norm(&left_error, options.norm);
330 let right_error_norm = compute_norm(&right_error, options.norm);
331
332 subintervals.push(Subinterval {
333 a: interval.a,
334 b: mid,
335 integral: left_integral,
336 error: left_error,
337 error_norm: left_error_norm,
338 });
339
340 subintervals.push(Subinterval {
341 a: mid,
342 b: interval.b,
343 integral: right_integral,
344 error: right_error,
345 error_norm: right_error_norm,
346 });
347 }
348
349 let interval_vec: Vec<Subinterval> = subintervals.into_vec();
351 let mut total_integral = Array1::zeros(output_size);
352 let mut total_error = Array1::zeros(output_size);
353
354 for interval in &interval_vec {
355 for (i, &val) in interval.integral.iter().enumerate() {
356 total_integral[i] += val;
357 }
358
359 for (i, &val) in interval.error.iter().enumerate() {
360 total_error[i] += val;
361 }
362 }
363
364 let error_norm = compute_norm(&total_error, options.norm);
366 let abs_tol = options.epsabs;
367 let rel_tol = options.epsrel * compute_norm(&total_integral, options.norm);
368
369 let success = error_norm <= abs_tol || error_norm <= rel_tol;
370
371 Ok(QuadVecResult {
372 integral: total_integral,
373 error: total_error,
374 nfev,
375 nintervals: interval_vec.len(),
376 success,
377 })
378}
379
380#[allow(dead_code)]
382fn get_total<F, T>(heap: &BinaryHeap<Subinterval>, extra: &Subinterval, extract: F) -> Array1<T>
383where
384 F: Fn(&Subinterval) -> &Array1<T>,
385 T: Clone + scirs2_core::numeric::Zero,
386{
387 let mut result = extract(extra).clone();
388
389 for interval in heap.iter() {
390 let property = extract(interval);
391
392 for (i, val) in property.iter().enumerate() {
393 result[i] = result[i].clone() + val.clone();
394 }
395 }
396
397 result
398}
399
400#[allow(dead_code)]
402fn evaluate_interval<F>(
403 f: &F,
404 a: f64,
405 b: f64,
406 output_size: usize,
407 rule: QuadRule,
408) -> IntegrateResult<(Array1<f64>, Array1<f64>, usize)>
409where
410 F: Fn(f64) -> Array1<f64>,
411{
412 match rule {
413 QuadRule::GK15 => {
414 let nodes = [
417 -0.9914553711208126f64,
418 -0.9491079123427585,
419 -0.8648644233597691,
420 -0.7415311855993944,
421 -0.5860872354676911,
422 -0.4058451513773972,
423 -0.2077849550078985,
424 0.0,
425 0.2077849550078985,
426 0.4058451513773972,
427 0.5860872354676911,
428 0.7415311855993944,
429 0.8648644233597691,
430 0.9491079123427585,
431 0.9914553711208126,
432 ];
433
434 let weights_k = [
435 0.022935322010529224f64,
436 0.063_092_092_629_978_56,
437 0.10479001032225018,
438 0.14065325971552592,
439 0.169_004_726_639_267_9,
440 0.190_350_578_064_785_4,
441 0.20443294007529889,
442 0.20948214108472782,
443 0.20443294007529889,
444 0.190_350_578_064_785_4,
445 0.169_004_726_639_267_9,
446 0.14065325971552592,
447 0.10479001032225018,
448 0.063_092_092_629_978_56,
449 0.022935322010529224,
450 ];
451
452 let weights_g = [
454 0.129_484_966_168_869_7_f64,
455 0.27970539148927664,
456 0.381_830_050_505_118_9,
457 0.417_959_183_673_469_4,
458 0.381_830_050_505_118_9,
459 0.27970539148927664,
460 0.129_484_966_168_869_7,
461 ];
462
463 evaluate_rule(f, a, b, output_size, &nodes, &weights_g, &weights_k)
464 }
465 QuadRule::GK21 => {
466 let nodes = [
469 -0.9956571630258081f64,
470 -0.9739065285171717,
471 -0.9301574913557082,
472 -0.8650633666889845,
473 -0.7808177265864169,
474 -0.6794095682990244,
475 -0.5627571346686047,
476 -0.4333953941292472,
477 -0.2943928627014602,
478 -0.1488743389816312,
479 0.0,
480 0.1488743389816312,
481 0.2943928627014602,
482 0.4333953941292472,
483 0.5627571346686047,
484 0.6794095682990244,
485 0.7808177265864169,
486 0.8650633666889845,
487 0.9301574913557082,
488 0.9739065285171717,
489 0.9956571630258081,
490 ];
491
492 let weights_k = [
493 0.011694638867371874f64,
494 0.032558162307964725,
495 0.054755896574351995,
496 0.075_039_674_810_919_96,
497 0.093_125_454_583_697_6,
498 0.109_387_158_802_297_64,
499 0.123_491_976_262_065_84,
500 0.134_709_217_311_473_34,
501 0.142_775_938_577_060_09,
502 0.147_739_104_901_338_49,
503 0.149_445_554_002_916_9,
504 0.147_739_104_901_338_49,
505 0.142_775_938_577_060_09,
506 0.134_709_217_311_473_34,
507 0.123_491_976_262_065_84,
508 0.109_387_158_802_297_64,
509 0.093_125_454_583_697_6,
510 0.075_039_674_810_919_96,
511 0.054755896574351995,
512 0.032558162307964725,
513 0.011694638867371874,
514 ];
515
516 let weights_g = [
518 0.066_671_344_308_688_14f64,
519 0.149_451_349_150_580_6,
520 0.219_086_362_515_982_04,
521 0.269_266_719_309_996_36,
522 0.295_524_224_714_752_9,
523 0.295_524_224_714_752_9,
524 0.269_266_719_309_996_36,
525 0.219_086_362_515_982_04,
526 0.149_451_349_150_580_6,
527 0.066_671_344_308_688_14,
528 ];
529
530 evaluate_rule(f, a, b, output_size, &nodes, &weights_g, &weights_k)
531 }
532 QuadRule::Trapezoid => {
533 let n = 15;
535 let mut integral = Array1::zeros(output_size);
536 let mut error = Array1::zeros(output_size);
537
538 let h = (b - a) / (n as f64 - 1.0);
539 let fa = f(a);
540 let fb = f(b);
541
542 for (i, (&fa_i, &fb_i)) in fa.iter().zip(fb.iter()).enumerate() {
544 integral[i] = 0.5 * (fa_i + fb_i);
545 }
546
547 for j in 1..n - 1 {
549 let x = a + (j as f64) * h;
550 let fx = f(x);
551
552 for (i, &fx_i) in fx.iter().enumerate() {
553 integral[i] += fx_i;
554 }
555 }
556
557 for i in 0..output_size {
559 integral[i] *= h;
560
561 error[i] = 1e-2 * integral[i].abs();
563 }
564
565 Ok((integral, error, n))
566 }
567 }
568}
569
570#[allow(dead_code)]
572fn evaluate_rule<F>(
573 f: &F,
574 a: f64,
575 b: f64,
576 output_size: usize,
577 nodes: &[f64],
578 weights_g: &[f64],
579 weights_k: &[f64],
580) -> IntegrateResult<(Array1<f64>, Array1<f64>, usize)>
581where
582 F: Fn(f64) -> Array1<f64>,
583{
584 let _n = nodes.len();
585
586 let mut integral_k = Array1::zeros(output_size);
587 let mut integral_g = Array1::zeros(output_size);
588
589 let mid = (a + b) / 2.0;
591 let half_length = (b - a) / 2.0;
592
593 let mut nfev = 0;
594
595 let mut gauss_idx = 0;
597
598 for (i, &node) in nodes.iter().enumerate() {
600 let x = mid + half_length * node;
601 let fx = f(x);
602 nfev += 1;
603
604 for (j, &fx_j) in fx.iter().enumerate() {
606 integral_k[j] += weights_k[i] * fx_j;
607 }
608
609 if i % 2 == 1 && gauss_idx < weights_g.len() {
613 for (j, &fx_j) in fx.iter().enumerate() {
614 integral_g[j] += weights_g[gauss_idx] * fx_j;
615 }
616 gauss_idx += 1;
617 }
618 }
619
620 integral_k *= half_length;
622 integral_g *= half_length;
623
624 let mut error = Array1::zeros(output_size);
627 for i in 0..output_size {
628 let diff = (integral_k[i] - integral_g[i]).abs();
629 error[i] = (200.0 * diff).powf(1.5_f64);
630 }
631
632 Ok((integral_k, error, nfev))
633}
634
635#[cfg(test)]
636mod tests {
637 use super::*;
638 use approx::assert_abs_diff_eq;
639 use scirs2_core::ndarray::arr1;
640
641 #[test]
642 fn test_simple_integral() {
643 let f = |x: f64| arr1(&[x, x * x]);
645 let result = quad_vec(f, 0.0, 1.0, None).expect("Operation failed");
646
647 assert_abs_diff_eq!(result.integral[0], 0.5, epsilon = 1e-10);
648 assert_abs_diff_eq!(result.integral[1], 1.0 / 3.0, epsilon = 1e-10);
649 assert!(result.success);
650 }
651
652 #[test]
653 fn test_trig_functions() {
654 let f = |x: f64| arr1(&[x.sin(), x.cos()]);
656 let result = quad_vec(f, 0.0, PI, None).expect("Operation failed");
657
658 assert_abs_diff_eq!(result.integral[0], 2.0, epsilon = 1e-10);
659 assert_abs_diff_eq!(result.integral[1], 0.0, epsilon = 1e-10);
660 assert!(result.success);
661 }
662
663 #[test]
664 fn test_with_breakpoints() {
665 let f = |x: f64| arr1(&[x, x * x]);
667
668 let options = QuadVecOptions {
669 points: Some(vec![1.0]),
670 ..Default::default()
671 };
672
673 let result = quad_vec(f, 0.0, 2.0, Some(options)).expect("Operation failed");
674
675 assert_abs_diff_eq!(result.integral[0], 2.0, epsilon = 1e-10);
676 assert_abs_diff_eq!(result.integral[1], 8.0 / 3.0, epsilon = 1e-10);
677 assert!(result.success);
678 }
679
680 #[test]
681 fn test_different_rules() {
682 let f = |x: f64| arr1(&[x.sin()]);
684
685 let options_gk15 = QuadVecOptions {
686 rule: QuadRule::GK15,
687 ..Default::default()
688 };
689
690 let options_gk21 = QuadVecOptions {
691 rule: QuadRule::GK21,
692 ..Default::default()
693 };
694
695 let options_trapezoid = QuadVecOptions {
696 rule: QuadRule::Trapezoid,
697 ..Default::default()
698 };
699
700 let result_gk15 = quad_vec(f, 0.0, PI, Some(options_gk15)).expect("Operation failed");
701 let result_gk21 = quad_vec(f, 0.0, PI, Some(options_gk21)).expect("Operation failed");
702 let result_trapezoid =
703 quad_vec(f, 0.0, PI, Some(options_trapezoid)).expect("Operation failed");
704
705 assert_abs_diff_eq!(result_gk15.integral[0], 2.0, epsilon = 1e-10);
706 assert_abs_diff_eq!(result_gk21.integral[0], 2.0, epsilon = 1e-10);
707 assert_abs_diff_eq!(result_trapezoid.integral[0], 2.0, epsilon = 2e-3); }
709
710 #[test]
711 fn test_error_norms() {
712 let arr = arr1(&[1.0, -2.0, 0.5]);
714 let max_norm = compute_norm(&arr, NormType::Max);
715 assert_abs_diff_eq!(max_norm, 2.0, epsilon = 1e-10);
716
717 let l2_norm = compute_norm(&arr, NormType::L2);
719 assert_abs_diff_eq!(
720 l2_norm,
721 (1.0f64 * 1.0 + 2.0 * 2.0 + 0.5 * 0.5).sqrt(),
722 epsilon = 1e-10
723 );
724 }
725}