Skip to main content

scirs2_stats/regression/
functional.rs

1//! Scalar-on-Function Regression (Functional Linear Model)
2//!
3//! Models the relationship between a scalar response Y and a functional
4//! predictor X(t):
5//!
6//!   Y_i = α + ∫ β(t) X_i(t) dt + ε_i
7//!
8//! The unknown coefficient function β(t) is expanded in a basis:
9//!
10//!   β(t) = Σ_k c_k φ_k(t)
11//!
12//! so the integral becomes Z_ij = ∫ X_i(t) φ_j(t) dt, giving
13//!
14//!   Y = Z c + ε
15//!
16//! with a roughness penalty λ ∫ [β''(t)]² dt = λ c' Ω c.
17//!
18//! # Basis choices
19//!
20//! | Variant | Details |
21//! |---------|---------|
22//! | `BSpline{n_basis, degree}` | B-splines via de Boor recursion, equidistant knots |
23//! | `Fourier{n_basis}` | Fourier (sin/cos) basis on \[0,1\] |
24//! | `Wavelet{n_basis}` | Haar wavelet basis (power-of-two levels) |
25//!
26//! # References
27//!
28//! - Ramsay, J.O. & Silverman, B.W. (2005). *Functional Data Analysis* (2nd ed.). Springer.
29//! - Cardot, H., Ferraty, F. & Sarda, P. (1999). Functional linear model.
30//!   *Statistics & Probability Letters* 45, 11-22.
31
32use crate::error::{StatsError, StatsResult};
33
34// ---------------------------------------------------------------------------
35// Public types
36// ---------------------------------------------------------------------------
37
38/// Basis used to represent the coefficient function β(t).
39#[non_exhaustive]
40#[derive(Debug, Clone, Copy, PartialEq, Eq)]
41pub enum FunctionalBasis {
42    /// B-spline basis with `n_basis` functions of degree `degree`.
43    BSpline {
44        /// Number of B-spline basis functions.
45        n_basis: usize,
46        /// Polynomial degree (e.g., 3 for cubic splines).
47        degree: usize,
48    },
49    /// Fourier (trigonometric) basis with `n_basis` functions on [0, 1].
50    Fourier {
51        /// Number of Fourier basis functions (should be odd for symmetry).
52        n_basis: usize,
53    },
54    /// Haar wavelet basis with `n_basis` functions (must be power of two).
55    Wavelet {
56        /// Number of wavelet basis functions.
57        n_basis: usize,
58    },
59}
60
61/// Configuration for functional regression.
62#[non_exhaustive]
63#[derive(Debug, Clone)]
64pub struct FofConfig {
65    /// Basis for β(t) (default: cubic B-spline with 10 functions).
66    pub basis: FunctionalBasis,
67    /// Roughness penalty weight λ ≥ 0 (default: 0.01).
68    pub lambda: f64,
69    /// Number of quadrature grid points for ∫ X_i(t) φ_j(t) dt (default: 100).
70    pub n_grid: usize,
71}
72
73impl Default for FofConfig {
74    fn default() -> Self {
75        Self {
76            basis: FunctionalBasis::BSpline {
77                n_basis: 10,
78                degree: 3,
79            },
80            lambda: 0.01,
81            n_grid: 100,
82        }
83    }
84}
85
86/// Result of fitting a scalar-on-function regression.
87#[derive(Debug, Clone)]
88pub struct FofResult {
89    /// Basis expansion coefficients for β(t).
90    pub beta_coefs: Vec<f64>,
91    /// β(t) evaluated at the internal grid.
92    pub beta_values: Vec<f64>,
93    /// Time grid at which β is evaluated.
94    pub grid: Vec<f64>,
95    /// Intercept α.
96    pub intercept: f64,
97    /// Coefficient of determination R².
98    pub r_squared: f64,
99    /// Generalised cross-validation score GCV(λ).
100    pub gcv_score: f64,
101}
102
103/// Scalar-on-function regression estimator.
104#[derive(Debug, Clone)]
105pub struct FunctionalRegression {
106    config: FofConfig,
107    /// Stored basis coefficients after fitting (None before fit).
108    beta_coefs: Option<Vec<f64>>,
109    /// Fitted intercept.
110    intercept: Option<f64>,
111    /// Grid used during fit (for prediction).
112    fit_grid: Option<Vec<f64>>,
113}
114
115impl FunctionalRegression {
116    /// Create a new estimator with the given configuration.
117    pub fn new(config: FofConfig) -> Self {
118        Self {
119            config,
120            beta_coefs: None,
121            intercept: None,
122            fit_grid: None,
123        }
124    }
125
126    /// Fit the model.
127    ///
128    /// # Arguments
129    ///
130    /// * `data`     – observed functional predictors, shape `(n_obs, n_time)`.
131    ///                Each row is a discretised curve X_i(t).
132    /// * `response` – scalar response values, length `n_obs`.
133    /// * `grid`     – time points at which the curves are observed, length `n_time`.
134    ///
135    /// # Errors
136    ///
137    /// Returns `StatsError` when dimensions are inconsistent, `n_time < 2`,
138    /// `n_obs < n_basis + 1`, or the penalised system is singular.
139    pub fn fit(
140        &mut self,
141        data: &[Vec<f64>],
142        response: &[f64],
143        grid: &[f64],
144    ) -> StatsResult<FofResult> {
145        let n_obs = data.len();
146        if n_obs == 0 {
147            return Err(StatsError::InsufficientData(
148                "need at least one observation".to_owned(),
149            ));
150        }
151        let n_time = grid.len();
152        if n_time < 2 {
153            return Err(StatsError::InvalidArgument(
154                "grid must have at least 2 points".to_owned(),
155            ));
156        }
157        if response.len() != n_obs {
158            return Err(StatsError::DimensionMismatch(format!(
159                "response length {} != n_obs {}",
160                response.len(),
161                n_obs
162            )));
163        }
164        for (i, row) in data.iter().enumerate() {
165            if row.len() != n_time {
166                return Err(StatsError::DimensionMismatch(format!(
167                    "data[{}] has {} time points, expected {}",
168                    i,
169                    row.len(),
170                    n_time
171                )));
172            }
173        }
174
175        let n_basis = self.n_basis_fns();
176        if n_obs < n_basis + 1 {
177            return Err(StatsError::InsufficientData(format!(
178                "need n_obs >= n_basis+1 = {} but got {}",
179                n_basis + 1,
180                n_obs
181            )));
182        }
183
184        // --- 1. Evaluate basis at observation grid -------------------------
185        let phi = self.evaluate_basis(grid); // shape: n_time × n_basis
186
187        // --- 2. Build Z matrix (n_obs × n_basis): Z_ij = ∫ X_i(t) φ_j(t) dt
188        //        Using trapezoidal rule over the observation grid.
189        let z = build_z_matrix(data, &phi, grid); // n_obs × n_basis
190
191        // --- 3. Roughness penalty matrix Ω (n_basis × n_basis) ------------
192        let omega = self.roughness_penalty(n_basis);
193
194        // --- 4. Centre the response and add intercept ---------------------
195        let y_mean = response.iter().sum::<f64>() / n_obs as f64;
196        let y_centred: Vec<f64> = response.iter().map(|&y| y - y_mean).collect();
197
198        // Centre Z columns too (to decorrelate intercept from slopes)
199        let z_col_means: Vec<f64> = (0..n_basis)
200            .map(|j| z.iter().map(|row| row[j]).sum::<f64>() / n_obs as f64)
201            .collect();
202        let z_centred: Vec<Vec<f64>> = z
203            .iter()
204            .map(|row| {
205                row.iter()
206                    .enumerate()
207                    .map(|(j, &v)| v - z_col_means[j])
208                    .collect()
209            })
210            .collect();
211
212        // --- 5. Penalised least-squares: (Z'Z + λΩ) c = Z' y_c ----------
213        let coefs = penalized_ls(&z_centred, &y_centred, &omega, self.config.lambda)?;
214
215        // Recover intercept: α = ȳ - z̄' c
216        let intercept = y_mean
217            - z_col_means
218                .iter()
219                .zip(coefs.iter())
220                .map(|(&zm, &c)| zm * c)
221                .sum::<f64>();
222
223        // --- 6. Fitted values and R² --------------------------------------
224        let y_hat: Vec<f64> = z
225            .iter()
226            .map(|row| {
227                intercept
228                    + row
229                        .iter()
230                        .zip(coefs.iter())
231                        .map(|(&z_ij, &c)| z_ij * c)
232                        .sum::<f64>()
233            })
234            .collect();
235
236        let ss_res: f64 = response
237            .iter()
238            .zip(y_hat.iter())
239            .map(|(&y, &yh)| (y - yh).powi(2))
240            .sum();
241        let ss_tot: f64 = response.iter().map(|&y| (y - y_mean).powi(2)).sum();
242        let r_squared = if ss_tot > 0.0 {
243            1.0 - ss_res / ss_tot
244        } else {
245            1.0
246        };
247
248        // --- 7. GCV score --------------------------------------------------
249        // GCV(λ) = (1/n) ||y - ŷ||² / (1 - h̄)²
250        // where h̄ = trace(H) / n and H = Z (Z'Z + λΩ)⁻¹ Z'
251        let gcv_score = compute_gcv(&z, &y_hat, response, &omega, self.config.lambda, n_obs);
252
253        // --- 8. Evaluate β on evaluation grid -----------------------------
254        let eval_grid = linspace(grid[0], *grid.last().unwrap_or(&1.0), self.config.n_grid);
255        let phi_eval = self.evaluate_basis(&eval_grid);
256        let beta_values: Vec<f64> = eval_grid
257            .iter()
258            .enumerate()
259            .map(|(t, _)| {
260                phi_eval[t]
261                    .iter()
262                    .zip(coefs.iter())
263                    .map(|(&p, &c)| p * c)
264                    .sum()
265            })
266            .collect();
267
268        // Store state for prediction
269        self.beta_coefs = Some(coefs.clone());
270        self.intercept = Some(intercept);
271        self.fit_grid = Some(grid.to_vec());
272
273        Ok(FofResult {
274            beta_coefs: coefs,
275            beta_values,
276            grid: eval_grid,
277            intercept,
278            r_squared,
279            gcv_score,
280        })
281    }
282
283    /// Predict scalar responses for new functional observations.
284    ///
285    /// # Arguments
286    ///
287    /// * `new_data` – new functional observations, shape `(n_new, n_time)`.
288    /// * `grid`     – the same time grid used during `fit`.
289    ///
290    /// # Errors
291    ///
292    /// Returns `StatsError` if the model has not been fitted yet, or if
293    /// dimension mismatches exist.
294    pub fn predict(&self, new_data: &[Vec<f64>], grid: &[f64]) -> StatsResult<Vec<f64>> {
295        let coefs = self
296            .beta_coefs
297            .as_ref()
298            .ok_or_else(|| StatsError::ComputationError("model not fitted yet".to_owned()))?;
299        let intercept = self.intercept.unwrap_or(0.0);
300
301        let n_basis = self.n_basis_fns();
302        let phi = self.evaluate_basis(grid);
303        let z = build_z_matrix(new_data, &phi, grid);
304
305        let preds = z
306            .iter()
307            .map(|row| {
308                intercept
309                    + row
310                        .iter()
311                        .zip(coefs.iter())
312                        .take(n_basis)
313                        .map(|(&z_ij, &c)| z_ij * c)
314                        .sum::<f64>()
315            })
316            .collect();
317
318        Ok(preds)
319    }
320
321    // -----------------------------------------------------------------------
322    // Private helpers
323    // -----------------------------------------------------------------------
324
325    /// Return the number of basis functions.
326    fn n_basis_fns(&self) -> usize {
327        match self.config.basis {
328            FunctionalBasis::BSpline { n_basis, .. } => n_basis,
329            FunctionalBasis::Fourier { n_basis } => n_basis,
330            FunctionalBasis::Wavelet { n_basis } => n_basis,
331        }
332    }
333
334    /// Evaluate basis functions at each grid point.
335    ///
336    /// Returns a matrix of shape `(n_grid, n_basis)` where `phi[t][k]` = φ_k(grid[t]).
337    fn evaluate_basis(&self, grid: &[f64]) -> Vec<Vec<f64>> {
338        match self.config.basis {
339            FunctionalBasis::BSpline { n_basis, degree } => bspline_basis(grid, n_basis, degree),
340            FunctionalBasis::Fourier { n_basis } => fourier_basis(grid, n_basis),
341            FunctionalBasis::Wavelet { n_basis } => wavelet_basis(grid, n_basis),
342        }
343    }
344
345    /// Second-difference roughness penalty matrix Ω.
346    ///
347    /// Ω = D' D where D is the (n_basis-2) × n_basis second-difference matrix.
348    fn roughness_penalty(&self, n_basis: usize) -> Vec<Vec<f64>> {
349        roughness_penalty(n_basis)
350    }
351}
352
353// ---------------------------------------------------------------------------
354// B-spline basis (de Boor recursion)
355// ---------------------------------------------------------------------------
356
357/// Compute B-spline basis functions at the given grid points.
358///
359/// Uses the de Boor recursion on equidistant knots. Returns a matrix of
360/// shape `(n_grid, n_basis)`.
361pub fn bspline_basis(grid: &[f64], n_basis: usize, degree: usize) -> Vec<Vec<f64>> {
362    let n_grid = grid.len();
363    if n_basis == 0 || n_grid == 0 {
364        return vec![vec![]; n_grid];
365    }
366
367    let t_min = grid[0];
368    let t_max = *grid.last().unwrap_or(&1.0);
369
370    // Number of internal knots: n_basis - degree + 1 (for clamped B-splines)
371    // Total knot vector length: n_basis + degree + 1
372    let n_knots = n_basis + degree + 1;
373    let knots = build_clamped_knots(t_min, t_max, n_knots, degree);
374
375    let mut phi = vec![vec![0.0f64; n_basis]; n_grid];
376    for (t_idx, &t) in grid.iter().enumerate() {
377        for k in 0..n_basis {
378            phi[t_idx][k] = de_boor_basis(t, k, degree, &knots);
379        }
380    }
381    phi
382}
383
384/// Build a clamped B-spline knot vector with `n_knots` total knots.
385fn build_clamped_knots(t_min: f64, t_max: f64, n_knots: usize, degree: usize) -> Vec<f64> {
386    let mut knots = Vec::with_capacity(n_knots);
387    // Clamp `degree` copies at each end, uniform interior knots
388    let n_interior = n_knots.saturating_sub(2 * (degree + 1));
389    for _ in 0..=degree {
390        knots.push(t_min);
391    }
392    for i in 1..=(n_interior) {
393        let t = t_min + (t_max - t_min) * (i as f64) / (n_interior + 1) as f64;
394        knots.push(t);
395    }
396    while knots.len() < n_knots - (degree) {
397        let t = t_max;
398        knots.push(t);
399    }
400    for _ in 0..=degree {
401        knots.push(t_max);
402    }
403    knots.truncate(n_knots);
404    // Ensure last knots are t_max
405    while knots.len() < n_knots {
406        knots.push(t_max);
407    }
408    knots
409}
410
411/// Evaluate B-spline basis function B_{k,p}(t) using the Cox-de Boor recursion.
412///
413/// * `k`     – index of the basis function (0-based)
414/// * `p`     – degree
415/// * `knots` – knot vector
416fn de_boor_basis(t: f64, k: usize, p: usize, knots: &[f64]) -> f64 {
417    let n_knots = knots.len();
418    if k + p + 1 >= n_knots {
419        return 0.0;
420    }
421
422    if p == 0 {
423        // Indicator: 1 if knots[k] ≤ t < knots[k+1] (with special case at right end)
424        let at_right_end = (t - knots[k + 1]).abs() < 1e-14
425            && knots[k + 1] >= knots.last().copied().unwrap_or(f64::INFINITY);
426        return if (t >= knots[k] && t < knots[k + 1]) || at_right_end {
427            1.0
428        } else {
429            0.0
430        };
431    }
432
433    let denom1 = knots[k + p] - knots[k];
434    let left = if denom1.abs() > 1e-14 {
435        (t - knots[k]) / denom1 * de_boor_basis(t, k, p - 1, knots)
436    } else {
437        0.0
438    };
439
440    let denom2 = knots[k + p + 1] - knots[k + 1];
441    let right = if denom2.abs() > 1e-14 {
442        (knots[k + p + 1] - t) / denom2 * de_boor_basis(t, k + 1, p - 1, knots)
443    } else {
444        0.0
445    };
446
447    left + right
448}
449
450// ---------------------------------------------------------------------------
451// Fourier basis
452// ---------------------------------------------------------------------------
453
454/// Fourier basis on the normalised interval [0, 1].
455///
456/// For `n_basis = 2m+1`: {1, cos(2πt), sin(2πt), cos(4πt), sin(4πt), ...}
457fn fourier_basis(grid: &[f64], n_basis: usize) -> Vec<Vec<f64>> {
458    let n_grid = grid.len();
459    if n_basis == 0 || n_grid == 0 {
460        return vec![vec![]; n_grid];
461    }
462
463    let t_min = grid[0];
464    let t_max = *grid.last().unwrap_or(&1.0);
465    let span = (t_max - t_min).max(1e-12);
466
467    let mut phi = vec![vec![0.0f64; n_basis]; n_grid];
468    for (t_idx, &t) in grid.iter().enumerate() {
469        let s = (t - t_min) / span; // normalise to [0,1]
470        phi[t_idx][0] = 1.0; // constant
471        let mut k = 1usize;
472        let mut freq = 1usize;
473        while k < n_basis {
474            let omega = 2.0 * std::f64::consts::PI * freq as f64 * s;
475            if k < n_basis {
476                phi[t_idx][k] = omega.cos();
477                k += 1;
478            }
479            if k < n_basis {
480                phi[t_idx][k] = omega.sin();
481                k += 1;
482            }
483            freq += 1;
484        }
485    }
486    phi
487}
488
489// ---------------------------------------------------------------------------
490// Haar wavelet basis
491// ---------------------------------------------------------------------------
492
493/// Haar wavelet basis on the normalised interval [0, 1].
494///
495/// The first basis function is the scaling function 1_{[0,1)}.
496/// Subsequent functions are Haar wavelets at dyadic levels.
497fn wavelet_basis(grid: &[f64], n_basis: usize) -> Vec<Vec<f64>> {
498    let n_grid = grid.len();
499    if n_basis == 0 || n_grid == 0 {
500        return vec![vec![]; n_grid];
501    }
502
503    let t_min = grid[0];
504    let t_max = *grid.last().unwrap_or(&1.0);
505    let span = (t_max - t_min).max(1e-12);
506
507    let mut phi = vec![vec![0.0f64; n_basis]; n_grid];
508    for (t_idx, &t) in grid.iter().enumerate() {
509        let s = (t - t_min) / span; // normalise to [0,1]
510                                    // k=0: constant scaling function
511        phi[t_idx][0] = 1.0;
512
513        // k>=1: Haar wavelets; encode (level, translate) in index
514        let mut k = 1usize;
515        let mut level = 0usize;
516        while k < n_basis {
517            let n_at_level = 1usize << level; // 2^level wavelets at this level
518            let scale = (n_at_level as f64).sqrt(); // L2 normalisation
519            for translate in 0..n_at_level {
520                if k >= n_basis {
521                    break;
522                }
523                let t0 = translate as f64 / n_at_level as f64;
524                let tmid = (translate as f64 + 0.5) / n_at_level as f64;
525                let t1 = (translate + 1) as f64 / n_at_level as f64;
526                phi[t_idx][k] = if s >= t0 && s < tmid {
527                    scale
528                } else if s >= tmid && s < t1 {
529                    -scale
530                } else {
531                    0.0
532                };
533                k += 1;
534            }
535            level += 1;
536        }
537    }
538    phi
539}
540
541// ---------------------------------------------------------------------------
542// Z matrix: numerical integration ∫ X_i(t) φ_j(t) dt
543// ---------------------------------------------------------------------------
544
545/// Compute the design matrix Z where Z_{ij} = ∫ X_i(t) φ_j(t) dt.
546///
547/// Uses the composite trapezoidal rule over `grid`.
548fn build_z_matrix(data: &[Vec<f64>], phi: &[Vec<f64>], grid: &[f64]) -> Vec<Vec<f64>> {
549    let n_obs = data.len();
550    let n_time = grid.len();
551    let n_basis = phi.first().map(|r| r.len()).unwrap_or(0);
552
553    let mut z = vec![vec![0.0f64; n_basis]; n_obs];
554
555    for (i, xi) in data.iter().enumerate() {
556        for j in 0..n_basis {
557            // Trapezoidal rule: Σ_{t=0}^{n_time-2} (h/2)(f_t + f_{t+1})
558            let mut integral = 0.0f64;
559            for t in 0..(n_time.saturating_sub(1)) {
560                let h = grid[t + 1] - grid[t];
561                let f_t = xi[t] * phi[t][j];
562                let f_t1 = xi[t + 1] * phi[t + 1][j];
563                integral += 0.5 * h * (f_t + f_t1);
564            }
565            z[i][j] = integral;
566        }
567    }
568    z
569}
570
571// ---------------------------------------------------------------------------
572// Roughness penalty matrix
573// ---------------------------------------------------------------------------
574
575/// Second-difference roughness penalty matrix Ω = D'D.
576///
577/// D is the `(n_basis - 2) × n_basis` second-difference operator.
578pub fn roughness_penalty(n_basis: usize) -> Vec<Vec<f64>> {
579    if n_basis < 3 {
580        return vec![vec![0.0; n_basis]; n_basis];
581    }
582    let m = n_basis - 2; // number of rows in D
583                         // D_{k, k} = 1, D_{k, k+1} = -2, D_{k, k+2} = 1
584    let mut omega = vec![vec![0.0f64; n_basis]; n_basis];
585    for row in 0..m {
586        // D'D: omega[j][k] = Σ_r D[r,j] * D[r,k]
587        // D[r, r] = 1, D[r, r+1] = -2, D[r, r+2] = 1
588        // Only three non-zero entries per row of D
589        let cols = [(row, 1.0f64), (row + 1, -2.0f64), (row + 2, 1.0f64)];
590        for &(c1, v1) in &cols {
591            for &(c2, v2) in &cols {
592                omega[c1][c2] += v1 * v2;
593            }
594        }
595    }
596    omega
597}
598
599// ---------------------------------------------------------------------------
600// Penalised least squares solver
601// ---------------------------------------------------------------------------
602
603/// Solve the penalised LS system: (Z'Z + λΩ) c = Z'y.
604///
605/// Uses Cholesky decomposition via an LDL' variant that falls back to
606/// positive-definite Gaussian elimination.
607pub fn penalized_ls(
608    z: &[Vec<f64>],
609    y: &[f64],
610    omega: &[Vec<f64>],
611    lambda: f64,
612) -> StatsResult<Vec<f64>> {
613    let n = z.first().map(|r| r.len()).unwrap_or(0);
614    if n == 0 {
615        return Ok(Vec::new());
616    }
617
618    // A = Z'Z + λΩ
619    let mut a = vec![vec![0.0f64; n]; n];
620    for obs in z {
621        for j in 0..n {
622            for k in 0..n {
623                a[j][k] += obs[j] * obs[k];
624            }
625        }
626    }
627    for j in 0..n {
628        for k in 0..n {
629            a[j][k] += lambda * omega[j][k];
630        }
631    }
632
633    // b = Z'y
634    let mut b = vec![0.0f64; n];
635    for (obs, &yi) in z.iter().zip(y) {
636        for j in 0..n {
637            b[j] += obs[j] * yi;
638        }
639    }
640
641    // Solve A c = b via Gaussian elimination with partial pivoting
642    gauss_solve(&a, &b)
643}
644
645/// Gaussian elimination with partial pivoting.
646fn gauss_solve(a: &[Vec<f64>], b: &[f64]) -> StatsResult<Vec<f64>> {
647    let n = a.len();
648    if n == 0 {
649        return Ok(Vec::new());
650    }
651
652    // Augmented matrix [A | b]
653    let mut m: Vec<Vec<f64>> = a
654        .iter()
655        .zip(b.iter())
656        .map(|(row, &bi)| {
657            let mut r = row.clone();
658            r.push(bi);
659            r
660        })
661        .collect();
662
663    for col in 0..n {
664        // Find pivot
665        let pivot_row = (col..n).max_by(|&r1, &r2| {
666            m[r1][col]
667                .abs()
668                .partial_cmp(&m[r2][col].abs())
669                .unwrap_or(std::cmp::Ordering::Equal)
670        });
671        let pivot_row = pivot_row
672            .ok_or_else(|| StatsError::ComputationError("singular penalised system".to_owned()))?;
673
674        m.swap(col, pivot_row);
675
676        let pivot = m[col][col];
677        if pivot.abs() < 1e-300 {
678            return Err(StatsError::ComputationError(
679                "near-singular penalised normal equations; increase lambda".to_owned(),
680            ));
681        }
682
683        for row in (col + 1)..n {
684            let factor = m[row][col] / pivot;
685            for k in col..=n {
686                let val = m[col][k];
687                m[row][k] -= factor * val;
688            }
689        }
690    }
691
692    // Back-substitution
693    let mut x = vec![0.0f64; n];
694    for i in (0..n).rev() {
695        let mut sum = m[i][n];
696        for j in (i + 1)..n {
697            sum -= m[i][j] * x[j];
698        }
699        x[i] = sum / m[i][i];
700    }
701    Ok(x)
702}
703
704// ---------------------------------------------------------------------------
705// GCV score
706// ---------------------------------------------------------------------------
707
708/// Compute the GCV score.
709///
710/// GCV(λ) = (1/n) ||y - ŷ||² / (1 - trace(H)/n)²
711///
712/// where H = Z (Z'Z + λΩ)⁻¹ Z' (the hat matrix).
713/// We approximate `trace(H)` via the diagonal sum of H = Z A⁻¹ Z'.
714fn compute_gcv(
715    z: &[Vec<f64>],
716    y_hat: &[f64],
717    response: &[f64],
718    omega: &[Vec<f64>],
719    lambda: f64,
720    n_obs: usize,
721) -> f64 {
722    let n = z.first().map(|r| r.len()).unwrap_or(0);
723    if n == 0 || n_obs == 0 {
724        return f64::INFINITY;
725    }
726
727    // Build A = Z'Z + λΩ
728    let mut a = vec![vec![0.0f64; n]; n];
729    for obs in z {
730        for j in 0..n {
731            for k in 0..n {
732                a[j][k] += obs[j] * obs[k];
733            }
734        }
735    }
736    for j in 0..n {
737        for k in 0..n {
738            a[j][k] += lambda * omega[j][k];
739        }
740    }
741
742    // Invert A via Gauss-Jordan (small n_basis, typically ≤ 20)
743    let a_inv = match invert_matrix(&a) {
744        Ok(inv) => inv,
745        Err(_) => return f64::INFINITY,
746    };
747
748    // trace(H) = trace(Z A⁻¹ Z') = Σ_i (z_i' A⁻¹ z_i)
749    let tr_h: f64 = z
750        .iter()
751        .map(|zi| {
752            // A⁻¹ z_i
753            let az: Vec<f64> = (0..n)
754                .map(|j| (0..n).map(|k| a_inv[j][k] * zi[k]).sum::<f64>())
755                .collect();
756            // z_i' (A⁻¹ z_i)
757            zi.iter().zip(az.iter()).map(|(&v, &w)| v * w).sum::<f64>()
758        })
759        .sum();
760
761    let df_hat = tr_h / n_obs as f64;
762    if (1.0 - df_hat).abs() < 1e-10 {
763        return f64::INFINITY;
764    }
765
766    let ss_res: f64 = response
767        .iter()
768        .zip(y_hat.iter())
769        .map(|(&y, &yh)| (y - yh).powi(2))
770        .sum();
771
772    (ss_res / n_obs as f64) / (1.0 - df_hat).powi(2)
773}
774
775/// Invert an n×n matrix using Gauss-Jordan elimination.
776fn invert_matrix(a: &[Vec<f64>]) -> StatsResult<Vec<Vec<f64>>> {
777    let n = a.len();
778    // Augment with identity
779    let mut m: Vec<Vec<f64>> = a
780        .iter()
781        .enumerate()
782        .map(|(i, row)| {
783            let mut r = row.clone();
784            r.resize(2 * n, 0.0);
785            r[n + i] = 1.0;
786            r
787        })
788        .collect();
789
790    for col in 0..n {
791        // Pivot
792        let pivot_row = (col..n).max_by(|&r1, &r2| {
793            m[r1][col]
794                .abs()
795                .partial_cmp(&m[r2][col].abs())
796                .unwrap_or(std::cmp::Ordering::Equal)
797        });
798        let pivot_row =
799            pivot_row.ok_or_else(|| StatsError::ComputationError("singular matrix".to_owned()))?;
800        m.swap(col, pivot_row);
801
802        let pivot = m[col][col];
803        if pivot.abs() < 1e-300 {
804            return Err(StatsError::ComputationError("singular matrix".to_owned()));
805        }
806        let scale = 1.0 / pivot;
807        for k in 0..(2 * n) {
808            m[col][k] *= scale;
809        }
810        for row in 0..n {
811            if row != col {
812                let factor = m[row][col];
813                for k in 0..(2 * n) {
814                    let val = m[col][k];
815                    m[row][k] -= factor * val;
816                }
817            }
818        }
819    }
820
821    let inv: Vec<Vec<f64>> = m.iter().map(|row| row[n..].to_vec()).collect();
822    Ok(inv)
823}
824
825// ---------------------------------------------------------------------------
826// Utilities
827// ---------------------------------------------------------------------------
828
829/// Generate `n` equidistant points in `[start, end]`.
830fn linspace(start: f64, end: f64, n: usize) -> Vec<f64> {
831    if n == 0 {
832        return Vec::new();
833    }
834    if n == 1 {
835        return vec![start];
836    }
837    (0..n)
838        .map(|i| start + (end - start) * i as f64 / (n - 1) as f64)
839        .collect()
840}
841
842// ---------------------------------------------------------------------------
843// Tests
844// ---------------------------------------------------------------------------
845
846#[cfg(test)]
847mod tests {
848    use super::*;
849
850    // -----------------------------------------------------------------------
851    // Helpers
852    // -----------------------------------------------------------------------
853
854    /// Simple LCG random number generator to avoid external dependencies.
855    fn lcg(s: &mut u64) -> f64 {
856        *s = s
857            .wrapping_mul(6_364_136_223_846_793_005)
858            .wrapping_add(1_442_695_040_888_963_407);
859        ((*s >> 11) as f64) / ((1u64 << 53) as f64)
860    }
861
862    /// Generate `n` sample paths of a smooth function X_i(t) = a_i * f(t) + noise.
863    fn smooth_data(n_obs: usize, n_time: usize, seed: u64) -> (Vec<Vec<f64>>, Vec<f64>, Vec<f64>) {
864        let grid: Vec<f64> = (0..n_time)
865            .map(|i| i as f64 / (n_time - 1) as f64)
866            .collect();
867        let mut rng = seed;
868        let mut data = Vec::with_capacity(n_obs);
869        let mut response = Vec::with_capacity(n_obs);
870        // β(t) = sin(πt), so ∫ β(t) X_i(t) dt = a_i ∫ sin(πt) sin(2πt) dt + noise
871        // We use X_i(t) = a_i * sin(2πt) so we can compute expected ∫
872        for _ in 0..n_obs {
873            let a = lcg(&mut rng) * 2.0 - 1.0; // a in [-1,1]
874            let curve: Vec<f64> = grid
875                .iter()
876                .map(|&t| a * (2.0 * std::f64::consts::PI * t).sin())
877                .collect();
878            // ∫₀¹ sin(πt) a sin(2πt) dt = a ∫₀¹ sin(πt)sin(2πt)dt ≈ 0 (orthogonal on [0,1])
879            // Use a simpler β(t)=t: Y_i = ∫₀¹ t * a sin(2πt) dt = a ∫₀¹ t sin(2πt) dt
880            let integral: f64 = grid
881                .windows(2)
882                .map(|w| {
883                    let t0 = w[0];
884                    let t1 = w[1];
885                    let dt = t1 - t0;
886                    let f0 = t0 * a * (2.0 * std::f64::consts::PI * t0).sin();
887                    let f1 = t1 * a * (2.0 * std::f64::consts::PI * t1).sin();
888                    0.5 * dt * (f0 + f1)
889                })
890                .sum();
891            response.push(integral + (lcg(&mut rng) - 0.5) * 0.01); // tiny noise
892            data.push(curve);
893        }
894        (data, response, grid)
895    }
896
897    // -----------------------------------------------------------------------
898    // Config defaults
899    // -----------------------------------------------------------------------
900
901    #[test]
902    fn test_fof_config_default() {
903        let cfg = FofConfig::default();
904        assert_eq!(
905            cfg.basis,
906            FunctionalBasis::BSpline {
907                n_basis: 10,
908                degree: 3
909            }
910        );
911        assert!((cfg.lambda - 0.01).abs() < 1e-15);
912        assert_eq!(cfg.n_grid, 100);
913    }
914
915    // -----------------------------------------------------------------------
916    // B-spline basis
917    // -----------------------------------------------------------------------
918
919    #[test]
920    fn test_bspline_basis_partition_of_unity() {
921        // B-splines of any degree form a partition of unity: Σ_k φ_k(t) = 1
922        let grid: Vec<f64> = (0..20).map(|i| i as f64 / 19.0).collect();
923        let phi = bspline_basis(&grid, 8, 3);
924        for (t_idx, _) in grid.iter().enumerate() {
925            let s: f64 = phi[t_idx].iter().sum();
926            assert!(
927                (s - 1.0).abs() < 1e-8,
928                "partition of unity at t={t_idx}: sum={s}"
929            );
930        }
931    }
932
933    #[test]
934    fn test_bspline_basis_non_negative() {
935        let grid: Vec<f64> = (0..50).map(|i| i as f64 / 49.0).collect();
936        let phi = bspline_basis(&grid, 10, 3);
937        for row in &phi {
938            for &v in row {
939                assert!(v >= -1e-10, "negative B-spline value: {v}");
940            }
941        }
942    }
943
944    // -----------------------------------------------------------------------
945    // Roughness penalty
946    // -----------------------------------------------------------------------
947
948    #[test]
949    fn test_roughness_penalty_symmetry() {
950        let omega = roughness_penalty(8);
951        let n = omega.len();
952        for i in 0..n {
953            for j in 0..n {
954                assert!(
955                    (omega[i][j] - omega[j][i]).abs() < 1e-14,
956                    "Omega not symmetric at ({i},{j})"
957                );
958            }
959        }
960    }
961
962    #[test]
963    fn test_roughness_penalty_psd() {
964        // Ω = D'D is positive semi-definite; verify all diagonal entries ≥ 0
965        let omega = roughness_penalty(6);
966        for (i, row) in omega.iter().enumerate() {
967            assert!(row[i] >= 0.0, "negative diagonal in Omega");
968        }
969    }
970
971    // -----------------------------------------------------------------------
972    // Penalised LS
973    // -----------------------------------------------------------------------
974
975    #[test]
976    fn test_penalized_ls_identity() {
977        // With Z = I, Ω = 0, λ=0: solution should be y
978        let n = 4;
979        let z: Vec<Vec<f64>> = (0..n)
980            .map(|i| (0..n).map(|j| if i == j { 1.0 } else { 0.0 }).collect())
981            .collect();
982        let omega = vec![vec![0.0; n]; n];
983        let y = vec![1.0, 2.0, 3.0, 4.0];
984        let c = penalized_ls(&z, &y, &omega, 0.0).expect("penalized_ls failed");
985        for (ci, &yi) in c.iter().zip(y.iter()) {
986            assert!((ci - yi).abs() < 1e-10, "expected {yi}, got {ci}");
987        }
988    }
989
990    // -----------------------------------------------------------------------
991    // Fit and predict
992    // -----------------------------------------------------------------------
993
994    #[test]
995    fn test_fit_r_squared_high_on_clean_data() {
996        let (data, response, grid) = smooth_data(50, 40, 42);
997        let config = FofConfig {
998            basis: FunctionalBasis::BSpline {
999                n_basis: 8,
1000                degree: 3,
1001            },
1002            lambda: 1e-4,
1003            n_grid: 50,
1004        };
1005        let mut model = FunctionalRegression::new(config);
1006        let result = model.fit(&data, &response, &grid).expect("fit failed");
1007        // R² should be high (> 0.9) on near-noise-free data
1008        assert!(result.r_squared > 0.9, "R² too low: {}", result.r_squared);
1009    }
1010
1011    #[test]
1012    fn test_predict_length() {
1013        let (data, response, grid) = smooth_data(30, 30, 7);
1014        let config = FofConfig {
1015            basis: FunctionalBasis::BSpline {
1016                n_basis: 6,
1017                degree: 3,
1018            },
1019            lambda: 0.01,
1020            n_grid: 50,
1021        };
1022        let mut model = FunctionalRegression::new(config);
1023        model.fit(&data, &response, &grid).expect("fit failed");
1024
1025        let (new_data, _, _) = smooth_data(10, 30, 99);
1026        let preds = model.predict(&new_data, &grid).expect("predict failed");
1027        assert_eq!(preds.len(), 10, "predict length mismatch");
1028    }
1029
1030    #[test]
1031    fn test_predict_before_fit_returns_error() {
1032        let config = FofConfig::default();
1033        let model = FunctionalRegression::new(config);
1034        let grid: Vec<f64> = (0..10).map(|i| i as f64 / 9.0).collect();
1035        let data = vec![vec![0.0; 10]];
1036        let res = model.predict(&data, &grid);
1037        assert!(res.is_err(), "predict before fit should return error");
1038    }
1039
1040    #[test]
1041    fn test_fit_with_fourier_basis() {
1042        let (data, response, grid) = smooth_data(40, 40, 123);
1043        let config = FofConfig {
1044            basis: FunctionalBasis::Fourier { n_basis: 9 },
1045            lambda: 0.01,
1046            n_grid: 50,
1047        };
1048        let mut model = FunctionalRegression::new(config);
1049        let result = model.fit(&data, &response, &grid).expect("fit failed");
1050        assert!(result.r_squared >= 0.0 && result.r_squared <= 1.0 + 1e-10);
1051    }
1052
1053    #[test]
1054    fn test_fit_with_wavelet_basis() {
1055        let (data, response, grid) = smooth_data(40, 40, 456);
1056        let config = FofConfig {
1057            basis: FunctionalBasis::Wavelet { n_basis: 8 },
1058            lambda: 0.01,
1059            n_grid: 50,
1060        };
1061        let mut model = FunctionalRegression::new(config);
1062        let result = model.fit(&data, &response, &grid).expect("fit failed");
1063        assert!(result.r_squared >= 0.0, "r_squared should be non-negative");
1064    }
1065
1066    #[test]
1067    fn test_gcv_score_finite() {
1068        let (data, response, grid) = smooth_data(30, 30, 11);
1069        let config = FofConfig {
1070            basis: FunctionalBasis::BSpline {
1071                n_basis: 6,
1072                degree: 3,
1073            },
1074            lambda: 0.1,
1075            n_grid: 40,
1076        };
1077        let mut model = FunctionalRegression::new(config);
1078        let result = model.fit(&data, &response, &grid).expect("fit failed");
1079        assert!(result.gcv_score.is_finite(), "GCV should be finite");
1080        assert!(result.gcv_score >= 0.0, "GCV should be non-negative");
1081    }
1082
1083    #[test]
1084    fn test_beta_values_length() {
1085        let (data, response, grid) = smooth_data(30, 30, 13);
1086        let n_grid = 60;
1087        let config = FofConfig {
1088            basis: FunctionalBasis::BSpline {
1089                n_basis: 6,
1090                degree: 3,
1091            },
1092            lambda: 0.01,
1093            n_grid,
1094        };
1095        let mut model = FunctionalRegression::new(config);
1096        let result = model.fit(&data, &response, &grid).expect("fit failed");
1097        assert_eq!(result.beta_values.len(), n_grid);
1098        assert_eq!(result.grid.len(), n_grid);
1099    }
1100
1101    #[test]
1102    fn test_dimension_mismatch_error() {
1103        let data = vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]];
1104        let response = vec![1.0]; // wrong length
1105        let grid = vec![0.0, 0.5, 1.0];
1106        let config = FofConfig::default();
1107        let mut model = FunctionalRegression::new(config);
1108        let res = model.fit(&data, &response, &grid);
1109        assert!(res.is_err(), "should return dimension mismatch error");
1110    }
1111
1112    #[test]
1113    fn test_gcv_varies_with_lambda() {
1114        // GCV should not be constant as lambda changes
1115        let (data, response, grid) = smooth_data(40, 40, 77);
1116        let lambdas = [1e-4, 1e-2, 1.0];
1117        let mut gcv_scores = Vec::new();
1118        for &lam in &lambdas {
1119            let config = FofConfig {
1120                basis: FunctionalBasis::BSpline {
1121                    n_basis: 6,
1122                    degree: 3,
1123                },
1124                lambda: lam,
1125                n_grid: 40,
1126            };
1127            let mut model = FunctionalRegression::new(config);
1128            let result = model.fit(&data, &response, &grid).expect("fit failed");
1129            gcv_scores.push(result.gcv_score);
1130        }
1131        // At least two GCV scores should differ
1132        let all_same = gcv_scores.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-14);
1133        assert!(!all_same, "GCV should vary with lambda");
1134    }
1135}