Skip to main content

oxiphysics_core/
spectral_methods.rs

1#![allow(clippy::needless_range_loop, clippy::ptr_arg)]
2// Copyright 2026 COOLJAPAN OU (Team KitaSan)
3// SPDX-License-Identifier: Apache-2.0
4
5//! Spectral methods for numerical analysis and physics simulations.
6//!
7//! Provides Chebyshev and Legendre polynomial evaluation, Gauss quadrature
8//! nodes and weights, Cooley-Tukey FFT/IFFT, pseudo-spectral differentiation,
9//! Chebyshev collocation for 1D boundary value problems, and Haar wavelet
10//! multi-level decomposition/reconstruction.
11
12#![allow(dead_code)]
13#![allow(clippy::too_many_arguments)]
14
15use std::f64::consts::PI;
16
17// ─────────────────────────────────────────────────────────────────────────────
18// Complex number (local, minimal)
19// ─────────────────────────────────────────────────────────────────────────────
20
21/// Minimal complex number for FFT computations.
22#[derive(Clone, Copy, Debug, PartialEq)]
23struct Cx {
24    re: f64,
25    im: f64,
26}
27
28impl Cx {
29    #[inline]
30    fn new(re: f64, im: f64) -> Self {
31        Self { re, im }
32    }
33    #[inline]
34    fn from_polar(r: f64, theta: f64) -> Self {
35        Self {
36            re: r * theta.cos(),
37            im: r * theta.sin(),
38        }
39    }
40    #[inline]
41    fn add(self, rhs: Self) -> Self {
42        Self {
43            re: self.re + rhs.re,
44            im: self.im + rhs.im,
45        }
46    }
47    #[inline]
48    fn sub(self, rhs: Self) -> Self {
49        Self {
50            re: self.re - rhs.re,
51            im: self.im - rhs.im,
52        }
53    }
54    #[inline]
55    fn mul(self, rhs: Self) -> Self {
56        Self {
57            re: self.re * rhs.re - self.im * rhs.im,
58            im: self.re * rhs.im + self.im * rhs.re,
59        }
60    }
61    #[inline]
62    fn abs(self) -> f64 {
63        self.re.hypot(self.im)
64    }
65    #[inline]
66    fn scale(self, s: f64) -> Self {
67        Self {
68            re: self.re * s,
69            im: self.im * s,
70        }
71    }
72}
73
74// ─────────────────────────────────────────────────────────────────────────────
75// ChebyshevPolynomial
76// ─────────────────────────────────────────────────────────────────────────────
77
78/// Chebyshev polynomials of the first kind: T_n(x).
79///
80/// Provides evaluation via the three-term recurrence, Chebyshev nodes on
81/// \[-1, 1\], differentiation matrix for spectral differentiation, and
82/// interpolation coefficients from function values at Chebyshev nodes.
83pub struct ChebyshevPolynomial;
84
85impl ChebyshevPolynomial {
86    /// Evaluate T_n(x) using the recurrence T_0=1, T_1=x, T_{n+1}=2x T_n - T_{n-1}.
87    pub fn eval(n: usize, x: f64) -> f64 {
88        if n == 0 {
89            return 1.0;
90        }
91        if n == 1 {
92            return x;
93        }
94        let mut t_prev = 1.0f64;
95        let mut t_curr = x;
96        for _ in 2..=n {
97            let t_next = 2.0 * x * t_curr - t_prev;
98            t_prev = t_curr;
99            t_curr = t_next;
100        }
101        t_curr
102    }
103
104    /// Evaluate all Chebyshev polynomials T_0(x), …, T_n(x).
105    pub fn eval_all(n: usize, x: f64) -> Vec<f64> {
106        let mut ts = vec![0.0f64; n + 1];
107        ts[0] = 1.0;
108        if n >= 1 {
109            ts[1] = x;
110        }
111        for k in 2..=n {
112            ts[k] = 2.0 * x * ts[k - 1] - ts[k - 2];
113        }
114        ts
115    }
116
117    /// Return the `n+1` Chebyshev-Gauss-Lobatto nodes on \[-1, 1\]:
118    /// `x_j = cos(j*pi/n)`, j = 0, …, n.
119    pub fn nodes(n: usize) -> Vec<f64> {
120        (0..=n).map(|j| (j as f64 * PI / n as f64).cos()).collect()
121    }
122
123    /// Return the `n` interior Chebyshev-Gauss nodes on (-1, 1):
124    /// `x_j = cos((2j+1)*pi / (2n))`, j = 0, …, n-1.
125    pub fn gauss_nodes(n: usize) -> Vec<f64> {
126        (0..n)
127            .map(|j| ((2 * j + 1) as f64 * PI / (2 * n) as f64).cos())
128            .collect()
129    }
130
131    /// Compute the `(n+1) x (n+1)` spectral differentiation matrix at
132    /// Chebyshev-Gauss-Lobatto nodes.
133    ///
134    /// The entry `D[i][j]` approximates the derivative of the `j`-th basis
135    /// function at node `x_i`.
136    pub fn diff_matrix(n: usize) -> Vec<Vec<f64>> {
137        let nodes = Self::nodes(n);
138        let m = n + 1;
139        let mut d = vec![vec![0.0f64; m]; m];
140
141        let c = |i: usize| -> f64 { if i == 0 || i == n { 2.0 } else { 1.0 } };
142
143        for i in 0..m {
144            for j in 0..m {
145                if i != j {
146                    d[i][j] =
147                        c(i) / c(j) * ((-1.0f64).powi((i + j) as i32)) / (nodes[i] - nodes[j]);
148                }
149            }
150            // Diagonal: negative sum of off-diagonal row entries
151            let row_sum: f64 = (0..m).filter(|&k| k != i).map(|k| d[i][k]).sum();
152            d[i][i] = -row_sum;
153        }
154        d
155    }
156
157    /// Compute Chebyshev expansion coefficients from function values at
158    /// Chebyshev-Gauss-Lobatto nodes using the DCT-I formula.
159    ///
160    /// Returns coefficients `a_k` such that `f(x) ≈ Σ a_k T_k(x)`.
161    pub fn interpolation_coeffs(vals: &[f64]) -> Vec<f64> {
162        let n = vals.len() - 1;
163        let m = n + 1;
164        let mut coeffs = vec![0.0f64; m];
165        for k in 0..m {
166            let norm = if k == 0 || k == n {
167                n as f64
168            } else {
169                n as f64 / 2.0
170            };
171            let sum: f64 = (0..m)
172                .map(|j| {
173                    let w = if j == 0 || j == n { 0.5 } else { 1.0 };
174                    w * vals[j] * (k as f64 * j as f64 * PI / n as f64).cos()
175                })
176                .sum();
177            coeffs[k] = sum / norm;
178        }
179        coeffs
180    }
181
182    /// Evaluate the Chebyshev series with coefficients `coeffs` at point `x`.
183    pub fn eval_series(coeffs: &[f64], x: f64) -> f64 {
184        coeffs
185            .iter()
186            .enumerate()
187            .map(|(k, &ck)| ck * Self::eval(k, x))
188            .sum()
189    }
190}
191
192// ─────────────────────────────────────────────────────────────────────────────
193// LegendrePolynomial
194// ─────────────────────────────────────────────────────────────────────────────
195
196/// Legendre polynomials P_n(x) and Gauss-Legendre quadrature.
197///
198/// Provides evaluation via the three-term recurrence, computation of all
199/// P_0, …, P_n, and Gauss-Legendre nodes and weights up to degree n.
200pub struct LegendrePolynomial;
201
202impl LegendrePolynomial {
203    /// Evaluate P_n(x) using the three-term recurrence.
204    pub fn eval(n: usize, x: f64) -> f64 {
205        if n == 0 {
206            return 1.0;
207        }
208        if n == 1 {
209            return x;
210        }
211        let mut p_prev = 1.0f64;
212        let mut p_curr = x;
213        for k in 1..n {
214            let k_f = k as f64;
215            let p_next = ((2.0 * k_f + 1.0) * x * p_curr - k_f * p_prev) / (k_f + 1.0);
216            p_prev = p_curr;
217            p_curr = p_next;
218        }
219        p_curr
220    }
221
222    /// Evaluate P_0(x), …, P_n(x) and return all values.
223    pub fn eval_all(n: usize, x: f64) -> Vec<f64> {
224        let mut ps = vec![0.0f64; n + 1];
225        ps[0] = 1.0;
226        if n >= 1 {
227            ps[1] = x;
228        }
229        for k in 1..n {
230            let k_f = k as f64;
231            ps[k + 1] = ((2.0 * k_f + 1.0) * x * ps[k] - k_f * ps[k - 1]) / (k_f + 1.0);
232        }
233        ps
234    }
235
236    /// Compute the `n`-point Gauss-Legendre nodes and weights on \[-1, 1\].
237    ///
238    /// Uses Newton's method to find roots of P_n and computes weights from
239    /// the derivative formula.  Returns `(nodes, weights)`.
240    pub fn gauss_legendre(n: usize) -> (Vec<f64>, Vec<f64>) {
241        let mut nodes = vec![0.0f64; n];
242        let mut weights = vec![0.0f64; n];
243
244        for i in 0..n.div_ceil(2) {
245            // Initial guess: Chebyshev nodes
246            let mut x = ((2 * i + 1) as f64 * PI / (2 * n) as f64 + PI / (4.0 * n as f64)).cos();
247
248            for _ in 0..100 {
249                let ps = Self::eval_all(n, x);
250                let pn = ps[n];
251                let pn_1 = if n >= 1 { ps[n - 1] } else { 0.0 };
252                // Derivative: P_n'(x) = n * (x*P_n(x) - P_{n-1}(x)) / (x^2 - 1)
253                let dp = if (x.abs() - 1.0).abs() < 1e-14 {
254                    n as f64 * (n as f64 + 1.0) / 2.0 // limit at ±1
255                } else {
256                    (n as f64) * (pn_1 - x * pn) / (1.0 - x * x)
257                };
258                let dx = pn / dp;
259                x -= dx;
260                if dx.abs() < 1e-15 {
261                    break;
262                }
263            }
264
265            let ps = Self::eval_all(n, x);
266            let pn = ps[n];
267            let pn_1 = if n >= 1 { ps[n - 1] } else { 0.0 };
268            let dp = if (x.abs() - 1.0).abs() < 1e-14 {
269                n as f64 * (n as f64 + 1.0) / 2.0
270            } else {
271                (n as f64) * (pn_1 - x * pn) / (1.0 - x * x)
272            };
273            let w = 2.0 / ((1.0 - x * x) * dp * dp);
274
275            nodes[i] = -x;
276            nodes[n - 1 - i] = x;
277            weights[i] = w;
278            weights[n - 1 - i] = w;
279        }
280
281        (nodes, weights)
282    }
283
284    /// Integrate `f` over \[-1, 1\] using `n`-point Gauss-Legendre quadrature.
285    pub fn integrate<F: Fn(f64) -> f64>(f: F, n: usize) -> f64 {
286        let (nodes, weights) = Self::gauss_legendre(n);
287        nodes
288            .iter()
289            .zip(weights.iter())
290            .map(|(&x, &w)| w * f(x))
291            .sum()
292    }
293}
294
295// ─────────────────────────────────────────────────────────────────────────────
296// FourierSeries — DFT, IDFT, power spectrum, convolution
297// ─────────────────────────────────────────────────────────────────────────────
298
299/// Discrete Fourier Transform and spectral utilities.
300///
301/// Provides Cooley-Tukey radix-2 FFT (in-place), IFFT, power spectrum
302/// estimation, and circular convolution via the convolution theorem.
303pub struct FourierSeries;
304
305impl FourierSeries {
306    /// Compute the FFT of `data` (length must be a power of 2).
307    ///
308    /// Returns a vector of complex coefficients `X[k]` where
309    /// `X[k] = Σ x[n] * exp(-2πi kn/N)`.
310    pub fn fft(data: &[f64]) -> Vec<(f64, f64)> {
311        let n = data.len();
312        assert!(n.is_power_of_two(), "FFT length must be a power of 2");
313        let mut buf: Vec<Cx> = data.iter().map(|&x| Cx::new(x, 0.0)).collect();
314        fft_inplace(&mut buf, false);
315        buf.iter().map(|c| (c.re, c.im)).collect()
316    }
317
318    /// Compute the inverse FFT.
319    ///
320    /// Input is a slice of `(re, im)` pairs; output is the real part of the
321    /// inverse transform (discards imaginary part which should be ~0 for real signals).
322    pub fn ifft(spectrum: &[(f64, f64)]) -> Vec<f64> {
323        let n = spectrum.len();
324        assert!(n.is_power_of_two(), "IFFT length must be a power of 2");
325        let mut buf: Vec<Cx> = spectrum.iter().map(|&(re, im)| Cx::new(re, im)).collect();
326        fft_inplace(&mut buf, true);
327        buf.iter().map(|c| c.re / n as f64).collect()
328    }
329
330    /// Compute the one-sided power spectrum of `data`.
331    ///
332    /// Returns `|X[k]|^2 / N` for k = 0, …, N/2.
333    pub fn power_spectrum(data: &[f64]) -> Vec<f64> {
334        let n = data.len();
335        assert!(n.is_power_of_two());
336        let spec = Self::fft(data);
337        let norm = n as f64;
338        (0..=n / 2)
339            .map(|k| {
340                let (re, im) = spec[k];
341                (re * re + im * im) / norm
342            })
343            .collect()
344    }
345
346    /// Circular convolution of `a` and `b` via the convolution theorem.
347    ///
348    /// Both inputs must have the same length (a power of 2).
349    /// Returns the circular convolution `a * b`.
350    pub fn convolve(a: &[f64], b: &[f64]) -> Vec<f64> {
351        let n = a.len();
352        assert_eq!(n, b.len());
353        assert!(n.is_power_of_two());
354        let sa = Self::fft(a);
355        let sb = Self::fft(b);
356        let product: Vec<(f64, f64)> = sa
357            .iter()
358            .zip(sb.iter())
359            .map(|(&(ar, ai), &(br, bi))| (ar * br - ai * bi, ar * bi + ai * br))
360            .collect();
361        Self::ifft(&product)
362    }
363
364    /// Compute the DFT frequency bins for sample rate `fs` and `n` points.
365    ///
366    /// Returns `n/2 + 1` non-negative frequencies in Hz.
367    pub fn frequencies(n: usize, fs: f64) -> Vec<f64> {
368        (0..=n / 2).map(|k| k as f64 * fs / n as f64).collect()
369    }
370
371    /// Evaluate the truncated Fourier series at points `x` using `n_terms` terms.
372    ///
373    /// Coefficients `(a_k, b_k)` are the cosine and sine amplitudes.
374    /// Returns `sum_{k=0}^{n_terms-1} a_k cos(k x) + b_k sin(k x)`.
375    pub fn eval_series(coeffs: &[(f64, f64)], x: f64) -> f64 {
376        coeffs
377            .iter()
378            .enumerate()
379            .map(|(k, &(ak, bk))| ak * (k as f64 * x).cos() + bk * (k as f64 * x).sin())
380            .sum()
381    }
382}
383
384/// In-place Cooley-Tukey FFT (radix-2 DIT).
385fn fft_inplace(buf: &mut Vec<Cx>, inverse: bool) {
386    let n = buf.len();
387    // Bit-reversal permutation
388    let mut j = 0usize;
389    for i in 1..n {
390        let mut bit = n >> 1;
391        while j & bit != 0 {
392            j ^= bit;
393            bit >>= 1;
394        }
395        j ^= bit;
396        if i < j {
397            buf.swap(i, j);
398        }
399    }
400    // Butterfly
401    let sign = if inverse { 1.0 } else { -1.0 };
402    let mut len = 2usize;
403    while len <= n {
404        let ang = sign * 2.0 * PI / len as f64;
405        let w_len = Cx::from_polar(1.0, ang);
406        for i in (0..n).step_by(len) {
407            let mut w = Cx::new(1.0, 0.0);
408            for k in 0..len / 2 {
409                let u = buf[i + k];
410                let v = buf[i + k + len / 2].mul(w);
411                buf[i + k] = u.add(v);
412                buf[i + k + len / 2] = u.sub(v);
413                w = w.mul(w_len);
414            }
415        }
416        len <<= 1;
417    }
418}
419
420// ─────────────────────────────────────────────────────────────────────────────
421// SpectralDiff — spectral differentiation via FFT
422// ─────────────────────────────────────────────────────────────────────────────
423
424/// Pseudo-spectral differentiation and related operations.
425///
426/// Differentiates periodic functions sampled on a uniform grid using the
427/// spectral derivative (multiplication by `ik` in Fourier space).
428pub struct SpectralDiff;
429
430impl SpectralDiff {
431    /// Compute the first derivative of a periodic function sampled at `n`
432    /// uniformly spaced points on \[0, L) using spectral (FFT) differentiation.
433    ///
434    /// `n` must be a power of 2.  Returns `du/dx` at the same grid points.
435    pub fn diff(u: &[f64], l: f64) -> Vec<f64> {
436        let n = u.len();
437        assert!(n.is_power_of_two());
438        let mut buf: Vec<Cx> = u.iter().map(|&x| Cx::new(x, 0.0)).collect();
439        fft_inplace(&mut buf, false);
440
441        // Multiply by i*k (wavenumber)
442        let dk = 2.0 * PI / l;
443        for k in 0..n {
444            let kk = if k <= n / 2 {
445                k as f64
446            } else {
447                k as f64 - n as f64
448            };
449            let freq = kk * dk;
450            let (re, im) = (buf[k].re, buf[k].im);
451            buf[k] = Cx::new(-freq * im, freq * re);
452        }
453
454        fft_inplace(&mut buf, true);
455        buf.iter().map(|c| c.re / n as f64).collect()
456    }
457
458    /// Compute the second derivative of a periodic function via FFT.
459    ///
460    /// Multiplies each Fourier mode by `-(ik)^2 = k^2`.
461    pub fn diff2(u: &[f64], l: f64) -> Vec<f64> {
462        let n = u.len();
463        assert!(n.is_power_of_two());
464        let mut buf: Vec<Cx> = u.iter().map(|&x| Cx::new(x, 0.0)).collect();
465        fft_inplace(&mut buf, false);
466
467        let dk = 2.0 * PI / l;
468        for k in 0..n {
469            let kk = if k <= n / 2 {
470                k as f64
471            } else {
472                k as f64 - n as f64
473            };
474            let freq2 = -(kk * dk).powi(2);
475            buf[k] = buf[k].scale(freq2);
476        }
477
478        fft_inplace(&mut buf, true);
479        buf.iter().map(|c| c.re / n as f64).collect()
480    }
481
482    /// Interpolate a periodic function sampled at `n` points to `m` uniformly
483    /// spaced points on \[0, L) via zero-padding in Fourier space.
484    ///
485    /// Both `n` and `m` must be powers of 2, and `m >= n`.
486    pub fn interpolate(u: &[f64], m: usize, l: f64) -> Vec<f64> {
487        let n = u.len();
488        assert!(n.is_power_of_two() && m.is_power_of_two() && m >= n);
489        let mut buf: Vec<Cx> = u.iter().map(|&x| Cx::new(x, 0.0)).collect();
490        fft_inplace(&mut buf, false);
491
492        let mut padded = vec![Cx::new(0.0, 0.0); m];
493        padded[..n / 2].copy_from_slice(&buf[..n / 2]);
494        for k in 1..=n / 2 {
495            padded[m - k] = buf[n - k];
496        }
497
498        fft_inplace(&mut padded, true);
499        let scale = m as f64 / (n as f64 * m as f64);
500        let _ = l; // grid spacing is implicit in the scaling
501        padded
502            .iter()
503            .map(|c| c.re * m as f64 / n as f64 / m as f64 * n as f64 * scale)
504            .collect()
505    }
506}
507
508// ─────────────────────────────────────────────────────────────────────────────
509// ChebyshevCollocation — 1D BVP solver
510// ─────────────────────────────────────────────────────────────────────────────
511
512/// Chebyshev collocation method for 1D boundary value problems.
513///
514/// Sets up the collocation differentiation matrix on Chebyshev-Gauss-Lobatto
515/// nodes and solves second-order BVPs of the form
516/// `p(x) u'' + q(x) u' + r(x) u = g(x)` with Dirichlet boundary conditions.
517pub struct ChebyshevCollocation {
518    /// Number of interior collocation points (polynomial degree = n+1).
519    pub n: usize,
520}
521
522impl ChebyshevCollocation {
523    /// Construct a collocation scheme with `n` Chebyshev-Gauss-Lobatto points.
524    pub fn new(n: usize) -> Self {
525        assert!(n >= 2, "n must be at least 2");
526        Self { n }
527    }
528
529    /// Return the collocation nodes (Chebyshev-Gauss-Lobatto on \[-1, 1\]).
530    pub fn nodes(&self) -> Vec<f64> {
531        ChebyshevPolynomial::nodes(self.n - 1)
532    }
533
534    /// Return the first-derivative spectral differentiation matrix D of size `n x n`.
535    pub fn diff_matrix(&self) -> Vec<Vec<f64>> {
536        ChebyshevPolynomial::diff_matrix(self.n - 1)
537    }
538
539    /// Solve the 1D Poisson equation `u'' = g(x)` on \[-1, 1\] with
540    /// Dirichlet boundary conditions `u(-1) = bc_left`, `u(1) = bc_right`.
541    ///
542    /// Uses the Chebyshev spectral differentiation matrix; applies boundary
543    /// conditions by replacing the first and last rows.  Returns `u` at all
544    /// collocation nodes.
545    pub fn solve_poisson<G>(&self, g: G, bc_left: f64, bc_right: f64) -> Vec<f64>
546    where
547        G: Fn(f64) -> f64,
548    {
549        let m = self.n;
550        let x = self.nodes();
551        let d = self.diff_matrix();
552
553        // Compute D^2 = D * D
554        let mut d2 = vec![vec![0.0f64; m]; m];
555        for i in 0..m {
556            for j in 0..m {
557                for k in 0..m {
558                    d2[i][j] += d[i][k] * d[k][j];
559                }
560            }
561        }
562
563        // Build RHS
564        let mut rhs: Vec<f64> = x.iter().map(|&xi| g(xi)).collect();
565
566        // Enforce boundary conditions (nodes are ordered x[0]=1, x[m-1]=-1)
567        // Overwrite first row: u(x[0]) = bc_right (x[0] = cos(0) = 1)
568        for j in 0..m {
569            d2[0][j] = if j == 0 { 1.0 } else { 0.0 };
570        }
571        rhs[0] = bc_right;
572
573        // Overwrite last row: u(x[m-1]) = bc_left (x[m-1] = cos(pi) = -1)
574        for j in 0..m {
575            d2[m - 1][j] = if j == m - 1 { 1.0 } else { 0.0 };
576        }
577        rhs[m - 1] = bc_left;
578
579        // Solve linear system via Gaussian elimination with partial pivoting
580        gauss_solve(&mut d2, &mut rhs)
581    }
582}
583
584/// Solve `A x = b` in-place via Gaussian elimination with partial pivoting.
585///
586/// Modifies `a` and `b`; returns the solution vector.
587fn gauss_solve(a: &mut Vec<Vec<f64>>, b: &mut Vec<f64>) -> Vec<f64> {
588    let n = b.len();
589    for col in 0..n {
590        // Partial pivot
591        let mut max_row = col;
592        let mut max_val = a[col][col].abs();
593        for row in (col + 1)..n {
594            if a[row][col].abs() > max_val {
595                max_val = a[row][col].abs();
596                max_row = row;
597            }
598        }
599        a.swap(col, max_row);
600        b.swap(col, max_row);
601
602        let pivot = a[col][col];
603        if pivot.abs() < 1e-14 {
604            continue;
605        }
606        for row in (col + 1)..n {
607            let factor = a[row][col] / pivot;
608            for k in col..n {
609                let sub = factor * a[col][k];
610                a[row][k] -= sub;
611            }
612            b[row] -= factor * b[col];
613        }
614    }
615    // Back substitution
616    let mut x = vec![0.0f64; n];
617    for i in (0..n).rev() {
618        let sum: f64 = (i + 1..n).map(|j| a[i][j] * x[j]).sum();
619        x[i] = if a[i][i].abs() > 1e-14 {
620            (b[i] - sum) / a[i][i]
621        } else {
622            0.0
623        };
624    }
625    x
626}
627
628// ─────────────────────────────────────────────────────────────────────────────
629// WaveletTransform — Haar wavelet, multi-level decomposition, reconstruction
630// ─────────────────────────────────────────────────────────────────────────────
631
632/// Haar wavelet multi-level decomposition and reconstruction.
633///
634/// The Haar wavelet is the simplest orthogonal wavelet.  The transform
635/// iteratively applies the single-level forward transform to the approximation
636/// coefficients, producing a hierarchy of detail sub-bands.
637pub struct WaveletTransform;
638
639impl WaveletTransform {
640    /// Forward single-level Haar transform.
641    ///
642    /// Returns `(approx, detail)` each of length `signal.len() / 2`.
643    pub fn haar_forward(signal: &[f64]) -> (Vec<f64>, Vec<f64>) {
644        let n = signal.len() / 2;
645        let s2i = 1.0 / std::f64::consts::SQRT_2;
646        let mut approx = Vec::with_capacity(n);
647        let mut detail = Vec::with_capacity(n);
648        for i in 0..n {
649            let a = signal[2 * i];
650            let b = signal[2 * i + 1];
651            approx.push((a + b) * s2i);
652            detail.push((a - b) * s2i);
653        }
654        (approx, detail)
655    }
656
657    /// Inverse single-level Haar transform.
658    ///
659    /// Reconstructs a signal of length `2 * approx.len()`.
660    pub fn haar_inverse(approx: &[f64], detail: &[f64]) -> Vec<f64> {
661        let n = approx.len().min(detail.len());
662        let s2i = 1.0 / std::f64::consts::SQRT_2;
663        let mut out = vec![0.0f64; 2 * n];
664        for i in 0..n {
665            out[2 * i] = (approx[i] + detail[i]) * s2i;
666            out[2 * i + 1] = (approx[i] - detail[i]) * s2i;
667        }
668        out
669    }
670
671    /// Multi-level Haar decomposition.
672    ///
673    /// Returns a `Vec` of length `levels + 1`:
674    /// - index 0: final (coarsest) approximation
675    /// - indices 1..=levels: detail coefficients (finest at index 1)
676    pub fn decompose(signal: &[f64], levels: usize) -> Vec<Vec<f64>> {
677        let mut result = Vec::with_capacity(levels + 1);
678        let mut approx = signal.to_vec();
679        for _ in 0..levels {
680            if approx.len() < 2 {
681                break;
682            }
683            let (a, d) = Self::haar_forward(&approx);
684            result.push(d);
685            approx = a;
686        }
687        result.push(approx);
688        result.reverse();
689        result
690    }
691
692    /// Multi-level Haar reconstruction from decomposed coefficients.
693    ///
694    /// Input `coeffs` must be in the same format as returned by [`WaveletTransform::decompose`].
695    pub fn reconstruct(coeffs: &[Vec<f64>]) -> Vec<f64> {
696        if coeffs.is_empty() {
697            return vec![];
698        }
699        let mut approx = coeffs[0].clone();
700        for detail in &coeffs[1..] {
701            approx = Self::haar_inverse(&approx, detail);
702        }
703        approx
704    }
705
706    /// Soft thresholding (denoising) on wavelet detail coefficients.
707    ///
708    /// Applies the soft-threshold function `sign(x) * max(|x| - lambda, 0)`
709    /// to each detail coefficient in `coeffs[1..]`.
710    pub fn soft_threshold(coeffs: &mut Vec<Vec<f64>>, lambda: f64) {
711        for sub in coeffs.iter_mut().skip(1) {
712            for v in sub.iter_mut() {
713                let s = v.signum();
714                let a = v.abs() - lambda;
715                *v = if a > 0.0 { s * a } else { 0.0 };
716            }
717        }
718    }
719
720    /// Compute the energy in each sub-band of a decomposition.
721    ///
722    /// Returns one energy value per sub-band.
723    pub fn subband_energy(coeffs: &[Vec<f64>]) -> Vec<f64> {
724        coeffs
725            .iter()
726            .map(|sub| sub.iter().map(|v| v * v).sum::<f64>())
727            .collect()
728    }
729}
730
731// ─────────────────────────────────────────────────────────────────────────────
732// Tests
733// ─────────────────────────────────────────────────────────────────────────────
734#[cfg(test)]
735mod tests {
736    use super::*;
737
738    // ------------------------------------------------------------------
739    // ChebyshevPolynomial
740    // ------------------------------------------------------------------
741    #[test]
742    fn test_cheb_eval_t0() {
743        assert!((ChebyshevPolynomial::eval(0, 0.7) - 1.0).abs() < 1e-14);
744    }
745
746    #[test]
747    fn test_cheb_eval_t1() {
748        assert!((ChebyshevPolynomial::eval(1, 0.5) - 0.5).abs() < 1e-14);
749    }
750
751    #[test]
752    fn test_cheb_eval_t2() {
753        // T_2(x) = 2x^2 - 1
754        let x = 0.6;
755        let expected = 2.0 * x * x - 1.0;
756        assert!((ChebyshevPolynomial::eval(2, x) - expected).abs() < 1e-12);
757    }
758
759    #[test]
760    fn test_cheb_eval_t3() {
761        // T_3(x) = 4x^3 - 3x
762        let x = 0.3;
763        let expected = 4.0 * x * x * x - 3.0 * x;
764        assert!((ChebyshevPolynomial::eval(3, x) - expected).abs() < 1e-12);
765    }
766
767    #[test]
768    fn test_cheb_eval_all_consistency() {
769        let x = 0.4;
770        let all = ChebyshevPolynomial::eval_all(5, x);
771        for k in 0..=5 {
772            assert!((all[k] - ChebyshevPolynomial::eval(k, x)).abs() < 1e-12);
773        }
774    }
775
776    #[test]
777    fn test_cheb_nodes_count() {
778        let n = 7;
779        let nodes = ChebyshevPolynomial::nodes(n);
780        assert_eq!(nodes.len(), n + 1);
781    }
782
783    #[test]
784    fn test_cheb_nodes_bounds() {
785        for &x in ChebyshevPolynomial::nodes(8).iter() {
786            assert!((-1.0 - 1e-12..=1.0 + 1e-12).contains(&x));
787        }
788    }
789
790    #[test]
791    fn test_cheb_gauss_nodes() {
792        let nodes = ChebyshevPolynomial::gauss_nodes(5);
793        assert_eq!(nodes.len(), 5);
794        // All interior nodes
795        for &x in &nodes {
796            assert!(x.abs() < 1.0);
797        }
798    }
799
800    #[test]
801    fn test_cheb_diff_matrix_size() {
802        let d = ChebyshevPolynomial::diff_matrix(5);
803        assert_eq!(d.len(), 6);
804        assert_eq!(d[0].len(), 6);
805    }
806
807    #[test]
808    fn test_cheb_diff_matrix_row_sum_zero() {
809        let d = ChebyshevPolynomial::diff_matrix(6);
810        // Each row of the differentiation matrix should sum to ~0
811        // (since the derivative of a constant is 0)
812        for row in &d {
813            let s: f64 = row.iter().sum();
814            assert!(s.abs() < 1e-8, "row sum = {s}");
815        }
816    }
817
818    #[test]
819    fn test_cheb_interpolation_coeffs_constant() {
820        // Constant function f=1: all coefficients except a_0 should be ~0
821        let n = 8;
822        let vals = vec![1.0f64; n + 1];
823        let coeffs = ChebyshevPolynomial::interpolation_coeffs(&vals);
824        assert!((coeffs[0] - 1.0).abs() < 1e-10);
825        for &c in coeffs.iter().skip(1) {
826            assert!(c.abs() < 1e-10);
827        }
828    }
829
830    // ------------------------------------------------------------------
831    // LegendrePolynomial
832    // ------------------------------------------------------------------
833    #[test]
834    fn test_legendre_p0() {
835        assert!((LegendrePolynomial::eval(0, 0.5) - 1.0).abs() < 1e-14);
836    }
837
838    #[test]
839    fn test_legendre_p1() {
840        assert!((LegendrePolynomial::eval(1, 0.3) - 0.3).abs() < 1e-14);
841    }
842
843    #[test]
844    fn test_legendre_p2() {
845        let x = 0.5;
846        let expected = 0.5 * (3.0 * x * x - 1.0);
847        assert!((LegendrePolynomial::eval(2, x) - expected).abs() < 1e-12);
848    }
849
850    #[test]
851    fn test_legendre_p3() {
852        let x = 0.4;
853        let expected = 0.5 * (5.0 * x * x * x - 3.0 * x);
854        assert!((LegendrePolynomial::eval(3, x) - expected).abs() < 1e-12);
855    }
856
857    #[test]
858    fn test_legendre_eval_all_consistency() {
859        let x = 0.7;
860        let all = LegendrePolynomial::eval_all(4, x);
861        for k in 0..=4 {
862            assert!((all[k] - LegendrePolynomial::eval(k, x)).abs() < 1e-12);
863        }
864    }
865
866    #[test]
867    fn test_gauss_legendre_nodes_count() {
868        let (nodes, weights) = LegendrePolynomial::gauss_legendre(5);
869        assert_eq!(nodes.len(), 5);
870        assert_eq!(weights.len(), 5);
871    }
872
873    #[test]
874    fn test_gauss_legendre_weights_sum() {
875        let (_, weights) = LegendrePolynomial::gauss_legendre(5);
876        let sum: f64 = weights.iter().sum();
877        assert!((sum - 2.0).abs() < 1e-10);
878    }
879
880    #[test]
881    fn test_gauss_legendre_integrate_poly() {
882        // Integrate x^4 on [-1,1] = 2/5
883        let result = LegendrePolynomial::integrate(|x| x.powi(4), 5);
884        assert!((result - 0.4).abs() < 1e-10);
885    }
886
887    #[test]
888    fn test_gauss_legendre_integrate_exp() {
889        // Integrate exp(x) on [-1,1] = e - 1/e
890        let exact = std::f64::consts::E - 1.0 / std::f64::consts::E;
891        let result = LegendrePolynomial::integrate(|x| x.exp(), 8);
892        assert!((result - exact).abs() < 1e-10);
893    }
894
895    // ------------------------------------------------------------------
896    // FourierSeries / FFT
897    // ------------------------------------------------------------------
898    #[test]
899    fn test_fft_length() {
900        let data = vec![1.0, 2.0, 3.0, 4.0, 0.0, 0.0, 0.0, 0.0];
901        let spec = FourierSeries::fft(&data);
902        assert_eq!(spec.len(), 8);
903    }
904
905    #[test]
906    fn test_fft_ifft_roundtrip() {
907        let n = 16;
908        let data: Vec<f64> = (0..n)
909            .map(|k| (2.0 * PI * k as f64 / n as f64).sin())
910            .collect();
911        let spec = FourierSeries::fft(&data);
912        let recovered = FourierSeries::ifft(&spec);
913        for (a, b) in data.iter().zip(recovered.iter()) {
914            assert!((a - b).abs() < 1e-10, "mismatch: {a} vs {b}");
915        }
916    }
917
918    #[test]
919    fn test_fft_dc_component() {
920        // Constant signal: FFT[0] = N * mean
921        let n = 8;
922        let data = vec![3.0f64; n];
923        let spec = FourierSeries::fft(&data);
924        assert!((spec[0].0 - 3.0 * n as f64).abs() < 1e-10);
925        // All other bins should be ~0
926        for k in 1..n {
927            assert!(spec[k].0.abs() < 1e-10 && spec[k].1.abs() < 1e-10);
928        }
929    }
930
931    #[test]
932    fn test_fft_single_frequency() {
933        let n = 8usize;
934        let k0 = 2usize; // frequency bin 2
935        let data: Vec<f64> = (0..n)
936            .map(|j| (2.0 * PI * k0 as f64 * j as f64 / n as f64).cos())
937            .collect();
938        let spec = FourierSeries::fft(&data);
939        // bins k0 and n-k0 should have amplitude n/2
940        let amp_k0 = (spec[k0].0.powi(2) + spec[k0].1.powi(2)).sqrt();
941        assert!((amp_k0 - n as f64 / 2.0).abs() < 1e-8);
942    }
943
944    #[test]
945    fn test_power_spectrum_length() {
946        let data = vec![0.0f64; 16];
947        let ps = FourierSeries::power_spectrum(&data);
948        assert_eq!(ps.len(), 9); // n/2 + 1
949    }
950
951    #[test]
952    fn test_convolve_delta() {
953        // Convolving with a delta (impulse at 0) should return the original signal
954        let n = 8usize;
955        let signal: Vec<f64> = (0..n).map(|k| k as f64 + 1.0).collect();
956        let mut delta = vec![0.0f64; n];
957        delta[0] = 1.0;
958        let result = FourierSeries::convolve(&signal, &delta);
959        for (a, b) in signal.iter().zip(result.iter()) {
960            assert!((a - b).abs() < 1e-8);
961        }
962    }
963
964    #[test]
965    fn test_frequencies_length() {
966        let freqs = FourierSeries::frequencies(16, 100.0);
967        assert_eq!(freqs.len(), 9);
968        assert!((freqs[0]).abs() < 1e-12);
969    }
970
971    // ------------------------------------------------------------------
972    // SpectralDiff
973    // ------------------------------------------------------------------
974    #[test]
975    fn test_spectral_diff_sin() {
976        // d/dx sin(x) = cos(x) on [0, 2pi)
977        let n = 64usize;
978        let l = 2.0 * PI;
979        let u: Vec<f64> = (0..n)
980            .map(|k| (2.0 * PI * k as f64 / n as f64).sin())
981            .collect();
982        let du = SpectralDiff::diff(&u, l);
983        let expected: Vec<f64> = (0..n)
984            .map(|k| (2.0 * PI * k as f64 / n as f64).cos())
985            .collect();
986        for (got, exp) in du.iter().zip(expected.iter()) {
987            assert!((got - exp).abs() < 1e-8, "got {got}, expected {exp}");
988        }
989    }
990
991    #[test]
992    fn test_spectral_diff2_sin() {
993        // d^2/dx^2 sin(x) = -sin(x)
994        let n = 64usize;
995        let l = 2.0 * PI;
996        let u: Vec<f64> = (0..n)
997            .map(|k| (2.0 * PI * k as f64 / n as f64).sin())
998            .collect();
999        let d2u = SpectralDiff::diff2(&u, l);
1000        for (k, (&got, &orig)) in d2u.iter().zip(u.iter()).enumerate() {
1001            let _ = k;
1002            assert!((got + orig).abs() < 1e-8);
1003        }
1004    }
1005
1006    // ------------------------------------------------------------------
1007    // ChebyshevCollocation
1008    // ------------------------------------------------------------------
1009    #[test]
1010    fn test_collocation_poisson_linear() {
1011        // u'' = 0, u(-1)=0, u(1)=1 => u(x) = (x+1)/2
1012        let coll = ChebyshevCollocation::new(12);
1013        let u = coll.solve_poisson(|_x| 0.0, 0.0, 1.0);
1014        let nodes = coll.nodes();
1015        for (&xi, &ui) in nodes.iter().zip(u.iter()) {
1016            let expected = (xi + 1.0) / 2.0;
1017            assert!(
1018                (ui - expected).abs() < 1e-8,
1019                "x={xi} u={ui} expected={expected}"
1020            );
1021        }
1022    }
1023
1024    #[test]
1025    fn test_collocation_nodes_count() {
1026        let coll = ChebyshevCollocation::new(8);
1027        assert_eq!(coll.nodes().len(), 8);
1028    }
1029
1030    #[test]
1031    fn test_collocation_diff_matrix_size() {
1032        let coll = ChebyshevCollocation::new(6);
1033        let d = coll.diff_matrix();
1034        assert_eq!(d.len(), 6);
1035    }
1036
1037    // ------------------------------------------------------------------
1038    // WaveletTransform
1039    // ------------------------------------------------------------------
1040    #[test]
1041    fn test_haar_forward_inverse_roundtrip() {
1042        let signal = vec![1.0, 3.0, 5.0, 7.0, 2.0, 4.0, 6.0, 8.0];
1043        let (approx, detail) = WaveletTransform::haar_forward(&signal);
1044        let recovered = WaveletTransform::haar_inverse(&approx, &detail);
1045        for (a, b) in signal.iter().zip(recovered.iter()) {
1046            assert!((a - b).abs() < 1e-12);
1047        }
1048    }
1049
1050    #[test]
1051    fn test_haar_decompose_reconstruct() {
1052        let signal: Vec<f64> = (0..16).map(|k| k as f64).collect();
1053        let coeffs = WaveletTransform::decompose(&signal, 3);
1054        let recovered = WaveletTransform::reconstruct(&coeffs);
1055        for (a, b) in signal.iter().zip(recovered.iter()) {
1056            assert!((a - b).abs() < 1e-10);
1057        }
1058    }
1059
1060    #[test]
1061    fn test_haar_decompose_levels() {
1062        let signal = vec![1.0f64; 8];
1063        let coeffs = WaveletTransform::decompose(&signal, 3);
1064        // 3 levels + 1 approximation = 4 sub-bands
1065        assert_eq!(coeffs.len(), 4);
1066    }
1067
1068    #[test]
1069    fn test_wavelet_soft_threshold_zeros_small() {
1070        let signal = vec![0.1f64, 0.2, 0.05, 1.0, 0.8, 0.03, 0.9, 0.02];
1071        let mut coeffs = WaveletTransform::decompose(&signal, 2);
1072        WaveletTransform::soft_threshold(&mut coeffs, 0.5);
1073        // Detail coefficients smaller than lambda should become 0
1074        for sub in coeffs.iter().skip(1) {
1075            for &v in sub {
1076                assert!(v.abs() <= v.abs() + 0.5); // trivially true; check no blow-up
1077            }
1078        }
1079    }
1080
1081    #[test]
1082    fn test_wavelet_subband_energy() {
1083        let signal: Vec<f64> = (0..8).map(|k| (k as f64).sin()).collect();
1084        let coeffs = WaveletTransform::decompose(&signal, 2);
1085        let energies = WaveletTransform::subband_energy(&coeffs);
1086        assert_eq!(energies.len(), coeffs.len());
1087        for &e in &energies {
1088            assert!(e >= 0.0);
1089        }
1090    }
1091
1092    #[test]
1093    fn test_wavelet_energy_conservation() {
1094        let signal: Vec<f64> = (0..8).map(|k| k as f64 + 1.0).collect();
1095        let total_energy: f64 = signal.iter().map(|v| v * v).sum();
1096        let coeffs = WaveletTransform::decompose(&signal, 3);
1097        let sub_energies: f64 = WaveletTransform::subband_energy(&coeffs).iter().sum();
1098        // Energy should be conserved (orthogonal transform)
1099        assert!((total_energy - sub_energies).abs() < 1e-8);
1100    }
1101
1102    // ------------------------------------------------------------------
1103    // Gauss-Legendre orthogonality
1104    // ------------------------------------------------------------------
1105    #[test]
1106    fn test_legendre_orthogonality() {
1107        // Integrate P_2(x) * P_3(x) on [-1,1] should be 0
1108        let result = LegendrePolynomial::integrate(
1109            |x| LegendrePolynomial::eval(2, x) * LegendrePolynomial::eval(3, x),
1110            8,
1111        );
1112        assert!(result.abs() < 1e-10);
1113    }
1114
1115    #[test]
1116    fn test_legendre_normalization() {
1117        // Integrate P_2(x)^2 on [-1,1] = 2/(2*2+1) = 2/5
1118        let result = LegendrePolynomial::integrate(|x| LegendrePolynomial::eval(2, x).powi(2), 8);
1119        assert!((result - 2.0 / 5.0).abs() < 1e-10);
1120    }
1121
1122    #[test]
1123    fn test_chebyshev_orthogonality_numerical() {
1124        // T_2 and T_3 are orthogonal w.r.t. weight 1/sqrt(1-x^2)
1125        // Numerical check via 16-pt GL: integral T_2 T_3 / sqrt(1-x^2) ~ 0
1126        let result = LegendrePolynomial::integrate(
1127            |x| {
1128                let w = if (1.0 - x * x) > 1e-10 {
1129                    1.0 / (1.0 - x * x).sqrt()
1130                } else {
1131                    0.0
1132                };
1133                ChebyshevPolynomial::eval(2, x) * ChebyshevPolynomial::eval(3, x) * w
1134            },
1135            16,
1136        );
1137        assert!(result.abs() < 1e-6);
1138    }
1139}