Skip to main content

fdars_core/
smooth_basis.rs

1//! Basis-penalized smoothing with continuous derivative penalties.
2//!
3//! This module implements `smooth.basis` from R's fda package. Unlike the
4//! discrete difference penalty used in P-splines (`basis.rs`), this uses
5//! continuous derivative penalties: `min ||y - Φc||² + λ·∫(Lf)² dt`.
6//!
7//! Key capabilities:
8//! - [`smooth_basis`] — Penalized least squares with continuous roughness penalty
9//! - [`smooth_basis_gcv`] — GCV-optimal smoothing parameter selection
10//! - [`bspline_penalty_matrix`] / [`fourier_penalty_matrix`] — Roughness penalty matrices
11
12use crate::basis::{bspline_basis, fourier_basis_with_period};
13use crate::helpers::simpsons_weights;
14use crate::matrix::FdMatrix;
15use nalgebra::DMatrix;
16use std::f64::consts::PI;
17
18// ─── Types ──────────────────────────────────────────────────────────────────
19
20/// Basis type for penalized smoothing.
21#[derive(Debug, Clone)]
22pub enum BasisType {
23    /// B-spline basis with given order (typically 4 for cubic).
24    Bspline { order: usize },
25    /// Fourier basis with given period.
26    Fourier { period: f64 },
27}
28
29/// Functional data parameter object (basis + penalty specification).
30#[derive(Debug, Clone)]
31pub struct FdPar {
32    /// Type of basis system.
33    pub basis_type: BasisType,
34    /// Number of basis functions.
35    pub nbasis: usize,
36    /// Smoothing parameter.
37    pub lambda: f64,
38    /// Derivative order for the penalty (default: 2).
39    pub lfd_order: usize,
40    /// Precomputed K×K penalty matrix (column-major).
41    pub penalty_matrix: Vec<f64>,
42}
43
44/// Result of basis-penalized smoothing.
45#[derive(Debug, Clone)]
46pub struct SmoothBasisResult {
47    /// Basis coefficients (n × K).
48    pub coefficients: FdMatrix,
49    /// Fitted values (n × m).
50    pub fitted: FdMatrix,
51    /// Effective degrees of freedom.
52    pub edf: f64,
53    /// Generalized cross-validation score.
54    pub gcv: f64,
55    /// AIC.
56    pub aic: f64,
57    /// BIC.
58    pub bic: f64,
59    /// Roughness penalty matrix (K × K, column-major).
60    pub penalty_matrix: Vec<f64>,
61    /// Number of basis functions used.
62    pub nbasis: usize,
63}
64
65// ─── Penalty Matrices ───────────────────────────────────────────────────────
66
67/// Compute the roughness penalty matrix for B-splines via numerical quadrature.
68///
69/// R\[j,k\] = ∫ D^m B_j(t) · D^m B_k(t) dt
70///
71/// Uses Simpson's rule on a fine sub-grid for each knot interval.
72///
73/// # Arguments
74/// * `argvals` — Evaluation points (length m)
75/// * `nbasis` — Number of basis functions
76/// * `order` — B-spline order (typically 4 for cubic)
77/// * `lfd_order` — Derivative order for penalty (typically 2)
78///
79/// # Returns
80/// K × K penalty matrix in column-major layout (K = nbasis)
81pub fn bspline_penalty_matrix(
82    argvals: &[f64],
83    nbasis: usize,
84    order: usize,
85    lfd_order: usize,
86) -> Vec<f64> {
87    if nbasis < 2 || order < 1 || lfd_order >= order {
88        return vec![0.0; nbasis * nbasis];
89    }
90
91    let nknots = nbasis.saturating_sub(order).max(2);
92
93    // Create a fine quadrature grid (10 sub-points per original interval)
94    let n_sub = 10;
95    let t_min = argvals[0];
96    let t_max = argvals[argvals.len() - 1];
97    let n_quad = (argvals.len() - 1) * n_sub + 1;
98    let quad_t: Vec<f64> = (0..n_quad)
99        .map(|i| t_min + (t_max - t_min) * i as f64 / (n_quad - 1) as f64)
100        .collect();
101
102    // Evaluate B-spline basis on fine grid
103    let basis_fine = bspline_basis(&quad_t, nknots, order);
104    let actual_nbasis = basis_fine.len() / n_quad;
105
106    // Compute derivatives of B-spline basis numerically
107    let h = (t_max - t_min) / (n_quad - 1) as f64;
108    let deriv_basis = differentiate_basis_columns(&basis_fine, n_quad, actual_nbasis, h, lfd_order);
109
110    // Integration weights on fine grid
111    let weights = simpsons_weights(&quad_t);
112
113    // Compute penalty matrix: R[j,k] = ∫ D^m B_j · D^m B_k dt
114    integrate_symmetric_penalty(&deriv_basis, &weights, actual_nbasis, n_quad)
115}
116
117/// Compute the roughness penalty matrix for a Fourier basis.
118///
119/// For Fourier basis, the penalty is diagonal with eigenvalues `(2πk/T)^(2m)`.
120///
121/// # Arguments
122/// * `nbasis` — Number of basis functions
123/// * `period` — Period of the Fourier basis
124/// * `lfd_order` — Derivative order for penalty
125///
126/// # Returns
127/// K × K penalty matrix in column-major layout
128pub fn fourier_penalty_matrix(nbasis: usize, period: f64, lfd_order: usize) -> Vec<f64> {
129    let k = nbasis;
130    let mut penalty = vec![0.0; k * k];
131
132    // First basis function is constant → lfd_order-th derivative is 0
133    // penalty[0] = 0 (already zero)
134
135    // For sin/cos pairs: eigenvalue is (2πk/T)^(2m)
136    let mut freq = 1;
137    let mut idx = 1;
138    while idx < k {
139        let omega = 2.0 * PI * freq as f64 / period;
140        let eigenval = omega.powi(2 * lfd_order as i32);
141
142        // sin component
143        if idx < k {
144            penalty[idx + idx * k] = eigenval;
145            idx += 1;
146        }
147        // cos component
148        if idx < k {
149            penalty[idx + idx * k] = eigenval;
150            idx += 1;
151        }
152        freq += 1;
153    }
154
155    penalty
156}
157
158// ─── Smoothing Functions ────────────────────────────────────────────────────
159
160/// Perform basis-penalized smoothing.
161///
162/// Solves `(Φ'Φ + λR)c = Φ'y` per curve via Cholesky decomposition.
163/// This implements `smooth.basis` from R's fda package.
164///
165/// # Arguments
166/// * `data` — Functional data matrix (n × m)
167/// * `argvals` — Evaluation points (length m)
168/// * `fdpar` — Functional parameter object specifying basis and penalty
169///
170/// # Returns
171/// [`SmoothBasisResult`] with coefficients, fitted values, and diagnostics.
172pub fn smooth_basis(data: &FdMatrix, argvals: &[f64], fdpar: &FdPar) -> Option<SmoothBasisResult> {
173    let (n, m) = data.shape();
174    if n == 0 || m == 0 || argvals.len() != m || fdpar.nbasis < 2 {
175        return None;
176    }
177
178    // Evaluate basis on argvals
179    let (basis_flat, actual_nbasis) = evaluate_basis(argvals, &fdpar.basis_type, fdpar.nbasis);
180    let k = actual_nbasis;
181
182    let b_mat = DMatrix::from_column_slice(m, k, &basis_flat);
183    let r_mat = DMatrix::from_column_slice(k, k, &fdpar.penalty_matrix);
184
185    // (Φ'Φ + λR + εI) — small ridge ensures positive definiteness
186    let btb = b_mat.transpose() * &b_mat;
187    let ridge_eps = 1e-10;
188    let system: DMatrix<f64> =
189        &btb + fdpar.lambda * &r_mat + ridge_eps * DMatrix::<f64>::identity(k, k);
190
191    // Invert the penalized system
192    let system_inv = invert_penalized_system(&system, k)?;
193
194    // Hat matrix: H = Φ (Φ'Φ + λR)^{-1} Φ'  →  EDF = tr(H)
195    let h_mat = &b_mat * &system_inv * b_mat.transpose();
196    let edf: f64 = (0..m).map(|i| h_mat[(i, i)]).sum();
197
198    // Project all curves
199    let proj = &system_inv * b_mat.transpose();
200    let (all_coefs, all_fitted, total_rss) = project_all_curves(data, &b_mat, &proj, n, m, k);
201
202    let total_points = (n * m) as f64;
203    let gcv = compute_gcv(total_rss, total_points, edf, m);
204    let mse = total_rss / total_points;
205    let aic = total_points * mse.max(1e-300).ln() + 2.0 * edf;
206    let bic = total_points * mse.max(1e-300).ln() + total_points.ln() * edf;
207
208    Some(SmoothBasisResult {
209        coefficients: all_coefs,
210        fitted: all_fitted,
211        edf,
212        gcv,
213        aic,
214        bic,
215        penalty_matrix: fdpar.penalty_matrix.clone(),
216        nbasis: k,
217    })
218}
219
220/// Perform basis-penalized smoothing with GCV-optimal lambda.
221///
222/// Searches over a log-lambda grid and selects the lambda minimizing GCV.
223///
224/// # Arguments
225/// * `data` — Functional data matrix (n × m)
226/// * `argvals` — Evaluation points (length m)
227/// * `basis_type` — Type of basis system
228/// * `nbasis` — Number of basis functions
229/// * `lfd_order` — Derivative order for penalty
230/// * `log_lambda_range` — Range of log10(lambda) to search, e.g. (-8.0, 4.0)
231/// * `n_grid` — Number of grid points for the search
232pub fn smooth_basis_gcv(
233    data: &FdMatrix,
234    argvals: &[f64],
235    basis_type: &BasisType,
236    nbasis: usize,
237    lfd_order: usize,
238    log_lambda_range: (f64, f64),
239    n_grid: usize,
240) -> Option<SmoothBasisResult> {
241    let m = argvals.len();
242    if m == 0 || nbasis < 2 || n_grid < 2 {
243        return None;
244    }
245
246    // Compute penalty matrix once
247    let penalty = match basis_type {
248        BasisType::Bspline { order } => bspline_penalty_matrix(argvals, nbasis, *order, lfd_order),
249        BasisType::Fourier { period } => fourier_penalty_matrix(nbasis, *period, lfd_order),
250    };
251
252    let (lo, hi) = log_lambda_range;
253    let mut best_gcv = f64::INFINITY;
254    let mut best_result: Option<SmoothBasisResult> = None;
255
256    for i in 0..n_grid {
257        let log_lam = lo + (hi - lo) * i as f64 / (n_grid - 1) as f64;
258        let lam = 10.0_f64.powf(log_lam);
259
260        let fdpar = FdPar {
261            basis_type: basis_type.clone(),
262            nbasis,
263            lambda: lam,
264            lfd_order,
265            penalty_matrix: penalty.clone(),
266        };
267
268        if let Some(result) = smooth_basis(data, argvals, &fdpar) {
269            if result.gcv < best_gcv {
270                best_gcv = result.gcv;
271                best_result = Some(result);
272            }
273        }
274    }
275
276    best_result
277}
278
279// ─── Internal Helpers ───────────────────────────────────────────────────────
280
281/// Differentiate column-major basis matrix `lfd_order` times using gradient_uniform.
282fn differentiate_basis_columns(
283    basis: &[f64],
284    n_quad: usize,
285    nbasis: usize,
286    h: f64,
287    lfd_order: usize,
288) -> Vec<f64> {
289    let mut deriv = basis.to_vec();
290    for _ in 0..lfd_order {
291        let mut new_deriv = vec![0.0; n_quad * nbasis];
292        for j in 0..nbasis {
293            let col: Vec<f64> = (0..n_quad).map(|i| deriv[i + j * n_quad]).collect();
294            let grad = crate::helpers::gradient_uniform(&col, h);
295            for i in 0..n_quad {
296                new_deriv[i + j * n_quad] = grad[i];
297            }
298        }
299        deriv = new_deriv;
300    }
301    deriv
302}
303
304/// Integrate symmetric penalty: R[j,k] = ∫ D^m B_j · D^m B_k dt.
305fn integrate_symmetric_penalty(
306    deriv_basis: &[f64],
307    weights: &[f64],
308    k: usize,
309    n_quad: usize,
310) -> Vec<f64> {
311    let mut penalty = vec![0.0; k * k];
312    for j in 0..k {
313        for l in j..k {
314            let mut val = 0.0;
315            for i in 0..n_quad {
316                val += deriv_basis[i + j * n_quad] * deriv_basis[i + l * n_quad] * weights[i];
317            }
318            penalty[j + l * k] = val;
319            penalty[l + j * k] = val;
320        }
321    }
322    penalty
323}
324
325/// Evaluate basis functions on argvals, returning (flat column-major, actual_nbasis).
326fn evaluate_basis(argvals: &[f64], basis_type: &BasisType, nbasis: usize) -> (Vec<f64>, usize) {
327    let m = argvals.len();
328    match basis_type {
329        BasisType::Bspline { order } => {
330            let nknots = nbasis.saturating_sub(*order).max(2);
331            let basis = bspline_basis(argvals, nknots, *order);
332            let actual = basis.len() / m;
333            (basis, actual)
334        }
335        BasisType::Fourier { period } => {
336            let basis = fourier_basis_with_period(argvals, nbasis, *period);
337            (basis, nbasis)
338        }
339    }
340}
341
342/// Invert the penalized system matrix via Cholesky or SVD pseudoinverse.
343fn invert_penalized_system(system: &DMatrix<f64>, k: usize) -> Option<DMatrix<f64>> {
344    if let Some(chol) = system.clone().cholesky() {
345        return Some(chol.inverse());
346    }
347    // SVD fallback
348    let svd = nalgebra::SVD::new(system.clone(), true, true);
349    let u = svd.u.as_ref()?;
350    let v_t = svd.v_t.as_ref()?;
351    let max_sv: f64 = svd.singular_values.iter().cloned().fold(0.0_f64, f64::max);
352    let eps = 1e-10 * max_sv;
353    let mut inv = DMatrix::<f64>::zeros(k, k);
354    for ii in 0..k {
355        for jj in 0..k {
356            let mut sum = 0.0;
357            for s in 0..k.min(svd.singular_values.len()) {
358                if svd.singular_values[s] > eps {
359                    sum += v_t[(s, ii)] / svd.singular_values[s] * u[(jj, s)];
360                }
361            }
362            inv[(ii, jj)] = sum;
363        }
364    }
365    Some(inv)
366}
367
368/// Project all curves onto basis, returning (coefficients, fitted, total_rss).
369fn project_all_curves(
370    data: &FdMatrix,
371    b_mat: &DMatrix<f64>,
372    proj: &DMatrix<f64>,
373    n: usize,
374    m: usize,
375    k: usize,
376) -> (FdMatrix, FdMatrix, f64) {
377    let mut all_coefs = FdMatrix::zeros(n, k);
378    let mut all_fitted = FdMatrix::zeros(n, m);
379    let mut total_rss = 0.0;
380
381    for i in 0..n {
382        let curve: Vec<f64> = (0..m).map(|j| data[(i, j)]).collect();
383        let y_vec = nalgebra::DVector::from_vec(curve.clone());
384        let coefs = proj * &y_vec;
385
386        for j in 0..k {
387            all_coefs[(i, j)] = coefs[j];
388        }
389        let fitted = b_mat * &coefs;
390        for j in 0..m {
391            all_fitted[(i, j)] = fitted[j];
392            let resid = curve[j] - fitted[j];
393            total_rss += resid * resid;
394        }
395    }
396
397    (all_coefs, all_fitted, total_rss)
398}
399
400/// Compute GCV score.
401fn compute_gcv(rss: f64, n_points: f64, edf: f64, m: usize) -> f64 {
402    let gcv_denom = 1.0 - edf / m as f64;
403    if gcv_denom.abs() > 1e-10 {
404        (rss / n_points) / (gcv_denom * gcv_denom)
405    } else {
406        f64::INFINITY
407    }
408}
409
410#[cfg(test)]
411mod tests {
412    use super::*;
413    use std::f64::consts::PI;
414
415    fn uniform_grid(m: usize) -> Vec<f64> {
416        (0..m).map(|i| i as f64 / (m - 1) as f64).collect()
417    }
418
419    #[test]
420    fn test_bspline_penalty_matrix_symmetric() {
421        let t = uniform_grid(101);
422        let penalty = bspline_penalty_matrix(&t, 15, 4, 2);
423        let _k = 15; // may differ from actual due to knot construction
424        let actual_k = (penalty.len() as f64).sqrt() as usize;
425        for i in 0..actual_k {
426            for j in 0..actual_k {
427                assert!(
428                    (penalty[i + j * actual_k] - penalty[j + i * actual_k]).abs() < 1e-10,
429                    "Penalty matrix not symmetric at ({}, {})",
430                    i,
431                    j
432                );
433            }
434        }
435    }
436
437    #[test]
438    fn test_bspline_penalty_matrix_positive_semidefinite() {
439        let t = uniform_grid(101);
440        let penalty = bspline_penalty_matrix(&t, 10, 4, 2);
441        let k = (penalty.len() as f64).sqrt() as usize;
442        // Diagonal elements should be non-negative
443        for i in 0..k {
444            assert!(
445                penalty[i + i * k] >= -1e-10,
446                "Diagonal element {} is negative: {}",
447                i,
448                penalty[i + i * k]
449            );
450        }
451    }
452
453    #[test]
454    fn test_fourier_penalty_diagonal() {
455        let penalty = fourier_penalty_matrix(7, 1.0, 2);
456        // Should be diagonal
457        for i in 0..7 {
458            for j in 0..7 {
459                if i != j {
460                    assert!(
461                        penalty[i + j * 7].abs() < 1e-10,
462                        "Off-diagonal ({},{}) = {}",
463                        i,
464                        j,
465                        penalty[i + j * 7]
466                    );
467                }
468            }
469        }
470        // Constant term should have zero penalty
471        assert!(penalty[0].abs() < 1e-10);
472        // Higher frequency terms should have larger penalties
473        assert!(penalty[1 + 7] > 0.0);
474        assert!(penalty[3 + 3 * 7] > penalty[1 + 7]);
475    }
476
477    #[test]
478    fn test_smooth_basis_bspline() {
479        let m = 101;
480        let n = 5;
481        let t = uniform_grid(m);
482
483        // Generate noisy sine curves
484        let mut data = FdMatrix::zeros(n, m);
485        for i in 0..n {
486            for j in 0..m {
487                data[(i, j)] = (2.0 * PI * t[j]).sin() + 0.1 * (i as f64 * 0.3 + j as f64 * 0.01);
488            }
489        }
490
491        let nbasis = 15;
492        let penalty = bspline_penalty_matrix(&t, nbasis, 4, 2);
493        let _actual_k = (penalty.len() as f64).sqrt() as usize;
494
495        let fdpar = FdPar {
496            basis_type: BasisType::Bspline { order: 4 },
497            nbasis,
498            lambda: 1e-4,
499            lfd_order: 2,
500            penalty_matrix: penalty,
501        };
502
503        let result = smooth_basis(&data, &t, &fdpar);
504        assert!(result.is_some(), "smooth_basis should succeed");
505
506        let res = result.unwrap();
507        assert_eq!(res.fitted.shape(), (n, m));
508        assert_eq!(res.coefficients.nrows(), n);
509        assert!(res.edf > 0.0, "EDF should be positive");
510        assert!(res.gcv > 0.0, "GCV should be positive");
511    }
512
513    #[test]
514    fn test_smooth_basis_fourier() {
515        let m = 101;
516        let n = 3;
517        let t = uniform_grid(m);
518
519        let mut data = FdMatrix::zeros(n, m);
520        for i in 0..n {
521            for j in 0..m {
522                data[(i, j)] = (2.0 * PI * t[j]).sin() + (4.0 * PI * t[j]).cos();
523            }
524        }
525
526        let nbasis = 7;
527        let period = 1.0;
528        let penalty = fourier_penalty_matrix(nbasis, period, 2);
529
530        let fdpar = FdPar {
531            basis_type: BasisType::Fourier { period },
532            nbasis,
533            lambda: 1e-6,
534            lfd_order: 2,
535            penalty_matrix: penalty,
536        };
537
538        let result = smooth_basis(&data, &t, &fdpar);
539        assert!(result.is_some());
540
541        let res = result.unwrap();
542        // Fourier basis should fit periodic data well
543        for j in 0..m {
544            let expected = (2.0 * PI * t[j]).sin() + (4.0 * PI * t[j]).cos();
545            assert!(
546                (res.fitted[(0, j)] - expected).abs() < 0.1,
547                "Fourier fit poor at j={}: got {}, expected {}",
548                j,
549                res.fitted[(0, j)],
550                expected
551            );
552        }
553    }
554
555    #[test]
556    fn test_smooth_basis_gcv_selects_reasonable_lambda() {
557        let m = 101;
558        let n = 5;
559        let t = uniform_grid(m);
560
561        let mut data = FdMatrix::zeros(n, m);
562        for i in 0..n {
563            for j in 0..m {
564                data[(i, j)] =
565                    (2.0 * PI * t[j]).sin() + 0.1 * ((i * 37 + j * 13) % 20) as f64 / 20.0;
566            }
567        }
568
569        let basis_type = BasisType::Bspline { order: 4 };
570        let result = smooth_basis_gcv(&data, &t, &basis_type, 15, 2, (-8.0, 4.0), 25);
571        assert!(result.is_some(), "GCV search should succeed");
572    }
573
574    #[test]
575    fn test_smooth_basis_large_lambda_reduces_edf() {
576        let m = 101;
577        let n = 3;
578        let t = uniform_grid(m);
579
580        let mut data = FdMatrix::zeros(n, m);
581        for i in 0..n {
582            for j in 0..m {
583                data[(i, j)] = (2.0 * PI * t[j]).sin();
584            }
585        }
586
587        let nbasis = 15;
588        let penalty = bspline_penalty_matrix(&t, nbasis, 4, 2);
589        let _actual_k = (penalty.len() as f64).sqrt() as usize;
590
591        let fdpar_small = FdPar {
592            basis_type: BasisType::Bspline { order: 4 },
593            nbasis,
594            lambda: 1e-8,
595            lfd_order: 2,
596            penalty_matrix: penalty.clone(),
597        };
598        let fdpar_large = FdPar {
599            basis_type: BasisType::Bspline { order: 4 },
600            nbasis,
601            lambda: 1e2,
602            lfd_order: 2,
603            penalty_matrix: penalty,
604        };
605
606        let res_small = smooth_basis(&data, &t, &fdpar_small).unwrap();
607        let res_large = smooth_basis(&data, &t, &fdpar_large).unwrap();
608
609        assert!(
610            res_large.edf < res_small.edf,
611            "Larger lambda should reduce EDF: {} vs {}",
612            res_large.edf,
613            res_small.edf
614        );
615    }
616}