Skip to main content

scirs2_integrate/
quad_vec.rs

1//! Vector-valued integration
2//!
3//! This module provides integration methods for vector-valued functions.
4//! These methods are useful when you need to integrate a function that
5//! returns arrays rather than scalar values.
6
7use scirs2_core::ndarray::Array1;
8use std::cmp::Ordering;
9use std::collections::BinaryHeap;
10use std::f64::consts::PI;
11use std::fmt;
12
13use crate::error::{IntegrateError, IntegrateResult};
14
15/// Result type for vector-valued integration
16#[derive(Clone, Debug)]
17pub struct QuadVecResult<T> {
18    /// The integral estimate
19    pub integral: Array1<T>,
20    /// The error estimate
21    pub error: Array1<T>,
22    /// Number of function evaluations
23    pub nfev: usize,
24    /// Number of integration subintervals used
25    pub nintervals: usize,
26    /// Whether the integration converged successfully
27    pub success: bool,
28}
29
30impl<T: fmt::Display> fmt::Display for QuadVecResult<T> {
31    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
32        write!(
33            f,
34            "QuadVecResult(\n  integral=[{:}],\n  error=[{:}],\n  nfev={},\n  nintervals={},\n  success={}\n)",
35            self.integral
36                .iter()
37                .map(|v| format!("{v}"))
38                .collect::<Vec<_>>()
39                .join(", "),
40            self.error
41                .iter()
42                .map(|v| format!("{v}"))
43                .collect::<Vec<_>>()
44                .join(", "),
45            self.nfev,
46            self.nintervals,
47            self.success
48        )
49    }
50}
51
52/// Options for quad_vec integration
53#[derive(Clone, Debug)]
54pub struct QuadVecOptions {
55    /// Absolute tolerance
56    pub epsabs: f64,
57    /// Relative tolerance
58    pub epsrel: f64,
59    /// Norm to use for error estimation
60    pub norm: NormType,
61    /// Maximum number of subintervals
62    pub limit: usize,
63    /// Quadrature rule to use
64    pub rule: QuadRule,
65    /// Additional points where the integrand should be sampled
66    pub points: Option<Vec<f64>>,
67}
68
69impl Default for QuadVecOptions {
70    fn default() -> Self {
71        Self {
72            epsabs: 1e-10,
73            epsrel: 1e-8,
74            norm: NormType::L2,
75            limit: 50,
76            rule: QuadRule::GK21,
77            points: None,
78        }
79    }
80}
81
82/// Type of norm to use for error estimation
83#[derive(Clone, Copy, Debug, PartialEq)]
84pub enum NormType {
85    /// Maximum absolute value
86    Max,
87    /// Euclidean (L2) norm
88    L2,
89}
90
91/// Quadrature rule to use
92#[derive(Clone, Copy, Debug, PartialEq)]
93pub enum QuadRule {
94    /// 15-point Gauss-Kronrod rule
95    GK15,
96    /// 21-point Gauss-Kronrod rule
97    GK21,
98    /// Composite trapezoidal rule
99    Trapezoid,
100}
101
102/// Subinterval for adaptive quadrature
103#[derive(Clone, Debug)]
104struct Subinterval {
105    /// Left endpoint
106    a: f64,
107    /// Right endpoint
108    b: f64,
109    /// Integral estimate on this subinterval
110    integral: Array1<f64>,
111    /// Error estimate on this subinterval
112    error: Array1<f64>,
113    /// Norm of the error estimate (priority for subdivision)
114    error_norm: f64,
115}
116
117impl PartialEq for Subinterval {
118    fn eq(&self, other: &Self) -> bool {
119        self.error_norm == other.error_norm
120    }
121}
122
123impl Eq for Subinterval {}
124
125impl PartialOrd for Subinterval {
126    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
127        Some(self.cmp(other))
128    }
129}
130
131impl Ord for Subinterval {
132    fn cmp(&self, other: &Self) -> Ordering {
133        // We want a max heap, so reverse the ordering
134        other
135            .error_norm
136            .partial_cmp(&self.error_norm)
137            .unwrap_or(Ordering::Equal)
138    }
139}
140
141/// Compute a norm for error estimation
142#[allow(dead_code)]
143fn compute_norm(array: &Array1<f64>, normtype: NormType) -> f64 {
144    match normtype {
145        NormType::Max => {
146            let mut max_abs = 0.0;
147            for &val in array.iter() {
148                let abs_val = val.abs();
149                if abs_val > max_abs {
150                    max_abs = abs_val;
151                }
152            }
153            max_abs
154        }
155        NormType::L2 => {
156            let mut sum_squares: f64 = 0.0;
157            for &val in array.iter() {
158                sum_squares += val * val;
159            }
160            sum_squares.sqrt()
161        }
162    }
163}
164
165/// Integration of a vector-valued function.
166///
167/// This function is similar to `quad`, but for functions that return
168/// arrays (vectors) rather than scalars.
169///
170/// # Parameters
171///
172/// * `f` - Function to integrate. Should take a float and return an array.
173/// * `a` - Lower bound of integration.
174/// * `b` - Upper bound of integration.
175/// * `options` - Integration options (optional).
176///
177/// # Returns
178///
179/// A result containing the integral estimate, error, and other information.
180///
181/// # Examples
182///
183/// ```
184/// use scirs2_core::ndarray::{Array1, arr1};
185/// use scirs2_integrate::quad_vec::{quad_vec, QuadVecOptions};
186///
187/// // Integrate a function that returns a 2D vector
188/// let f = |x: f64| arr1(&[x.sin(), x.cos()]);
189/// let result = quad_vec(f, 0.0, std::f64::consts::PI, None).expect("Operation failed");
190///
191/// // Result should be approximately [2.0, 0.0]
192/// assert!((result.integral[0] - 2.0).abs() < 1e-10);
193/// assert!(result.integral[1].abs() < 1e-10);
194/// ```
195#[allow(dead_code)]
196pub fn quad_vec<F>(
197    f: F,
198    a: f64,
199    b: f64,
200    options: Option<QuadVecOptions>,
201) -> IntegrateResult<QuadVecResult<f64>>
202where
203    F: Fn(f64) -> Array1<f64>,
204{
205    let options = options.unwrap_or_default();
206
207    // Validate inputs
208    if !a.is_finite() || !b.is_finite() {
209        return Err(IntegrateError::ValueError(
210            "Integration limits must be finite".to_string(),
211        ));
212    }
213
214    // Check if interval is effectively zero
215    if (b - a).abs() <= f64::EPSILON * a.abs().max(b.abs()) {
216        // Evaluate at midpoint to get vector dimension
217        let fval = f((a + b) / 2.0);
218        let zeros = Array1::zeros(fval.len());
219
220        return Ok(QuadVecResult {
221            integral: zeros.clone(),
222            error: zeros,
223            nfev: 1,
224            nintervals: 0,
225            success: true,
226        });
227    }
228
229    // Determine initial intervals
230    let intervals = if let Some(ref points) = options.points {
231        // Start with user-supplied breakpoints
232        let mut sorted_points: Vec<f64> = points.clone();
233        sorted_points.sort_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal));
234
235        // Filter out points outside [a, b] and remove duplicates
236        let mut filtered_points: Vec<f64> = Vec::new();
237        for &point in sorted_points.iter() {
238            if point > a
239                && point < b
240                && (filtered_points.is_empty()
241                    || (point - filtered_points.last().expect("Operation failed")).abs()
242                        > f64::EPSILON)
243            {
244                filtered_points.push(point);
245            }
246        }
247
248        // Create initial intervals
249        let mut intervals: Vec<(f64, f64)> = Vec::new();
250
251        if filtered_points.is_empty() {
252            intervals.push((a, b));
253        } else {
254            intervals.push((a, filtered_points[0]));
255
256            for i in 0..filtered_points.len() - 1 {
257                intervals.push((filtered_points[i], filtered_points[i + 1]));
258            }
259
260            intervals.push((*filtered_points.last().expect("Operation failed"), b));
261        }
262
263        intervals
264    } else {
265        // Just use the whole interval
266        vec![(a, b)]
267    };
268
269    // Evaluate function once to determine output size
270    let fval = f((intervals[0].0 + intervals[0].1) / 2.0);
271    let output_size = fval.len();
272
273    // Initialize priority queue for adaptive subdivision
274    let mut subintervals = BinaryHeap::new();
275    let mut nfev = 1; // We've already evaluated f once
276
277    // Process initial intervals
278    for (a_i, b_i) in intervals {
279        let (integral, error, evals) = evaluate_interval(&f, a_i, b_i, output_size, options.rule)?;
280
281        nfev += evals;
282
283        let error_norm = compute_norm(&error, options.norm);
284
285        subintervals.push(Subinterval {
286            a: a_i,
287            b: b_i,
288            integral,
289            error,
290            error_norm,
291        });
292    }
293
294    // Adaptive subdivision
295    while subintervals.len() < options.limit {
296        // Get interval with largest error
297        let interval = match subintervals.pop() {
298            Some(i) => i,
299            None => break, // Shouldn't happen, but just in case
300        };
301
302        // Check if we've reached desired accuracy
303        let total_integral = get_total(&subintervals, &interval, |i| &i.integral);
304        let total_error = get_total(&subintervals, &interval, |i| &i.error);
305
306        let error_norm = compute_norm(&total_error, options.norm);
307        let abs_tol = options.epsabs;
308        let rel_tol = options.epsrel * compute_norm(&total_integral, options.norm);
309
310        if error_norm <= abs_tol || error_norm <= rel_tol {
311            // Add the interval back, we'll compute the final result later
312            subintervals.push(interval);
313            break;
314        }
315
316        // Split the interval
317        let mid = (interval.a + interval.b) / 2.0;
318
319        // Evaluate on the two halves
320        let (left_integral, left_error, left_evals) =
321            evaluate_interval(&f, interval.a, mid, output_size, options.rule)?;
322
323        let (right_integral, right_error, right_evals) =
324            evaluate_interval(&f, mid, interval.b, output_size, options.rule)?;
325
326        nfev += left_evals + right_evals;
327
328        // Create new intervals
329        let left_error_norm = compute_norm(&left_error, options.norm);
330        let right_error_norm = compute_norm(&right_error, options.norm);
331
332        subintervals.push(Subinterval {
333            a: interval.a,
334            b: mid,
335            integral: left_integral,
336            error: left_error,
337            error_norm: left_error_norm,
338        });
339
340        subintervals.push(Subinterval {
341            a: mid,
342            b: interval.b,
343            integral: right_integral,
344            error: right_error,
345            error_norm: right_error_norm,
346        });
347    }
348
349    // Compute final result
350    let interval_vec: Vec<Subinterval> = subintervals.into_vec();
351    let mut total_integral = Array1::zeros(output_size);
352    let mut total_error = Array1::zeros(output_size);
353
354    for interval in &interval_vec {
355        for (i, &val) in interval.integral.iter().enumerate() {
356            total_integral[i] += val;
357        }
358
359        for (i, &val) in interval.error.iter().enumerate() {
360            total_error[i] += val;
361        }
362    }
363
364    // Check for convergence
365    let error_norm = compute_norm(&total_error, options.norm);
366    let abs_tol = options.epsabs;
367    let rel_tol = options.epsrel * compute_norm(&total_integral, options.norm);
368
369    let success = error_norm <= abs_tol || error_norm <= rel_tol;
370
371    Ok(QuadVecResult {
372        integral: total_integral,
373        error: total_error,
374        nfev,
375        nintervals: interval_vec.len(),
376        success,
377    })
378}
379
380/// Compute a property of all intervals combined
381#[allow(dead_code)]
382fn get_total<F, T>(heap: &BinaryHeap<Subinterval>, extra: &Subinterval, extract: F) -> Array1<T>
383where
384    F: Fn(&Subinterval) -> &Array1<T>,
385    T: Clone + scirs2_core::numeric::Zero,
386{
387    let mut result = extract(extra).clone();
388
389    for interval in heap.iter() {
390        let property = extract(interval);
391
392        for (i, val) in property.iter().enumerate() {
393            result[i] = result[i].clone() + val.clone();
394        }
395    }
396
397    result
398}
399
400/// Evaluate the integral on a specific interval
401#[allow(dead_code)]
402fn evaluate_interval<F>(
403    f: &F,
404    a: f64,
405    b: f64,
406    output_size: usize,
407    rule: QuadRule,
408) -> IntegrateResult<(Array1<f64>, Array1<f64>, usize)>
409where
410    F: Fn(f64) -> Array1<f64>,
411{
412    match rule {
413        QuadRule::GK15 => {
414            // Gauss-Kronrod 15-point rule (7-point Gauss, 15-point Kronrod)
415            // Points and weights from SciPy
416            let nodes = [
417                -0.9914553711208126f64,
418                -0.9491079123427585,
419                -0.8648644233597691,
420                -0.7415311855993944,
421                -0.5860872354676911,
422                -0.4058451513773972,
423                -0.2077849550078985,
424                0.0,
425                0.2077849550078985,
426                0.4058451513773972,
427                0.5860872354676911,
428                0.7415311855993944,
429                0.8648644233597691,
430                0.9491079123427585,
431                0.9914553711208126,
432            ];
433
434            let weights_k = [
435                0.022935322010529224f64,
436                0.063_092_092_629_978_56,
437                0.10479001032225018,
438                0.14065325971552592,
439                0.169_004_726_639_267_9,
440                0.190_350_578_064_785_4,
441                0.20443294007529889,
442                0.20948214108472782,
443                0.20443294007529889,
444                0.190_350_578_064_785_4,
445                0.169_004_726_639_267_9,
446                0.14065325971552592,
447                0.10479001032225018,
448                0.063_092_092_629_978_56,
449                0.022935322010529224,
450            ];
451
452            // Abscissae for the 7-point Gauss rule (odd indices of xgk)
453            let weights_g = [
454                0.129_484_966_168_869_7_f64,
455                0.27970539148927664,
456                0.381_830_050_505_118_9,
457                0.417_959_183_673_469_4,
458                0.381_830_050_505_118_9,
459                0.27970539148927664,
460                0.129_484_966_168_869_7,
461            ];
462
463            evaluate_rule(f, a, b, output_size, &nodes, &weights_g, &weights_k)
464        }
465        QuadRule::GK21 => {
466            // Gauss-Kronrod 21-point rule (10-point Gauss, 21-point Kronrod)
467            // Points and weights from SciPy
468            let nodes = [
469                -0.9956571630258081f64,
470                -0.9739065285171717,
471                -0.9301574913557082,
472                -0.8650633666889845,
473                -0.7808177265864169,
474                -0.6794095682990244,
475                -0.5627571346686047,
476                -0.4333953941292472,
477                -0.2943928627014602,
478                -0.1488743389816312,
479                0.0,
480                0.1488743389816312,
481                0.2943928627014602,
482                0.4333953941292472,
483                0.5627571346686047,
484                0.6794095682990244,
485                0.7808177265864169,
486                0.8650633666889845,
487                0.9301574913557082,
488                0.9739065285171717,
489                0.9956571630258081,
490            ];
491
492            let weights_k = [
493                0.011694638867371874f64,
494                0.032558162307964725,
495                0.054755896574351995,
496                0.075_039_674_810_919_96,
497                0.093_125_454_583_697_6,
498                0.109_387_158_802_297_64,
499                0.123_491_976_262_065_84,
500                0.134_709_217_311_473_34,
501                0.142_775_938_577_060_09,
502                0.147_739_104_901_338_49,
503                0.149_445_554_002_916_9,
504                0.147_739_104_901_338_49,
505                0.142_775_938_577_060_09,
506                0.134_709_217_311_473_34,
507                0.123_491_976_262_065_84,
508                0.109_387_158_802_297_64,
509                0.093_125_454_583_697_6,
510                0.075_039_674_810_919_96,
511                0.054755896574351995,
512                0.032558162307964725,
513                0.011694638867371874,
514            ];
515
516            // Abscissae for the 10-point Gauss rule (every other point)
517            let weights_g = [
518                0.066_671_344_308_688_14f64,
519                0.149_451_349_150_580_6,
520                0.219_086_362_515_982_04,
521                0.269_266_719_309_996_36,
522                0.295_524_224_714_752_9,
523                0.295_524_224_714_752_9,
524                0.269_266_719_309_996_36,
525                0.219_086_362_515_982_04,
526                0.149_451_349_150_580_6,
527                0.066_671_344_308_688_14,
528            ];
529
530            evaluate_rule(f, a, b, output_size, &nodes, &weights_g, &weights_k)
531        }
532        QuadRule::Trapezoid => {
533            // Simple trapezoid rule with 15 points
534            let n = 15;
535            let mut integral = Array1::zeros(output_size);
536            let mut error = Array1::zeros(output_size);
537
538            let h = (b - a) / (n as f64 - 1.0);
539            let fa = f(a);
540            let fb = f(b);
541
542            // Add endpoints with half weight
543            for (i, (&fa_i, &fb_i)) in fa.iter().zip(fb.iter()).enumerate() {
544                integral[i] = 0.5 * (fa_i + fb_i);
545            }
546
547            // Add interior points
548            for j in 1..n - 1 {
549                let x = a + (j as f64) * h;
550                let fx = f(x);
551
552                for (i, &fx_i) in fx.iter().enumerate() {
553                    integral[i] += fx_i;
554                }
555            }
556
557            // Scale by h
558            for i in 0..output_size {
559                integral[i] *= h;
560
561                // Crude error estimate
562                error[i] = 1e-2 * integral[i].abs();
563            }
564
565            Ok((integral, error, n))
566        }
567    }
568}
569
570/// Evaluate a Gauss-Kronrod rule on an interval
571#[allow(dead_code)]
572fn evaluate_rule<F>(
573    f: &F,
574    a: f64,
575    b: f64,
576    output_size: usize,
577    nodes: &[f64],
578    weights_g: &[f64],
579    weights_k: &[f64],
580) -> IntegrateResult<(Array1<f64>, Array1<f64>, usize)>
581where
582    F: Fn(f64) -> Array1<f64>,
583{
584    let _n = nodes.len();
585
586    let mut integral_k = Array1::zeros(output_size);
587    let mut integral_g = Array1::zeros(output_size);
588
589    // Map nodes to [a, b]
590    let mid = (a + b) / 2.0;
591    let half_length = (b - a) / 2.0;
592
593    let mut nfev = 0;
594
595    // For GK rules, Gauss points are at odd indices (1, 3, 5, ...)
596    let mut gauss_idx = 0;
597
598    // Evaluate at all Kronrod points
599    for (i, &node) in nodes.iter().enumerate() {
600        let x = mid + half_length * node;
601        let fx = f(x);
602        nfev += 1;
603
604        // Add to Kronrod integral
605        for (j, &fx_j) in fx.iter().enumerate() {
606            integral_k[j] += weights_k[i] * fx_j;
607        }
608
609        // Check if this is also a Gauss point
610        // For GK15: Gauss points are at indices 1, 3, 5, 7, 9, 11, 13
611        // For GK21: Gauss points are at indices 1, 3, 5, 7, 9, 11, 13, 15, 17, 19
612        if i % 2 == 1 && gauss_idx < weights_g.len() {
613            for (j, &fx_j) in fx.iter().enumerate() {
614                integral_g[j] += weights_g[gauss_idx] * fx_j;
615            }
616            gauss_idx += 1;
617        }
618    }
619
620    // Scale by half-length
621    integral_k *= half_length;
622    integral_g *= half_length;
623
624    // Compute error estimate
625    // Error is estimated as (200 * |I_k - I_g|)^1.5
626    let mut error = Array1::zeros(output_size);
627    for i in 0..output_size {
628        let diff = (integral_k[i] - integral_g[i]).abs();
629        error[i] = (200.0 * diff).powf(1.5_f64);
630    }
631
632    Ok((integral_k, error, nfev))
633}
634
635#[cfg(test)]
636mod tests {
637    use super::*;
638    use approx::assert_abs_diff_eq;
639    use scirs2_core::ndarray::arr1;
640
641    #[test]
642    fn test_simple_integral() {
643        // Integrate [x, x^2] from 0 to 1
644        let f = |x: f64| arr1(&[x, x * x]);
645        let result = quad_vec(f, 0.0, 1.0, None).expect("Operation failed");
646
647        assert_abs_diff_eq!(result.integral[0], 0.5, epsilon = 1e-10);
648        assert_abs_diff_eq!(result.integral[1], 1.0 / 3.0, epsilon = 1e-10);
649        assert!(result.success);
650    }
651
652    #[test]
653    fn test_trig_functions() {
654        // Integrate [sin(x), cos(x)] from 0 to π
655        let f = |x: f64| arr1(&[x.sin(), x.cos()]);
656        let result = quad_vec(f, 0.0, PI, None).expect("Operation failed");
657
658        assert_abs_diff_eq!(result.integral[0], 2.0, epsilon = 1e-10);
659        assert_abs_diff_eq!(result.integral[1], 0.0, epsilon = 1e-10);
660        assert!(result.success);
661    }
662
663    #[test]
664    fn test_with_breakpoints() {
665        // Integrate [x, x^2] from 0 to 2 with a breakpoint at x=1
666        let f = |x: f64| arr1(&[x, x * x]);
667
668        let options = QuadVecOptions {
669            points: Some(vec![1.0]),
670            ..Default::default()
671        };
672
673        let result = quad_vec(f, 0.0, 2.0, Some(options)).expect("Operation failed");
674
675        assert_abs_diff_eq!(result.integral[0], 2.0, epsilon = 1e-10);
676        assert_abs_diff_eq!(result.integral[1], 8.0 / 3.0, epsilon = 1e-10);
677        assert!(result.success);
678    }
679
680    #[test]
681    fn test_different_rules() {
682        // Test with different quadrature rules
683        let f = |x: f64| arr1(&[x.sin()]);
684
685        let options_gk15 = QuadVecOptions {
686            rule: QuadRule::GK15,
687            ..Default::default()
688        };
689
690        let options_gk21 = QuadVecOptions {
691            rule: QuadRule::GK21,
692            ..Default::default()
693        };
694
695        let options_trapezoid = QuadVecOptions {
696            rule: QuadRule::Trapezoid,
697            ..Default::default()
698        };
699
700        let result_gk15 = quad_vec(f, 0.0, PI, Some(options_gk15)).expect("Operation failed");
701        let result_gk21 = quad_vec(f, 0.0, PI, Some(options_gk21)).expect("Operation failed");
702        let result_trapezoid =
703            quad_vec(f, 0.0, PI, Some(options_trapezoid)).expect("Operation failed");
704
705        assert_abs_diff_eq!(result_gk15.integral[0], 2.0, epsilon = 1e-10);
706        assert_abs_diff_eq!(result_gk21.integral[0], 2.0, epsilon = 1e-10);
707        assert_abs_diff_eq!(result_trapezoid.integral[0], 2.0, epsilon = 2e-3); // Lower precision for trapezoid
708    }
709
710    #[test]
711    fn test_error_norms() {
712        // Test Max norm
713        let arr = arr1(&[1.0, -2.0, 0.5]);
714        let max_norm = compute_norm(&arr, NormType::Max);
715        assert_abs_diff_eq!(max_norm, 2.0, epsilon = 1e-10);
716
717        // Test L2 norm
718        let l2_norm = compute_norm(&arr, NormType::L2);
719        assert_abs_diff_eq!(
720            l2_norm,
721            (1.0f64 * 1.0 + 2.0 * 2.0 + 0.5 * 0.5).sqrt(),
722            epsilon = 1e-10
723        );
724    }
725}