Skip to main content

numra_integrate/
adaptive.rs

1//! Adaptive Gauss-Kronrod quadrature (G7K15).
2//!
3//! This is the workhorse integration routine, equivalent to SciPy's `quad`.
4//! Uses a 15-point Kronrod rule with embedded 7-point Gauss rule for error
5//! estimation, with adaptive bisection of subintervals ordered by error.
6//!
7//! Author: Moussa Leblouba
8//! Date: 9 February 2026
9//! Modified: 2 May 2026
10
11use alloc::collections::BinaryHeap;
12use core::cmp::Ordering;
13
14use numra_core::Scalar;
15
16use crate::error::IntegrationError;
17
18extern crate alloc;
19
20/// Options for adaptive quadrature.
21#[derive(Clone, Debug)]
22pub struct QuadOptions<S: Scalar> {
23    /// Absolute tolerance.
24    pub atol: S,
25    /// Relative tolerance.
26    pub rtol: S,
27    /// Maximum number of subinterval subdivisions.
28    pub max_subdivisions: usize,
29    /// Known singularities or discontinuities where the interval should be pre-split.
30    pub points: Vec<S>,
31}
32
33impl<S: Scalar> Default for QuadOptions<S> {
34    fn default() -> Self {
35        Self {
36            atol: S::from_f64(1.49e-8),
37            rtol: S::from_f64(1.49e-8),
38            max_subdivisions: 50,
39            points: Vec::new(),
40        }
41    }
42}
43
44impl<S: Scalar> QuadOptions<S> {
45    /// Set absolute tolerance.
46    pub fn atol(mut self, atol: S) -> Self {
47        self.atol = atol;
48        self
49    }
50
51    /// Set relative tolerance.
52    pub fn rtol(mut self, rtol: S) -> Self {
53        self.rtol = rtol;
54        self
55    }
56
57    /// Set maximum subdivisions.
58    pub fn max_subdivisions(mut self, max: usize) -> Self {
59        self.max_subdivisions = max;
60        self
61    }
62
63    /// Set known singularity/discontinuity points.
64    pub fn points(mut self, pts: Vec<S>) -> Self {
65        self.points = pts;
66        self
67    }
68}
69
70/// Result of numerical integration.
71#[derive(Clone, Debug)]
72pub struct QuadResult<S: Scalar> {
73    /// Estimated value of the integral.
74    pub value: S,
75    /// Estimated absolute error.
76    pub error_estimate: S,
77    /// Number of function evaluations.
78    pub n_evaluations: usize,
79    /// Number of subinterval subdivisions performed.
80    pub n_subdivisions: usize,
81}
82
83// ============================================================================
84// Gauss-Kronrod G7K15 nodes and weights on [-1, 1]
85//
86// 15 Kronrod nodes (symmetric), with the 7 Gauss nodes being a subset.
87// We store only the positive half (8 nodes for K15, 4 for G7 subset).
88// Source: QUADPACK (Piessens et al.), 15-digit precision.
89// ============================================================================
90
91/// Kronrod abscissae (positive half, 8 values including 0)
92const K15_NODES: [f64; 8] = [
93    0.0,
94    0.2077849550078985,
95    0.4058451513773972,
96    0.5860872354676911,
97    0.7415311855993945,
98    0.8648644233597691,
99    0.9491079123427585,
100    0.9914553711208126,
101];
102
103/// Kronrod weights (for the 8 positive nodes, including node 0)
104const K15_WEIGHTS: [f64; 8] = [
105    0.2094821410847278,
106    0.2044329400752989,
107    0.1903505780647854,
108    0.1690047266392679,
109    0.1406532597155259,
110    0.1047900103222502,
111    0.0630920926299786,
112    0.0229353220105292,
113];
114
115/// Gauss weights for the G7 subset nodes (indices 0, 2, 4, 6 in K15_NODES)
116/// corresponding to nodes: 0, 0.4058..., 0.7415..., 0.9491...
117const G7_WEIGHTS: [f64; 4] = [
118    0.4179591836734694,
119    0.3818300505051189,
120    0.2797053914892767,
121    0.1294849661688697,
122];
123
124/// Apply G7K15 rule to a single interval [a, b].
125/// Returns (kronrod_result, gauss_result, n_evals).
126fn g7k15<S, F>(f: &mut F, a: S, b: S) -> (S, S, usize)
127where
128    S: Scalar,
129    F: FnMut(S) -> S,
130{
131    let mid = (a + b) * S::HALF;
132    let half_len = (b - a) * S::HALF;
133
134    let mut k15 = S::ZERO;
135    let mut g7 = S::ZERO;
136
137    // Node 0 (center)
138    let f_center = f(mid);
139    k15 += S::from_f64(K15_WEIGHTS[0]) * f_center;
140    g7 += S::from_f64(G7_WEIGHTS[0]) * f_center;
141
142    // Nodes 1, 3, 5, 7 are Kronrod-only (odd indices in the full 15-point rule)
143    for &i in &[1usize, 3, 5, 7] {
144        let x = half_len * S::from_f64(K15_NODES[i]);
145        let f_pos = f(mid + x);
146        let f_neg = f(mid - x);
147        k15 += S::from_f64(K15_WEIGHTS[i]) * (f_pos + f_neg);
148    }
149
150    // Nodes 2, 4, 6 are shared with G7 (even indices > 0)
151    for (g_idx, &k_idx) in [2usize, 4, 6].iter().enumerate() {
152        let x = half_len * S::from_f64(K15_NODES[k_idx]);
153        let f_pos = f(mid + x);
154        let f_neg = f(mid - x);
155        let fsum = f_pos + f_neg;
156        k15 += S::from_f64(K15_WEIGHTS[k_idx]) * fsum;
157        g7 += S::from_f64(G7_WEIGHTS[g_idx + 1]) * fsum;
158    }
159
160    (k15 * half_len, g7 * half_len, 15)
161}
162
163/// A subinterval with its integral estimate and error, for the priority queue.
164struct SubInterval<S: Scalar> {
165    a: S,
166    b: S,
167    result: S,
168    error: S,
169}
170
171impl<S: Scalar> PartialEq for SubInterval<S> {
172    fn eq(&self, other: &Self) -> bool {
173        self.error.to_f64() == other.error.to_f64()
174    }
175}
176
177impl<S: Scalar> Eq for SubInterval<S> {}
178
179impl<S: Scalar> PartialOrd for SubInterval<S> {
180    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
181        Some(self.cmp(other))
182    }
183}
184
185impl<S: Scalar> Ord for SubInterval<S> {
186    fn cmp(&self, other: &Self) -> Ordering {
187        // Max-heap by error
188        self.error
189            .to_f64()
190            .partial_cmp(&other.error.to_f64())
191            .unwrap_or(Ordering::Equal)
192    }
193}
194
195/// Adaptive Gauss-Kronrod quadrature (G7K15).
196///
197/// Integrates `f` over `[a, b]` using adaptive subdivision with a 15-point
198/// Kronrod rule and 7-point embedded Gauss rule for error estimation.
199///
200/// # Example
201///
202/// ```
203/// use numra_integrate::{quad, QuadOptions};
204///
205/// let result = quad(|x: f64| x.sin(), 0.0, std::f64::consts::PI, &QuadOptions::default()).unwrap();
206/// assert!((result.value - 2.0).abs() < 1e-10);
207/// ```
208pub fn quad<S, F>(
209    mut f: F,
210    a: S,
211    b: S,
212    opts: &QuadOptions<S>,
213) -> Result<QuadResult<S>, IntegrationError>
214where
215    S: Scalar,
216    F: FnMut(S) -> S,
217{
218    // Build initial list of subintervals, splitting at known singularity points
219    let mut breakpoints = Vec::new();
220    breakpoints.push(a);
221    for &p in &opts.points {
222        if p > a && p < b {
223            breakpoints.push(p);
224        }
225    }
226    breakpoints.push(b);
227    // Sort breakpoints
228    breakpoints.sort_by(|x, y| {
229        x.to_f64()
230            .partial_cmp(&y.to_f64())
231            .unwrap_or(Ordering::Equal)
232    });
233    // Remove duplicates
234    breakpoints.dedup_by(|a, b| ((*a) - (*b)).abs() < S::EPSILON);
235
236    let mut heap: BinaryHeap<SubInterval<S>> = BinaryHeap::new();
237    let mut total_result = S::ZERO;
238    let mut total_error = S::ZERO;
239    let mut total_evals = 0usize;
240    let mut n_subdivisions = 0usize;
241
242    // Initial pass: apply G7K15 to each breakpoint segment
243    for i in 0..breakpoints.len() - 1 {
244        let seg_a = breakpoints[i];
245        let seg_b = breakpoints[i + 1];
246        let (k15, g7, ne) = g7k15(&mut f, seg_a, seg_b);
247        let err = (k15 - g7).abs();
248        total_result += k15;
249        total_error += err;
250        total_evals += ne;
251        n_subdivisions += 1;
252
253        // Check for invalid values
254        if !k15.is_finite() {
255            let mid = (seg_a + seg_b) * S::HALF;
256            return Err(IntegrationError::InvalidValue { x: mid.to_f64() });
257        }
258
259        heap.push(SubInterval {
260            a: seg_a,
261            b: seg_b,
262            result: k15,
263            error: err,
264        });
265    }
266
267    // Check if already converged
268    let tol = opts.atol.max(opts.rtol * total_result.abs());
269    if total_error <= tol {
270        return Ok(QuadResult {
271            value: total_result,
272            error_estimate: total_error,
273            n_evaluations: total_evals,
274            n_subdivisions,
275        });
276    }
277
278    // Adaptive refinement
279    while n_subdivisions < opts.max_subdivisions {
280        let worst = match heap.pop() {
281            Some(w) => w,
282            None => break,
283        };
284
285        // Bisect the worst interval
286        let mid = (worst.a + worst.b) * S::HALF;
287
288        let (k15_l, g7_l, ne_l) = g7k15(&mut f, worst.a, mid);
289        let err_l = (k15_l - g7_l).abs();
290
291        let (k15_r, g7_r, ne_r) = g7k15(&mut f, mid, worst.b);
292        let err_r = (k15_r - g7_r).abs();
293
294        total_evals += ne_l + ne_r;
295        n_subdivisions += 1;
296
297        // Update totals: remove old, add new
298        total_result = total_result - worst.result + k15_l + k15_r;
299        total_error = total_error - worst.error + err_l + err_r;
300
301        if !k15_l.is_finite() || !k15_r.is_finite() {
302            return Err(IntegrationError::InvalidValue { x: mid.to_f64() });
303        }
304
305        heap.push(SubInterval {
306            a: worst.a,
307            b: mid,
308            result: k15_l,
309            error: err_l,
310        });
311        heap.push(SubInterval {
312            a: mid,
313            b: worst.b,
314            result: k15_r,
315            error: err_r,
316        });
317
318        let tol = opts.atol.max(opts.rtol * total_result.abs());
319        if total_error <= tol {
320            return Ok(QuadResult {
321                value: total_result,
322                error_estimate: total_error,
323                n_evaluations: total_evals,
324                n_subdivisions,
325            });
326        }
327    }
328
329    // Didn't converge but return the best result with an error
330    Err(IntegrationError::MaxSubdivisions {
331        subdivisions: n_subdivisions,
332        error_estimate: total_error.to_f64(),
333    })
334}
335
336#[cfg(test)]
337mod tests {
338    use super::*;
339    use approx::assert_relative_eq;
340
341    #[test]
342    fn test_quad_sin() {
343        // integral of sin(x) from 0 to pi = 2
344        let result = quad(
345            |x: f64| x.sin(),
346            0.0,
347            core::f64::consts::PI,
348            &QuadOptions::default(),
349        )
350        .unwrap();
351        assert_relative_eq!(result.value, 2.0, epsilon = 1e-10);
352        assert!(result.error_estimate < 1e-10);
353    }
354
355    #[test]
356    fn test_quad_exp() {
357        // integral of exp(x) from 0 to 1 = e - 1
358        let result = quad(|x: f64| x.exp(), 0.0, 1.0, &QuadOptions::default()).unwrap();
359        let expected = core::f64::consts::E - 1.0;
360        assert_relative_eq!(result.value, expected, epsilon = 1e-12);
361    }
362
363    #[test]
364    fn test_quad_polynomial() {
365        // integral of x^4 from 0 to 1 = 1/5
366        let result = quad(|x: f64| x.powi(4), 0.0, 1.0, &QuadOptions::default()).unwrap();
367        assert_relative_eq!(result.value, 0.2, epsilon = 1e-14);
368    }
369
370    #[test]
371    fn test_quad_singular_sqrt() {
372        // integral of 1/sqrt(x) from 0 to 1 = 2 (singular at x=0)
373        let opts = QuadOptions::default()
374            .atol(1e-8)
375            .rtol(1e-8)
376            .max_subdivisions(100)
377            .points(vec![0.0]);
378        let result = quad(
379            |x: f64| {
380                if x.abs() < 1e-300 {
381                    0.0
382                } else {
383                    1.0 / x.sqrt()
384                }
385            },
386            0.0,
387            1.0,
388            &opts,
389        )
390        .unwrap();
391        assert_relative_eq!(result.value, 2.0, epsilon = 1e-6);
392    }
393
394    #[test]
395    fn test_quad_oscillatory() {
396        // integral of sin(100x) from 0 to pi = (1 - cos(100*pi)) / 100 = 0
397        let opts = QuadOptions::default().max_subdivisions(200);
398        let result = quad(
399            |x: f64| (100.0 * x).sin(),
400            0.0,
401            core::f64::consts::PI,
402            &opts,
403        )
404        .unwrap();
405        assert!(result.value.abs() < 1e-6);
406    }
407
408    #[test]
409    fn test_quad_tight_tolerance() {
410        // integral of cos(x) from 0 to pi/2 = 1
411        let opts = QuadOptions::default().atol(1e-14).rtol(1e-14);
412        let result = quad(|x: f64| x.cos(), 0.0, core::f64::consts::FRAC_PI_2, &opts).unwrap();
413        assert_relative_eq!(result.value, 1.0, epsilon = 1e-13);
414    }
415
416    #[test]
417    fn test_quad_f32() {
418        // integral of sin(x) from 0 to pi = 2 in f32
419        let opts = QuadOptions::<f32>::default().atol(1e-4).rtol(1e-4);
420        let result = quad(|x: f32| x.sin(), 0.0f32, core::f32::consts::PI, &opts).unwrap();
421        assert!((result.value - 2.0).abs() < 1e-4);
422    }
423
424    #[test]
425    fn test_quad_gaussian() {
426        // integral of exp(-x^2) from -5 to 5 ≈ sqrt(pi)
427        let result = quad(|x: f64| (-x * x).exp(), -5.0, 5.0, &QuadOptions::default()).unwrap();
428        assert_relative_eq!(result.value, core::f64::consts::PI.sqrt(), epsilon = 1e-10);
429    }
430}