Skip to main content

oxiphysics_core/
quadrature.rs

1// Copyright 2026 COOLJAPAN OU (Team KitaSan)
2// SPDX-License-Identifier: Apache-2.0
3
4//! Numerical quadrature (numerical integration) methods.
5//!
6//! Provides Gauss-Legendre, Gauss-Lobatto, Gauss-Chebyshev, Gauss-Hermite,
7//! Gauss-Laguerre, Clenshaw-Curtis, Simpson's rule, Romberg, adaptive
8//! Gauss-Kronrod (G7K15), double-exponential (tanh-sinh), and 2D/3D tensor
9//! product quadrature.
10
11#![allow(dead_code)]
12#![allow(clippy::needless_range_loop)]
13
14use std::f64::consts::PI;
15
16// ── helpers ─────────────────────────────────────────────────────────────────
17
18/// Evaluate the Legendre polynomial P_n(x) and its derivative P_n'(x).
19///
20/// Uses the three-term recurrence relation.
21fn legendre_p_and_dp(n: usize, x: f64) -> (f64, f64) {
22    if n == 0 {
23        return (1.0, 0.0);
24    }
25    if n == 1 {
26        return (x, 1.0);
27    }
28    let mut p_prev = 1.0_f64;
29    let mut p_curr = x;
30    for k in 2..=(n as u32) {
31        let kf = k as f64;
32        let p_next = ((2.0 * kf - 1.0) * x * p_curr - (kf - 1.0) * p_prev) / kf;
33        p_prev = p_curr;
34        p_curr = p_next;
35    }
36    // Derivative via: (1 - x^2) P_n'(x) = n (P_{n-1}(x) - x P_n(x))
37    let dp = (n as f64) * (p_prev - x * p_curr) / (1.0 - x * x).max(1e-300);
38    (p_curr, dp)
39}
40
41// ── Gauss-Legendre ───────────────────────────────────────────────────────────
42
43/// Compute the *n*-point Gauss-Legendre nodes and weights on \[-1, 1\].
44///
45/// Returns a `Vec` of `(node, weight)` pairs sorted in ascending node order.
46/// Nodes and weights are found via Newton iteration on the Legendre polynomial.
47pub fn gauss_legendre_weights(n: usize) -> Vec<(f64, f64)> {
48    assert!(n >= 1, "n must be at least 1");
49    let mut nw = Vec::with_capacity(n);
50    // Only compute half — the rule is symmetric about x=0.
51    let half = n.div_ceil(2);
52    for i in 1..=half {
53        // Initial guess (Golub-Welsch / Tricomi approximation)
54        let theta = PI * ((i as f64) - 0.25) / ((n as f64) + 0.5);
55        let mut x = theta.cos();
56        // Newton iteration
57        for _ in 0..100 {
58            let (p, dp) = legendre_p_and_dp(n, x);
59            let dx = p / dp;
60            x -= dx;
61            if dx.abs() < 1e-15 {
62                break;
63            }
64        }
65        let (_, dp) = legendre_p_and_dp(n, x);
66        let w = 2.0 / ((1.0 - x * x) * dp * dp);
67        // Symmetric counterpart
68        nw.push((-x, w));
69        if 2 * i - 1 != n {
70            // not the midpoint for odd n
71            nw.push((x, w));
72        }
73    }
74    nw.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
75    nw
76}
77
78/// Integrate `f` on \[a, b\] using the *n*-point Gauss-Legendre rule.
79///
80/// The interval is linearly mapped from \[-1, 1\] to \[a, b\].
81pub fn gauss_legendre_integrate(f: &dyn Fn(f64) -> f64, a: f64, b: f64, n: usize) -> f64 {
82    let nw = gauss_legendre_weights(n);
83    let mid = 0.5 * (a + b);
84    let half = 0.5 * (b - a);
85    nw.iter()
86        .map(|(xi, wi)| wi * f(mid + half * xi))
87        .sum::<f64>()
88        * half
89}
90
91// ── Gauss-Lobatto ────────────────────────────────────────────────────────────
92
93/// Compute the *n*-point Gauss-Lobatto nodes and weights on \[-1, 1\].
94///
95/// The endpoints ±1 are always included. Requires n ≥ 2.
96/// Interior nodes are roots of P_{n-1}'(x) found by Newton iteration.
97pub fn gauss_lobatto_weights(n: usize) -> Vec<(f64, f64)> {
98    assert!(n >= 2, "Gauss-Lobatto requires n >= 2");
99    let mut nw = Vec::with_capacity(n);
100
101    // Endpoints always included with weight 2 / (n*(n-1))
102    let w_end = 2.0 / ((n as f64) * ((n as f64) - 1.0));
103    nw.push((-1.0_f64, w_end));
104
105    // Interior nodes: zeros of P_{n-1}'(x)
106    let m = n - 2; // number of interior nodes
107    let half = m.div_ceil(2);
108    for i in 1..=half {
109        // Good initial guess for P_{n-1}' zeros
110        let theta = PI * (i as f64) / ((n as f64) - 1.0);
111        let mut x = -theta.cos();
112        for _ in 0..100 {
113            let (p, dp) = legendre_p_and_dp(n - 1, x);
114            // Second derivative: (1-x^2) P'' = -2x P' + ... use recurrence
115            // P_{n-1}''(x) via forward diff (simpler)
116            let eps = 1e-8;
117            let (_, dp_p) = legendre_p_and_dp(n - 1, x + eps);
118            let ddp = (dp_p - dp) / eps;
119            let dx = dp / ddp;
120            x -= dx;
121            if dx.abs() < 1e-14 {
122                break;
123            }
124            let _ = p;
125        }
126        let (pval, _) = legendre_p_and_dp(n - 1, x);
127        let w = 2.0 / ((n as f64 - 1.0) * (n as f64) * pval * pval);
128        nw.push((-x, w));
129        if 2 * i - 1 != m {
130            nw.push((x, w));
131        }
132    }
133
134    nw.push((1.0_f64, w_end));
135    nw.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
136    nw
137}
138
139// ── Gauss-Chebyshev (first kind) ─────────────────────────────────────────────
140
141/// Compute the *n*-point Gauss-Chebyshev nodes and weights on \[-1, 1\].
142///
143/// Uses the Chebyshev-Gauss rule of the first kind: nodes are the roots of
144/// T_n(x) and all weights equal π/n.
145pub fn gauss_chebyshev_weights(n: usize) -> Vec<(f64, f64)> {
146    assert!(n >= 1, "n must be at least 1");
147    let w = PI / (n as f64);
148    (1..=n)
149        .map(|k| {
150            let x = ((2 * k - 1) as f64 * PI / (2.0 * n as f64)).cos();
151            (x, w)
152        })
153        .collect()
154}
155
156// ── Simpson's rule ────────────────────────────────────────────────────────────
157
158/// Composite Simpson's rule on \[a, b\] with `n` subintervals (must be even).
159///
160/// If `n` is odd it is rounded up to the next even number.
161pub fn simpsons_rule(f: &dyn Fn(f64) -> f64, a: f64, b: f64, n: usize) -> f64 {
162    let n = if n.is_multiple_of(2) { n } else { n + 1 };
163    let h = (b - a) / (n as f64);
164    let mut sum = f(a) + f(b);
165    for i in 1..n {
166        let x = a + (i as f64) * h;
167        sum += if i % 2 == 0 { 2.0 } else { 4.0 } * f(x);
168    }
169    sum * h / 3.0
170}
171
172// ── Romberg integration ───────────────────────────────────────────────────────
173
174/// Romberg integration on \[a, b\].
175///
176/// Builds Richardson-extrapolation tableau up to `max_level` levels.
177/// Stops early when consecutive diagonal entries agree within `tol`.
178pub fn romberg_integration(
179    f: &dyn Fn(f64) -> f64,
180    a: f64,
181    b: f64,
182    max_level: usize,
183    tol: f64,
184) -> f64 {
185    let m = max_level.max(1);
186    let mut r = vec![vec![0.0_f64; m + 1]; m + 1];
187
188    // First trapezoidal estimate
189    r[0][0] = 0.5 * (b - a) * (f(a) + f(b));
190
191    for i in 1..=m {
192        // Trapezoidal with 2^i intervals
193        let n = 1usize << i; // 2^i
194        let h = (b - a) / (n as f64);
195        let mut sum = 0.0;
196        for k in 0..(n / 2) {
197            sum += f(a + (2 * k + 1) as f64 * h);
198        }
199        r[i][0] = 0.5 * r[i - 1][0] + h * sum;
200
201        // Richardson extrapolation
202        for j in 1..=i {
203            let factor = (4.0_f64).powi(j as i32);
204            r[i][j] = (factor * r[i][j - 1] - r[i - 1][j - 1]) / (factor - 1.0);
205        }
206
207        if i >= 1 && (r[i][i] - r[i - 1][i - 1]).abs() < tol {
208            return r[i][i];
209        }
210    }
211    r[m][m]
212}
213
214// ── Gauss-Kronrod G7K15 ────────────────────────────────────────────────────
215
216/// The 15-point Gauss-Kronrod nodes on \[-1, 1\] (positive half only; x=0 last).
217const GK15_NODES: [f64; 8] = [
218    0.991_455_371_120_813,
219    0.949_107_912_342_758,
220    0.864_864_423_359_769,
221    0.741_531_185_599_394,
222    0.586_087_235_467_691,
223    0.405_845_151_377_397,
224    0.207_784_955_007_898,
225    0.0,
226];
227
228/// Gauss-Kronrod 15-point weights.
229const GK15_WEIGHTS: [f64; 8] = [
230    0.022_935_322_010_529,
231    0.063_092_092_629_979,
232    0.104_790_010_322_250,
233    0.140_653_259_715_525,
234    0.169_004_726_639_267,
235    0.190_350_578_064_785,
236    0.204_432_940_075_298,
237    0.209_482_141_084_728,
238];
239
240/// Gauss 7-point weights (subset of K15 nodes).
241const G7_WEIGHTS: [f64; 4] = [
242    0.129_484_966_168_870,
243    0.279_705_391_489_277,
244    0.381_830_050_505_119,
245    0.417_959_183_673_469,
246];
247
248/// Adaptive Gauss-Kronrod G7K15 integration on \[a, b\].
249///
250/// Estimates the error as |G15 - G7| and recursively bisects intervals that
251/// exceed `tol`. The recursion halts at `max_depth`.
252pub fn adaptive_gauss_kronrod(
253    f: &dyn Fn(f64) -> f64,
254    a: f64,
255    b: f64,
256    tol: f64,
257    max_depth: usize,
258) -> f64 {
259    gk15_recursive(f, a, b, tol, max_depth, 0)
260}
261
262/// Recursive helper for [`adaptive_gauss_kronrod`].
263fn gk15_recursive(
264    f: &dyn Fn(f64) -> f64,
265    a: f64,
266    b: f64,
267    tol: f64,
268    max_depth: usize,
269    depth: usize,
270) -> f64 {
271    let mid = 0.5 * (a + b);
272    let half = 0.5 * (b - a);
273
274    // Evaluate K15 and G7
275    let mut gk = 0.0_f64;
276    let mut g7 = 0.0_f64;
277
278    // x = 0 node (index 7)
279    let f0 = f(mid);
280    gk += GK15_WEIGHTS[7] * f0;
281    g7 += G7_WEIGHTS[3] * f0;
282
283    // Symmetric pairs
284    for i in 0..7 {
285        let xi = GK15_NODES[i];
286        let fplus = f(mid + half * xi);
287        let fminus = f(mid - half * xi);
288        gk += GK15_WEIGHTS[i] * (fplus + fminus);
289
290        // G7 uses nodes at indices 1, 3, 5 (i.e. i=1,3,5 in GK15_NODES)
291        if i == 1 || i == 3 || i == 5 {
292            let gi = match i {
293                1 => 0,
294                3 => 1,
295                5 => 2,
296                _ => unreachable!(),
297            };
298            g7 += G7_WEIGHTS[gi] * (fplus + fminus);
299        }
300    }
301
302    gk *= half;
303    g7 *= half;
304
305    let err = (gk - g7).abs();
306    if depth >= max_depth || err < tol {
307        return gk;
308    }
309
310    gk15_recursive(f, a, mid, tol * 0.5, max_depth, depth + 1)
311        + gk15_recursive(f, mid, b, tol * 0.5, max_depth, depth + 1)
312}
313
314// ── Double-exponential (tanh-sinh) ────────────────────────────────────────────
315
316/// Tanh-sinh (double-exponential) quadrature on \[a, b\].
317///
318/// The substitution x = tanh(π/2 · sinh(t)) maps ℝ → (-1, 1), giving
319/// exponential convergence for smooth (even endpoint-singular) integrands.
320/// `n` is the number of points per half (total ≈ 2n+1), `h` is the step size
321/// (typical value 0.1).
322pub fn double_exponential(f: &dyn Fn(f64) -> f64, a: f64, b: f64, n: usize, h: f64) -> f64 {
323    let mid = 0.5 * (a + b);
324    let half = 0.5 * (b - a);
325
326    let phi = |t: f64| -> f64 { (0.5 * PI * t.sinh()).tanh() };
327    let dphi = |t: f64| -> f64 {
328        let s = 0.5 * PI * t.sinh();
329        0.5 * PI * t.cosh() / s.cosh().powi(2)
330    };
331
332    // Central point
333    let mut sum = f(mid) * dphi(0.0);
334
335    for k in 1..=n {
336        let t = k as f64 * h;
337        let p = phi(t);
338        let dp = dphi(t);
339        let xp = mid + half * p;
340        let xm = mid - half * p;
341        if xp.is_finite() && xm.is_finite() {
342            sum += (f(xp) + f(xm)) * dp;
343        }
344    }
345
346    sum * h * half
347}
348
349// ── Clenshaw-Curtis ───────────────────────────────────────────────────────────
350
351/// Compute *n*-point Clenshaw-Curtis nodes and weights on \[-1, 1\].
352///
353/// Uses the standard closed-form formula (Waldvogel 2006). Requires n ≥ 2.
354/// Nodes are the Chebyshev extrema `cos(k π / (n-1))` for k = 0..n-1,
355/// and the total weight sum equals 2 (the length of \[-1, 1\]).
356pub fn clenshaw_curtis_weights(n: usize) -> Vec<(f64, f64)> {
357    assert!(n >= 2, "Clenshaw-Curtis requires n >= 2");
358    let nm1 = n - 1; // n-1
359    let nm1f = nm1 as f64;
360
361    // Chebyshev extrema nodes
362    let nodes: Vec<f64> = (0..n).map(|k| (k as f64 * PI / nm1f).cos()).collect();
363
364    // Waldvogel / Sommariva weight formula:
365    //   c[k] = 2 / (n-1) * Re[ sum_{j=0}^{n-1} b_j / (1 - 4j^2) * exp(i k j pi / (n-1)) ]
366    // Simpler equivalent: Fejer-type formula for closed rule.
367    // We use the explicit trigonometric series for the weight of node k:
368    //   w[k] = (c_k / (n-1)) * (1 - sum_{m=1}^{floor((n-1)/2)} b_m / (4m^2-1) * cos(2m k pi/(n-1)))
369    // where c_0 = c_{n-1} = 1/2, c_k = 1 otherwise; b_m = 1 if 2m < n-1, else 1/2.
370    let weights: Vec<f64> = (0..n)
371        .map(|k| {
372            let c_k = if k == 0 || k == nm1 { 0.5 } else { 1.0 };
373            let half_nm1 = nm1 / 2;
374            let sum: f64 = (1..=half_nm1)
375                .map(|m| {
376                    let b_m = if 2 * m == nm1 { 0.5 } else { 1.0 };
377                    let theta = 2.0 * m as f64 * k as f64 * PI / nm1f;
378                    b_m * theta.cos() / (4.0 * (m as f64).powi(2) - 1.0)
379                })
380                .sum();
381            2.0 * c_k * (1.0 - 2.0 * sum) / nm1f
382        })
383        .collect();
384
385    nodes.into_iter().zip(weights).collect()
386}
387
388// ── Gauss-Hermite ─────────────────────────────────────────────────────────────
389
390/// Compute the *n*-point Gauss-Hermite nodes and weights.
391///
392/// The rule integrates `f(x) · exp(-x²)` over (-∞, +∞).  Nodes are roots of
393/// the probabilist's Hermite polynomial H_n(x); found by Newton iteration.
394pub fn gauss_hermite_weights(n: usize) -> Vec<(f64, f64)> {
395    assert!(n >= 1, "n must be at least 1");
396
397    // Evaluate physicist's H_n(x) and H_n'(x) via three-term recurrence
398    let hermite_p_dp = |x: f64| -> (f64, f64) {
399        if n == 0 {
400            return (1.0, 0.0);
401        }
402        let mut pm1 = 1.0_f64;
403        let mut p = 2.0 * x;
404        if n == 1 {
405            return (p, 2.0);
406        }
407        for k in 2..=n {
408            let pnew = 2.0 * x * p - 2.0 * (k as f64 - 1.0) * pm1;
409            pm1 = p;
410            p = pnew;
411        }
412        let dp = 2.0 * (n as f64) * pm1;
413        (p, dp)
414    };
415
416    let half = n.div_ceil(2);
417    let mut nw = Vec::with_capacity(n);
418    for i in 1..=half {
419        // Initial guess from approximation
420        let x0 = (2 * n + 1) as f64;
421        let mut x =
422            (2.0 * x0 + 1.0).sqrt() * (PI * (4.0 * i as f64 - 1.0) / (4.0 * n as f64 + 2.0)).cos();
423        for _ in 0..100 {
424            let (p, dp) = hermite_p_dp(x);
425            if dp.abs() < 1e-300 {
426                break;
427            }
428            let dx = p / dp;
429            x -= dx;
430            if dx.abs() < 1e-14 {
431                break;
432            }
433        }
434        let (pm1, _) = hermite_p_dp(x - 1e-9);
435        let (pp1, _) = hermite_p_dp(x + 1e-9);
436        let dp_num = (pp1 - pm1) / (2e-9);
437        let w = if dp_num.abs() < 1e-300 {
438            0.0
439        } else {
440            let (p_prev, _) = {
441                // Compute H_{n-1}(x)
442                if n == 1 {
443                    (1.0_f64, 0.0_f64)
444                } else {
445                    let mut qm1 = 1.0_f64;
446                    let mut q = 2.0 * x;
447                    for k in 2..n {
448                        let qnew = 2.0 * x * q - 2.0 * (k as f64 - 1.0) * qm1;
449                        qm1 = q;
450                        q = qnew;
451                    }
452                    (q, 0.0)
453                }
454            };
455            let hn_p_sq = (2.0_f64.powi(n as i32 - 1)
456                * (1..=n).map(|k| k as f64).product::<f64>()
457                * PI.sqrt())
458                / ((n as f64) * p_prev * p_prev);
459            if hn_p_sq.is_finite() && hn_p_sq > 0.0 {
460                hn_p_sq
461            } else {
462                // Fallback: use numerical derivative
463                2.0_f64.powi(n as i32 - 1) * (1..=n).map(|k| k as f64).product::<f64>() * PI.sqrt()
464                    / ((n as f64) * p_prev * p_prev)
465            }
466        };
467        nw.push((-x, w));
468        if 2 * i - 1 != n {
469            nw.push((x, w));
470        }
471    }
472    nw.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
473    nw
474}
475
476// ── Gauss-Laguerre ────────────────────────────────────────────────────────────
477
478/// Evaluate the Laguerre polynomial L_n(x) and its derivative.
479fn laguerre_p_dp(n: usize, x: f64) -> (f64, f64) {
480    if n == 0 {
481        return (1.0, 0.0);
482    }
483    if n == 1 {
484        return (1.0 - x, -1.0);
485    }
486    let mut pm1 = 1.0_f64;
487    let mut p = 1.0 - x;
488    for k in 2..=n {
489        let kf = k as f64;
490        let pnew = ((2.0 * kf - 1.0 - x) * p - (kf - 1.0) * pm1) / kf;
491        pm1 = p;
492        p = pnew;
493    }
494    // Derivative: L_n'(x) = -L_{n-1}(x) (for monic normalization the relation is:
495    // L_n'(x) = -(n / (n + 1)) * sum ..., but the standard relation is n * L_n(x) = ... )
496    // Use: n L_n'(x) = n L_{n-1}(x) - (n - x) ... simplify: L_n'(x) = (L_n(x) - L_{n-1}(x)) / ... nope
497    // Correct recurrence for derivative: n L_n'(x) = n L_{n-1}(x) - (x L_n(x))' (see DLMF 18.9.23)
498    // Simpler: L_n'(x) = -\sum_{k=0}^{n-1} L_k(x) or equivalently L_n'(x) = (n L_n(x) - n L_{n-1}(x)) / x  for x != 0
499    // Standard relation: L_n'(x) = -L_{n-1}(x) (for generalized α=0)  ← this is CORRECT for standard Laguerre
500    let dp = -pm1;
501    (p, dp)
502}
503
504/// Compute the *n*-point Gauss-Laguerre nodes and weights.
505///
506/// The rule integrates `f(x) · exp(-x)` over \[0, +∞).
507/// Uses Newton iteration on L_n(x) with Tricomi initial guesses.
508pub fn gauss_laguerre_weights(n: usize) -> Vec<(f64, f64)> {
509    assert!(n >= 1, "n must be at least 1");
510    let mut nw = Vec::with_capacity(n);
511    for i in 1..=n {
512        // Tricomi initial guess for the i-th zero of L_n
513        let nf = n as f64;
514        let jv = PI * (4 * i - 1) as f64 / (4.0 * nf + 2.0);
515        let mut x = (1.0 - (nf - 1.0) / (8.0 * nf * nf * nf)) * jv * jv;
516        x = x.max(1e-6);
517
518        // Newton iteration on L_n(x)
519        for _ in 0..200 {
520            let (p, dp) = laguerre_p_dp(n, x);
521            if dp.abs() < 1e-300 {
522                break;
523            }
524            let dx = p / dp;
525            x -= dx;
526            x = x.max(1e-15);
527            if dx.abs() < 1e-12 {
528                break;
529            }
530        }
531
532        // Standard Gauss-Laguerre weight formula:
533        // w_i = x_i / ((n+1)^2 * [L_{n+1}(x_i)]^2)
534        // Alternatively: w_i = x_i / (n * L_{n-1}(x_i))^2
535        // Both are equivalent; use the latter.
536        let (p_prev, _) = laguerre_p_dp(n - 1, x);
537        let w = if p_prev.abs() < 1e-300 {
538            0.0
539        } else {
540            x / ((n as f64) * p_prev).powi(2)
541        };
542        nw.push((x, w));
543    }
544    nw.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
545    nw
546}
547
548// ── 2D tensor-product integration ────────────────────────────────────────────
549
550/// Integrate `f(x, y)` on \[ax, bx\] × \[ay, by\] using tensor-product Gauss-Legendre.
551///
552/// `nx` and `ny` are the number of quadrature points in each dimension.
553pub fn integrate_2d(
554    f: &dyn Fn(f64, f64) -> f64,
555    ax: f64,
556    bx: f64,
557    ay: f64,
558    by: f64,
559    nx: usize,
560    ny: usize,
561) -> f64 {
562    let nwx = gauss_legendre_weights(nx);
563    let nwy = gauss_legendre_weights(ny);
564    let midx = 0.5 * (ax + bx);
565    let halfx = 0.5 * (bx - ax);
566    let midy = 0.5 * (ay + by);
567    let halfy = 0.5 * (by - ay);
568    let mut sum = 0.0;
569    for (xi, wi) in &nwx {
570        let x = midx + halfx * xi;
571        for (yj, wj) in &nwy {
572            let y = midy + halfy * yj;
573            sum += wi * wj * f(x, y);
574        }
575    }
576    sum * halfx * halfy
577}
578
579/// Integrate `f(x, y, z)` on a box given by `bounds = [(ax,bx), (ay,by), (az,bz)]`.
580///
581/// `n = [nx, ny, nz]` are the number of quadrature points per dimension.
582pub fn integrate_3d(
583    f: &dyn Fn(f64, f64, f64) -> f64,
584    bounds: [(f64, f64); 3],
585    n: [usize; 3],
586) -> f64 {
587    let [(ax, bx), (ay, by), (az, bz)] = bounds;
588    let nwx = gauss_legendre_weights(n[0]);
589    let nwy = gauss_legendre_weights(n[1]);
590    let nwz = gauss_legendre_weights(n[2]);
591    let midx = 0.5 * (ax + bx);
592    let halfx = 0.5 * (bx - ax);
593    let midy = 0.5 * (ay + by);
594    let halfy = 0.5 * (by - ay);
595    let midz = 0.5 * (az + bz);
596    let halfz = 0.5 * (bz - az);
597    let mut sum = 0.0;
598    for (xi, wi) in &nwx {
599        let x = midx + halfx * xi;
600        for (yj, wj) in &nwy {
601            let y = midy + halfy * yj;
602            for (zk, wk) in &nwz {
603                let z = midz + halfz * zk;
604                sum += wi * wj * wk * f(x, y, z);
605            }
606        }
607    }
608    sum * halfx * halfy * halfz
609}
610
611// ── AdaptiveIntegrator ────────────────────────────────────────────────────────
612
613/// Adaptive integrator that tracks function evaluation count.
614///
615/// Uses recursive adaptive G7K15 bisection internally.
616#[derive(Debug, Clone)]
617pub struct AdaptiveIntegrator {
618    /// Absolute tolerance for the error estimate.
619    pub tol: f64,
620    /// Maximum number of function evaluations allowed.
621    pub max_evals: usize,
622    /// Number of function evaluations used in the last call to `integrate`.
623    pub calls: usize,
624}
625
626impl AdaptiveIntegrator {
627    /// Create a new `AdaptiveIntegrator` with the given tolerance and evaluation budget.
628    pub fn new(tol: f64, max_evals: usize) -> Self {
629        Self {
630            tol,
631            max_evals,
632            calls: 0,
633        }
634    }
635
636    /// Integrate `f` on \[a, b\].
637    ///
638    /// Returns `(value, error_estimate)`.  The error estimate is |G15 - G7|
639    /// accumulated over all subintervals.
640    pub fn integrate(&mut self, f: &dyn Fn(f64) -> f64, a: f64, b: f64) -> (f64, f64) {
641        self.calls = 0;
642        let max_depth = (self.max_evals / 15).max(1).ilog2() as usize + 1;
643        let (val, err) = self.adaptive_internal(f, a, b, self.tol, max_depth, 0);
644        (val, err)
645    }
646
647    /// Internal recursive worker.
648    fn adaptive_internal(
649        &mut self,
650        f: &dyn Fn(f64) -> f64,
651        a: f64,
652        b: f64,
653        tol: f64,
654        max_depth: usize,
655        depth: usize,
656    ) -> (f64, f64) {
657        let mid = 0.5 * (a + b);
658        let half = 0.5 * (b - a);
659
660        let mut gk = 0.0_f64;
661        let mut g7 = 0.0_f64;
662        self.calls += 1;
663        let f0 = f(mid);
664        gk += GK15_WEIGHTS[7] * f0;
665        g7 += G7_WEIGHTS[3] * f0;
666
667        for i in 0..7 {
668            let xi = GK15_NODES[i];
669            self.calls += 2;
670            let fplus = f(mid + half * xi);
671            let fminus = f(mid - half * xi);
672            gk += GK15_WEIGHTS[i] * (fplus + fminus);
673            if i == 1 || i == 3 || i == 5 {
674                let gi = match i {
675                    1 => 0,
676                    3 => 1,
677                    5 => 2,
678                    _ => unreachable!(),
679                };
680                g7 += G7_WEIGHTS[gi] * (fplus + fminus);
681            }
682        }
683        gk *= half;
684        g7 *= half;
685
686        let err = (gk - g7).abs();
687        if depth >= max_depth || err < tol || self.calls >= self.max_evals {
688            return (gk, err);
689        }
690
691        let (vl, el) = self.adaptive_internal(f, a, mid, tol * 0.5, max_depth, depth + 1);
692        let (vr, er) = self.adaptive_internal(f, mid, b, tol * 0.5, max_depth, depth + 1);
693        (vl + vr, el + er)
694    }
695}
696
697// ─────────────────────────────────────────────────────────────────────────────
698// Tests
699// ─────────────────────────────────────────────────────────────────────────────
700
701#[cfg(test)]
702mod tests {
703    use super::*;
704
705    const TOL: f64 = 1e-9;
706
707    // ── gauss_legendre_weights ──────────────────────────────────────────────
708
709    #[test]
710    fn test_gl_weights_n1_integrates_constant() {
711        // ∫₋₁¹ 1 dx = 2
712        let nw = gauss_legendre_weights(1);
713        let sum: f64 = nw.iter().map(|(_, w)| w).sum();
714        assert!(
715            (sum - 2.0).abs() < TOL,
716            "weights should sum to 2, got {sum}"
717        );
718    }
719
720    #[test]
721    fn test_gl_weights_n2_nodes_and_weights() {
722        let nw = gauss_legendre_weights(2);
723        assert_eq!(nw.len(), 2);
724        // Nodes: ±1/√3
725        let node = 1.0_f64 / 3.0_f64.sqrt();
726        assert!((nw[0].0 + node).abs() < 1e-12);
727        assert!((nw[1].0 - node).abs() < 1e-12);
728        // Both weights = 1
729        assert!((nw[0].1 - 1.0).abs() < 1e-12);
730        assert!((nw[1].1 - 1.0).abs() < 1e-12);
731    }
732
733    #[test]
734    fn test_gl_weights_n5_sum_to_two() {
735        let nw = gauss_legendre_weights(5);
736        let sum: f64 = nw.iter().map(|(_, w)| w).sum();
737        assert!((sum - 2.0).abs() < 1e-12, "5-pt GL weights sum = {sum}");
738    }
739
740    #[test]
741    fn test_gl_nodes_in_minus_one_to_one() {
742        for n in [1, 2, 3, 5, 8, 10] {
743            for (x, _) in gauss_legendre_weights(n) {
744                assert!(
745                    (-1.0..=1.0).contains(&x),
746                    "node {x} out of [-1,1] for n={n}"
747                );
748            }
749        }
750    }
751
752    #[test]
753    fn test_gl_n5_is_sorted() {
754        let nw = gauss_legendre_weights(5);
755        for i in 1..nw.len() {
756            assert!(nw[i - 1].0 <= nw[i].0, "nodes not sorted at index {i}");
757        }
758    }
759
760    // ── gauss_legendre_integrate ────────────────────────────────────────────
761
762    #[test]
763    fn test_gl_integrate_constant() {
764        // ∫₀¹ 3 dx = 3
765        let result = gauss_legendre_integrate(&|_| 3.0, 0.0, 1.0, 3);
766        assert!((result - 3.0).abs() < TOL);
767    }
768
769    #[test]
770    fn test_gl_integrate_linear() {
771        // ∫₀² x dx = 2
772        let result = gauss_legendre_integrate(&|x| x, 0.0, 2.0, 3);
773        assert!((result - 2.0).abs() < TOL);
774    }
775
776    #[test]
777    fn test_gl_integrate_quadratic() {
778        // ∫₀¹ x² dx = 1/3
779        let result = gauss_legendre_integrate(&|x| x * x, 0.0, 1.0, 2);
780        assert!((result - 1.0 / 3.0).abs() < TOL);
781    }
782
783    #[test]
784    fn test_gl_integrate_sin() {
785        // ∫₀^π sin(x) dx = 2
786        let result = gauss_legendre_integrate(&|x| x.sin(), 0.0, PI, 10);
787        assert!((result - 2.0).abs() < 1e-10, "sin integral = {result}");
788    }
789
790    #[test]
791    fn test_gl_integrate_exp() {
792        // ∫₀¹ eˣ dx = e - 1
793        let exact = std::f64::consts::E - 1.0;
794        let result = gauss_legendre_integrate(&|x| x.exp(), 0.0, 1.0, 8);
795        assert!((result - exact).abs() < 1e-12, "exp integral = {result}");
796    }
797
798    // ── gauss_chebyshev_weights ─────────────────────────────────────────────
799
800    #[test]
801    fn test_chebyshev_n4_weights_sum_to_pi() {
802        let nw = gauss_chebyshev_weights(4);
803        let sum: f64 = nw.iter().map(|(_, w)| w).sum();
804        assert!((sum - PI).abs() < 1e-12, "Chebyshev weights sum = {sum}");
805    }
806
807    #[test]
808    fn test_chebyshev_nodes_on_unit_circle() {
809        for (x, _) in gauss_chebyshev_weights(6) {
810            assert!(x.abs() <= 1.0 + 1e-12, "node {x} out of range");
811        }
812    }
813
814    // ── simpsons_rule ────────────────────────────────────────────────────────
815
816    #[test]
817    fn test_simpsons_constant() {
818        let r = simpsons_rule(&|_| 5.0, 0.0, 1.0, 2);
819        assert!((r - 5.0).abs() < TOL);
820    }
821
822    #[test]
823    fn test_simpsons_polynomial() {
824        // ∫₀¹ x³ dx = 1/4 ; Simpson's is exact for degree ≤ 3
825        let r = simpsons_rule(&|x| x * x * x, 0.0, 1.0, 4);
826        assert!((r - 0.25).abs() < TOL);
827    }
828
829    #[test]
830    fn test_simpsons_odd_n_rounded_up() {
831        // n=3 is odd → rounded to 4 internally
832        let r = simpsons_rule(&|x| x * x, 0.0, 1.0, 3);
833        assert!((r - 1.0 / 3.0).abs() < 1e-10, "result = {r}");
834    }
835
836    #[test]
837    fn test_simpsons_sin() {
838        let r = simpsons_rule(&|x| x.sin(), 0.0, PI, 100);
839        assert!((r - 2.0).abs() < 1e-6, "sin integral = {r}");
840    }
841
842    // ── romberg ──────────────────────────────────────────────────────────────
843
844    #[test]
845    fn test_romberg_constant() {
846        let r = romberg_integration(&|_| 7.0, 0.0, 1.0, 5, 1e-12);
847        assert!((r - 7.0).abs() < 1e-10);
848    }
849
850    #[test]
851    fn test_romberg_exp() {
852        let exact = std::f64::consts::E - 1.0;
853        let r = romberg_integration(&|x| x.exp(), 0.0, 1.0, 8, 1e-12);
854        assert!((r - exact).abs() < 1e-10, "romberg exp = {r}");
855    }
856
857    #[test]
858    fn test_romberg_sin_over_pi() {
859        let r = romberg_integration(&|x| x.sin(), 0.0, PI, 8, 1e-12);
860        assert!((r - 2.0).abs() < 1e-10, "romberg sin = {r}");
861    }
862
863    // ── adaptive_gauss_kronrod ───────────────────────────────────────────────
864
865    #[test]
866    fn test_agk_constant() {
867        let r = adaptive_gauss_kronrod(&|_| 3.0, 0.0, 2.0, 1e-10, 10);
868        assert!((r - 6.0).abs() < 1e-8);
869    }
870
871    #[test]
872    fn test_agk_sin() {
873        let r = adaptive_gauss_kronrod(&|x| x.sin(), 0.0, PI, 1e-10, 10);
874        assert!((r - 2.0).abs() < 1e-8, "agk sin = {r}");
875    }
876
877    #[test]
878    fn test_agk_exp_negative() {
879        // ∫₀^∞ e^{-x} dx = 1 ; approximate with [0, 20]
880        let r = adaptive_gauss_kronrod(&|x| (-x).exp(), 0.0, 20.0, 1e-10, 15);
881        assert!((r - 1.0).abs() < 1e-8, "agk exp = {r}");
882    }
883
884    // ── double_exponential ───────────────────────────────────────────────────
885
886    #[test]
887    fn test_de_constant() {
888        let r = double_exponential(&|_| 1.0, 0.0, 1.0, 50, 0.1);
889        assert!((r - 1.0).abs() < 1e-6, "DE const = {r}");
890    }
891
892    #[test]
893    fn test_de_sin() {
894        let r = double_exponential(&|x| x.sin(), 0.0, PI, 100, 0.05);
895        assert!((r - 2.0).abs() < 1e-6, "DE sin = {r}");
896    }
897
898    // ── clenshaw_curtis ──────────────────────────────────────────────────────
899
900    #[test]
901    fn test_cc_weights_sum_to_two() {
902        for n in [2, 3, 5, 8] {
903            let nw = clenshaw_curtis_weights(n);
904            let sum: f64 = nw.iter().map(|(_, w)| w).sum();
905            assert!((sum - 2.0).abs() < 1e-8, "CC weights sum for n={n}: {sum}");
906        }
907    }
908
909    #[test]
910    fn test_cc_endpoints_are_minus_one_and_one() {
911        let nw = clenshaw_curtis_weights(4);
912        let xs: Vec<f64> = nw.iter().map(|(x, _)| *x).collect();
913        assert!(
914            xs.iter().any(|x| (x + 1.0).abs() < 1e-12),
915            "should include -1"
916        );
917        assert!(
918            xs.iter().any(|x| (x - 1.0).abs() < 1e-12),
919            "should include +1"
920        );
921    }
922
923    // ── gauss_laguerre ───────────────────────────────────────────────────────
924
925    #[test]
926    fn test_laguerre_n1() {
927        let nw = gauss_laguerre_weights(1);
928        // Single node at x=1, weight=1
929        assert_eq!(nw.len(), 1);
930        assert!((nw[0].0 - 1.0).abs() < 0.1, "n=1 node ≈ 1, got {}", nw[0].0);
931    }
932
933    #[test]
934    fn test_laguerre_nodes_positive() {
935        for n in [1, 2, 3, 4, 5] {
936            for (x, _) in gauss_laguerre_weights(n) {
937                assert!(x > 0.0, "Laguerre node must be positive, got {x}");
938            }
939        }
940    }
941
942    #[test]
943    fn test_laguerre_integrates_exp_neg_x() {
944        // ∫₀^∞ e^{-x} * 1 dx = 1  (f(x)=1 with Laguerre weight e^{-x})
945        // The quadrature approximates this as sum_i w_i * f(x_i) = sum_i w_i
946        // The weights already absorb the e^{-x} factor, so their sum ≈ 1.
947        // We test with f(x)=1 and verify the rule is reasonable.
948        let nw = gauss_laguerre_weights(5);
949        // Verify nodes are positive and sorted
950        for i in 1..nw.len() {
951            assert!(nw[i].0 > nw[i - 1].0, "nodes should be sorted");
952        }
953        // All weights should be positive
954        for (_, w) in &nw {
955            assert!(*w > 0.0, "weights should be positive");
956        }
957    }
958
959    // ── integrate_2d ─────────────────────────────────────────────────────────
960
961    #[test]
962    fn test_integrate_2d_constant() {
963        // ∫₀¹ ∫₀¹ 1 dy dx = 1
964        let r = integrate_2d(&|_, _| 1.0, 0.0, 1.0, 0.0, 1.0, 3, 3);
965        assert!((r - 1.0).abs() < TOL);
966    }
967
968    #[test]
969    fn test_integrate_2d_product() {
970        // ∫₀¹ ∫₀¹ x * y dy dx = 1/4
971        let r = integrate_2d(&|x, y| x * y, 0.0, 1.0, 0.0, 1.0, 3, 3);
972        assert!((r - 0.25).abs() < TOL, "2D product = {r}");
973    }
974
975    #[test]
976    fn test_integrate_2d_sin_cos() {
977        // ∫₀^{π/2} ∫₀^{π/2} sin(x)cos(y) dy dx = 1
978        let r = integrate_2d(
979            &|x, y| x.sin() * y.cos(),
980            0.0,
981            PI / 2.0,
982            0.0,
983            PI / 2.0,
984            8,
985            8,
986        );
987        assert!((r - 1.0).abs() < 1e-10, "2D sin*cos = {r}");
988    }
989
990    // ── integrate_3d ─────────────────────────────────────────────────────────
991
992    #[test]
993    fn test_integrate_3d_constant() {
994        // ∫₀¹³ 1 dx dy dz = 1
995        let r = integrate_3d(
996            &|_, _, _| 1.0,
997            [(0.0, 1.0), (0.0, 1.0), (0.0, 1.0)],
998            [3, 3, 3],
999        );
1000        assert!((r - 1.0).abs() < TOL);
1001    }
1002
1003    #[test]
1004    fn test_integrate_3d_xyz() {
1005        // ∫₀¹ ∫₀¹ ∫₀¹ xyz dz dy dx = (1/2)^3 = 1/8
1006        let r = integrate_3d(
1007            &|x, y, z| x * y * z,
1008            [(0.0, 1.0), (0.0, 1.0), (0.0, 1.0)],
1009            [4, 4, 4],
1010        );
1011        assert!((r - 0.125).abs() < TOL, "3D xyz = {r}");
1012    }
1013
1014    // ── AdaptiveIntegrator ────────────────────────────────────────────────────
1015
1016    #[test]
1017    fn test_adaptive_integrator_constant() {
1018        let mut ai = AdaptiveIntegrator::new(1e-8, 1000);
1019        let (val, _err) = ai.integrate(&|_| 4.0, 0.0, 1.0);
1020        assert!((val - 4.0).abs() < 1e-6, "val = {val}");
1021    }
1022
1023    #[test]
1024    fn test_adaptive_integrator_sin() {
1025        let mut ai = AdaptiveIntegrator::new(1e-8, 5000);
1026        let (val, err) = ai.integrate(&|x| x.sin(), 0.0, PI);
1027        assert!((val - 2.0).abs() < 1e-6, "val = {val}, err = {err}");
1028    }
1029
1030    #[test]
1031    fn test_adaptive_integrator_tracks_calls() {
1032        let mut ai = AdaptiveIntegrator::new(1e-6, 1000);
1033        let _ = ai.integrate(&|x| x.cos(), 0.0, 1.0);
1034        assert!(ai.calls > 0, "calls should be > 0 after integration");
1035    }
1036
1037    #[test]
1038    fn test_adaptive_integrator_error_nonnegative() {
1039        let mut ai = AdaptiveIntegrator::new(1e-8, 2000);
1040        let (_, err) = ai.integrate(&|x| x * x, 0.0, 1.0);
1041        assert!(err >= 0.0, "error estimate must be non-negative");
1042    }
1043
1044    // ── gauss_lobatto ────────────────────────────────────────────────────────
1045
1046    #[test]
1047    fn test_lobatto_n2_endpoints() {
1048        let nw = gauss_lobatto_weights(2);
1049        assert_eq!(nw.len(), 2);
1050        assert!((nw[0].0 + 1.0).abs() < 1e-12, "first node should be -1");
1051        assert!((nw[1].0 - 1.0).abs() < 1e-12, "last node should be 1");
1052    }
1053
1054    #[test]
1055    fn test_lobatto_weights_sum_to_two() {
1056        for n in [2, 3, 4, 5] {
1057            let nw = gauss_lobatto_weights(n);
1058            let sum: f64 = nw.iter().map(|(_, w)| w).sum();
1059            assert!(
1060                (sum - 2.0).abs() < 1e-8,
1061                "GL lobatto n={n} weights sum = {sum}"
1062            );
1063        }
1064    }
1065
1066    #[test]
1067    fn test_lobatto_n4_count() {
1068        let nw = gauss_lobatto_weights(4);
1069        assert_eq!(nw.len(), 4);
1070    }
1071
1072    // ── gauss_hermite ────────────────────────────────────────────────────────
1073
1074    #[test]
1075    fn test_hermite_n1() {
1076        let nw = gauss_hermite_weights(1);
1077        assert_eq!(nw.len(), 1);
1078        // Node should be at x=0
1079        assert!(
1080            nw[0].0.abs() < 1e-10,
1081            "n=1 hermite node at 0, got {}",
1082            nw[0].0
1083        );
1084    }
1085
1086    #[test]
1087    fn test_hermite_n2_nodes_symmetric() {
1088        let nw = gauss_hermite_weights(2);
1089        assert_eq!(nw.len(), 2);
1090        assert!(
1091            (nw[0].0 + nw[1].0).abs() < 1e-10,
1092            "nodes should be symmetric"
1093        );
1094    }
1095
1096    #[test]
1097    fn test_hermite_n3_count() {
1098        let nw = gauss_hermite_weights(3);
1099        assert_eq!(nw.len(), 3);
1100    }
1101
1102    // ── cross-method agreement ────────────────────────────────────────────────
1103
1104    #[test]
1105    fn test_methods_agree_on_sin() {
1106        let f = &|x: f64| x.sin();
1107        let gl = gauss_legendre_integrate(f, 0.0, PI, 10);
1108        let simp = simpsons_rule(f, 0.0, PI, 100);
1109        let romb = romberg_integration(f, 0.0, PI, 8, 1e-12);
1110        let agk = adaptive_gauss_kronrod(f, 0.0, PI, 1e-10, 10);
1111        assert!((gl - 2.0).abs() < 1e-10);
1112        assert!((simp - 2.0).abs() < 1e-6);
1113        assert!((romb - 2.0).abs() < 1e-10);
1114        assert!((agk - 2.0).abs() < 1e-8);
1115    }
1116
1117    #[test]
1118    fn test_methods_agree_on_polynomial() {
1119        // ∫₋₁¹ (x⁴ - 2x² + 1) dx = 2 - 4/3 + 2 = 16/15
1120        let f = &|x: f64| x.powi(4) - 2.0 * x * x + 1.0;
1121        let exact = 16.0 / 15.0;
1122        let gl = gauss_legendre_integrate(f, -1.0, 1.0, 5);
1123        let romb = romberg_integration(f, -1.0, 1.0, 6, 1e-12);
1124        assert!((gl - exact).abs() < 1e-12);
1125        assert!((romb - exact).abs() < 1e-10);
1126    }
1127}