Skip to main content

fdars_core/
smooth_basis.rs

1//! Basis-penalized smoothing with continuous derivative penalties.
2//!
3//! This module implements `smooth.basis` from R's fda package. Unlike the
4//! discrete difference penalty used in P-splines (`basis.rs`), this uses
5//! continuous derivative penalties: `min ||y - Φc||² + λ·∫(Lf)² dt`.
6//!
7//! Key capabilities:
8//! - [`smooth_basis`] — Penalized least squares with continuous roughness penalty
9//! - [`smooth_basis_gcv`] — GCV-optimal smoothing parameter selection
10//! - [`bspline_penalty_matrix`] / [`fourier_penalty_matrix`] — Roughness penalty matrices
11
12use crate::basis::{bspline_basis, fourier_basis_with_period};
13use crate::helpers::simpsons_weights;
14use crate::matrix::FdMatrix;
15use nalgebra::DMatrix;
16use std::f64::consts::PI;
17
18// ─── Types ──────────────────────────────────────────────────────────────────
19
20/// Basis type for penalized smoothing.
21#[derive(Debug, Clone)]
22pub enum BasisType {
23    /// B-spline basis with given order (typically 4 for cubic).
24    Bspline { order: usize },
25    /// Fourier basis with given period.
26    Fourier { period: f64 },
27}
28
29/// Functional data parameter object (basis + penalty specification).
30#[derive(Debug, Clone)]
31pub struct FdPar {
32    /// Type of basis system.
33    pub basis_type: BasisType,
34    /// Number of basis functions.
35    pub nbasis: usize,
36    /// Smoothing parameter.
37    pub lambda: f64,
38    /// Derivative order for the penalty (default: 2).
39    pub lfd_order: usize,
40    /// Precomputed K×K penalty matrix (column-major).
41    pub penalty_matrix: Vec<f64>,
42}
43
44/// Result of basis-penalized smoothing.
45#[derive(Debug, Clone)]
46pub struct SmoothBasisResult {
47    /// Basis coefficients (n × K).
48    pub coefficients: FdMatrix,
49    /// Fitted values (n × m).
50    pub fitted: FdMatrix,
51    /// Effective degrees of freedom.
52    pub edf: f64,
53    /// Generalized cross-validation score.
54    pub gcv: f64,
55    /// AIC.
56    pub aic: f64,
57    /// BIC.
58    pub bic: f64,
59    /// Roughness penalty matrix (K × K, column-major).
60    pub penalty_matrix: Vec<f64>,
61    /// Number of basis functions used.
62    pub nbasis: usize,
63}
64
65// ─── Penalty Matrices ───────────────────────────────────────────────────────
66
67/// Compute the roughness penalty matrix for B-splines via numerical quadrature.
68///
69/// R\[j,k\] = ∫ D^m B_j(t) · D^m B_k(t) dt
70///
71/// Uses Simpson's rule on a fine sub-grid for each knot interval.
72///
73/// # Arguments
74/// * `argvals` — Evaluation points (length m)
75/// * `nbasis` — Number of basis functions
76/// * `order` — B-spline order (typically 4 for cubic)
77/// * `lfd_order` — Derivative order for penalty (typically 2)
78///
79/// # Returns
80/// K × K penalty matrix in column-major layout (K = nbasis)
81pub fn bspline_penalty_matrix(
82    argvals: &[f64],
83    nbasis: usize,
84    order: usize,
85    lfd_order: usize,
86) -> Vec<f64> {
87    if nbasis < 2 || order < 1 || lfd_order >= order || argvals.len() < 2 {
88        return vec![0.0; nbasis * nbasis];
89    }
90
91    let nknots = nbasis.saturating_sub(order).max(2);
92
93    // Create a fine quadrature grid (10 sub-points per original interval)
94    let n_sub = 10;
95    let t_min = argvals[0];
96    let t_max = argvals[argvals.len() - 1];
97    let n_quad = (argvals.len() - 1) * n_sub + 1;
98    let quad_t: Vec<f64> = (0..n_quad)
99        .map(|i| t_min + (t_max - t_min) * i as f64 / (n_quad - 1) as f64)
100        .collect();
101
102    // Evaluate B-spline basis on fine grid
103    let basis_fine = bspline_basis(&quad_t, nknots, order);
104    let actual_nbasis = basis_fine.len() / n_quad;
105
106    // Compute derivatives of B-spline basis numerically
107    let h = (t_max - t_min) / (n_quad - 1) as f64;
108    let deriv_basis = differentiate_basis_columns(&basis_fine, n_quad, actual_nbasis, h, lfd_order);
109
110    // Integration weights on fine grid
111    let weights = simpsons_weights(&quad_t);
112
113    // Compute penalty matrix: R[j,k] = ∫ D^m B_j · D^m B_k dt
114    integrate_symmetric_penalty(&deriv_basis, &weights, actual_nbasis, n_quad)
115}
116
117/// Compute the roughness penalty matrix for a Fourier basis.
118///
119/// For Fourier basis, the penalty is diagonal with eigenvalues `(2πk/T)^(2m)`.
120///
121/// # Arguments
122/// * `nbasis` — Number of basis functions
123/// * `period` — Period of the Fourier basis
124/// * `lfd_order` — Derivative order for penalty
125///
126/// # Returns
127/// K × K penalty matrix in column-major layout
128pub fn fourier_penalty_matrix(nbasis: usize, period: f64, lfd_order: usize) -> Vec<f64> {
129    let k = nbasis;
130    let mut penalty = vec![0.0; k * k];
131
132    // First basis function is constant → lfd_order-th derivative is 0
133    // penalty[0] = 0 (already zero)
134
135    // For sin/cos pairs: eigenvalue is (2πk/T)^(2m)
136    // Matches R's fda package convention (sqrt(2)-normalized basis)
137    let mut freq = 1;
138    let mut idx = 1;
139    while idx < k {
140        let omega = 2.0 * PI * freq as f64 / period;
141        let eigenval = omega.powi(2 * lfd_order as i32);
142
143        // sin component
144        if idx < k {
145            penalty[idx + idx * k] = eigenval;
146            idx += 1;
147        }
148        // cos component
149        if idx < k {
150            penalty[idx + idx * k] = eigenval;
151            idx += 1;
152        }
153        freq += 1;
154    }
155
156    penalty
157}
158
159// ─── Smoothing Functions ────────────────────────────────────────────────────
160
161/// Perform basis-penalized smoothing.
162///
163/// Solves `(Φ'Φ + λR)c = Φ'y` per curve via Cholesky decomposition.
164/// This implements `smooth.basis` from R's fda package.
165///
166/// # Arguments
167/// * `data` — Functional data matrix (n × m)
168/// * `argvals` — Evaluation points (length m)
169/// * `fdpar` — Functional parameter object specifying basis and penalty
170///
171/// # Returns
172/// [`SmoothBasisResult`] with coefficients, fitted values, and diagnostics.
173pub fn smooth_basis(data: &FdMatrix, argvals: &[f64], fdpar: &FdPar) -> Option<SmoothBasisResult> {
174    let (n, m) = data.shape();
175    if n == 0 || m == 0 || argvals.len() != m || fdpar.nbasis < 2 {
176        return None;
177    }
178
179    // Evaluate basis on argvals
180    let (basis_flat, actual_nbasis) = evaluate_basis(argvals, &fdpar.basis_type, fdpar.nbasis);
181    let k = actual_nbasis;
182
183    let b_mat = DMatrix::from_column_slice(m, k, &basis_flat);
184    let r_mat = DMatrix::from_column_slice(k, k, &fdpar.penalty_matrix);
185
186    // (Φ'Φ + λR + εI) — small ridge ensures positive definiteness
187    let btb = b_mat.transpose() * &b_mat;
188    let ridge_eps = 1e-10;
189    let system: DMatrix<f64> =
190        &btb + fdpar.lambda * &r_mat + ridge_eps * DMatrix::<f64>::identity(k, k);
191
192    // Invert the penalized system
193    let system_inv = invert_penalized_system(&system, k)?;
194
195    // Hat matrix: H = Φ (Φ'Φ + λR)^{-1} Φ'  →  EDF = tr(H)
196    let h_mat = &b_mat * &system_inv * b_mat.transpose();
197    let edf: f64 = (0..m).map(|i| h_mat[(i, i)]).sum();
198
199    // Project all curves
200    let proj = &system_inv * b_mat.transpose();
201    let (all_coefs, all_fitted, total_rss) = project_all_curves(data, &b_mat, &proj, n, m, k);
202
203    let total_points = (n * m) as f64;
204    let gcv = compute_gcv(total_rss, total_points, edf, m);
205    let mse = total_rss / total_points;
206    // Total effective degrees of freedom = n curves * per-curve edf
207    let total_edf = n as f64 * edf;
208    let aic = total_points * mse.max(1e-300).ln() + 2.0 * total_edf;
209    let bic = total_points * mse.max(1e-300).ln() + total_points.ln() * total_edf;
210
211    Some(SmoothBasisResult {
212        coefficients: all_coefs,
213        fitted: all_fitted,
214        edf,
215        gcv,
216        aic,
217        bic,
218        penalty_matrix: fdpar.penalty_matrix.clone(),
219        nbasis: k,
220    })
221}
222
223/// Perform basis-penalized smoothing with GCV-optimal lambda.
224///
225/// Searches over a log-lambda grid and selects the lambda minimizing GCV.
226///
227/// # Arguments
228/// * `data` — Functional data matrix (n × m)
229/// * `argvals` — Evaluation points (length m)
230/// * `basis_type` — Type of basis system
231/// * `nbasis` — Number of basis functions
232/// * `lfd_order` — Derivative order for penalty
233/// * `log_lambda_range` — Range of log10(lambda) to search, e.g. (-8.0, 4.0)
234/// * `n_grid` — Number of grid points for the search
235pub fn smooth_basis_gcv(
236    data: &FdMatrix,
237    argvals: &[f64],
238    basis_type: &BasisType,
239    nbasis: usize,
240    lfd_order: usize,
241    log_lambda_range: (f64, f64),
242    n_grid: usize,
243) -> Option<SmoothBasisResult> {
244    let m = argvals.len();
245    if m == 0 || nbasis < 2 || n_grid < 2 {
246        return None;
247    }
248
249    // Compute penalty matrix once
250    let penalty = match basis_type {
251        BasisType::Bspline { order } => bspline_penalty_matrix(argvals, nbasis, *order, lfd_order),
252        BasisType::Fourier { period } => fourier_penalty_matrix(nbasis, *period, lfd_order),
253    };
254
255    let (lo, hi) = log_lambda_range;
256    let mut best_gcv = f64::INFINITY;
257    let mut best_result: Option<SmoothBasisResult> = None;
258
259    for i in 0..n_grid {
260        let log_lam = lo + (hi - lo) * i as f64 / (n_grid - 1) as f64;
261        let lam = 10.0_f64.powf(log_lam);
262
263        let fdpar = FdPar {
264            basis_type: basis_type.clone(),
265            nbasis,
266            lambda: lam,
267            lfd_order,
268            penalty_matrix: penalty.clone(),
269        };
270
271        if let Some(result) = smooth_basis(data, argvals, &fdpar) {
272            if result.gcv < best_gcv {
273                best_gcv = result.gcv;
274                best_result = Some(result);
275            }
276        }
277    }
278
279    best_result
280}
281
282// ─── Internal Helpers ───────────────────────────────────────────────────────
283
284/// Differentiate column-major basis matrix `lfd_order` times using gradient_uniform.
285fn differentiate_basis_columns(
286    basis: &[f64],
287    n_quad: usize,
288    nbasis: usize,
289    h: f64,
290    lfd_order: usize,
291) -> Vec<f64> {
292    let mut deriv = basis.to_vec();
293    for _ in 0..lfd_order {
294        let mut new_deriv = vec![0.0; n_quad * nbasis];
295        for j in 0..nbasis {
296            let col: Vec<f64> = (0..n_quad).map(|i| deriv[i + j * n_quad]).collect();
297            let grad = crate::helpers::gradient_uniform(&col, h);
298            for i in 0..n_quad {
299                new_deriv[i + j * n_quad] = grad[i];
300            }
301        }
302        deriv = new_deriv;
303    }
304    deriv
305}
306
307/// Integrate symmetric penalty: R[j,k] = ∫ D^m B_j · D^m B_k dt.
308fn integrate_symmetric_penalty(
309    deriv_basis: &[f64],
310    weights: &[f64],
311    k: usize,
312    n_quad: usize,
313) -> Vec<f64> {
314    let mut penalty = vec![0.0; k * k];
315    for j in 0..k {
316        for l in j..k {
317            let mut val = 0.0;
318            for i in 0..n_quad {
319                val += deriv_basis[i + j * n_quad] * deriv_basis[i + l * n_quad] * weights[i];
320            }
321            penalty[j + l * k] = val;
322            penalty[l + j * k] = val;
323        }
324    }
325    penalty
326}
327
328/// Evaluate basis functions on argvals, returning (flat column-major, actual_nbasis).
329fn evaluate_basis(argvals: &[f64], basis_type: &BasisType, nbasis: usize) -> (Vec<f64>, usize) {
330    let m = argvals.len();
331    match basis_type {
332        BasisType::Bspline { order } => {
333            let nknots = nbasis.saturating_sub(*order).max(2);
334            let basis = bspline_basis(argvals, nknots, *order);
335            let actual = basis.len() / m;
336            (basis, actual)
337        }
338        BasisType::Fourier { period } => {
339            let basis = fourier_basis_with_period(argvals, nbasis, *period);
340            (basis, nbasis)
341        }
342    }
343}
344
345/// Invert the penalized system matrix via Cholesky or SVD pseudoinverse.
346fn invert_penalized_system(system: &DMatrix<f64>, k: usize) -> Option<DMatrix<f64>> {
347    if let Some(chol) = system.clone().cholesky() {
348        return Some(chol.inverse());
349    }
350    // SVD fallback
351    let svd = nalgebra::SVD::new(system.clone(), true, true);
352    let u = svd.u.as_ref()?;
353    let v_t = svd.v_t.as_ref()?;
354    let max_sv: f64 = svd.singular_values.iter().cloned().fold(0.0_f64, f64::max);
355    let eps = 1e-10 * max_sv;
356    let mut inv = DMatrix::<f64>::zeros(k, k);
357    for ii in 0..k {
358        for jj in 0..k {
359            let mut sum = 0.0;
360            for s in 0..k.min(svd.singular_values.len()) {
361                if svd.singular_values[s] > eps {
362                    sum += v_t[(s, ii)] / svd.singular_values[s] * u[(jj, s)];
363                }
364            }
365            inv[(ii, jj)] = sum;
366        }
367    }
368    Some(inv)
369}
370
371/// Project all curves onto basis, returning (coefficients, fitted, total_rss).
372fn project_all_curves(
373    data: &FdMatrix,
374    b_mat: &DMatrix<f64>,
375    proj: &DMatrix<f64>,
376    n: usize,
377    m: usize,
378    k: usize,
379) -> (FdMatrix, FdMatrix, f64) {
380    let mut all_coefs = FdMatrix::zeros(n, k);
381    let mut all_fitted = FdMatrix::zeros(n, m);
382    let mut total_rss = 0.0;
383
384    for i in 0..n {
385        let curve: Vec<f64> = (0..m).map(|j| data[(i, j)]).collect();
386        let y_vec = nalgebra::DVector::from_vec(curve.clone());
387        let coefs = proj * &y_vec;
388
389        for j in 0..k {
390            all_coefs[(i, j)] = coefs[j];
391        }
392        let fitted = b_mat * &coefs;
393        for j in 0..m {
394            all_fitted[(i, j)] = fitted[j];
395            let resid = curve[j] - fitted[j];
396            total_rss += resid * resid;
397        }
398    }
399
400    (all_coefs, all_fitted, total_rss)
401}
402
403/// Compute GCV score.
404fn compute_gcv(rss: f64, n_points: f64, edf: f64, m: usize) -> f64 {
405    let gcv_denom = 1.0 - edf / m as f64;
406    if gcv_denom.abs() > 1e-10 {
407        (rss / n_points) / (gcv_denom * gcv_denom)
408    } else {
409        f64::INFINITY
410    }
411}
412
413// ─── Nbasis Selection via CV ────────────────────────────────────────────────
414
415/// Criterion for nbasis selection.
416#[derive(Debug, Clone, Copy, PartialEq)]
417pub enum BasisCriterion {
418    /// Generalized cross-validation.
419    Gcv,
420    /// Leave-one-out cross-validation (k-fold).
421    Cv,
422    /// Akaike Information Criterion.
423    Aic,
424    /// Bayesian Information Criterion.
425    Bic,
426}
427
428/// Result of nbasis selection.
429#[derive(Debug, Clone)]
430pub struct BasisNbasisCvResult {
431    /// Optimal number of basis functions.
432    pub optimal_nbasis: usize,
433    /// Score for each nbasis tested.
434    pub scores: Vec<f64>,
435    /// Range of nbasis values tested.
436    pub nbasis_range: Vec<usize>,
437    /// Criterion used.
438    pub criterion: BasisCriterion,
439}
440
441/// Select the optimal number of basis functions using multiple criteria
442/// (R's `fdata2basis_cv`).
443///
444/// For `Gcv`, `Aic`, `Bic`: loops over nbasis values, calling `smooth_basis`
445/// and extracting the corresponding score.
446///
447/// For `Cv`: k-fold cross-validation. For each nbasis, for each fold,
448/// projects training curves to basis and measures reconstruction error on test curves.
449///
450/// # Arguments
451/// * `data` — Functional data matrix (n × m)
452/// * `argvals` — Evaluation points (length m)
453/// * `nbasis_range` — Candidate nbasis values to test (e.g., `&[4, 6, 8, ..., 20]`)
454/// * `basis_type` — Basis type
455/// * `criterion` — Selection criterion
456/// * `n_folds` — Number of folds (only used when criterion = Cv)
457/// * `lambda` — Smoothing parameter (default 0)
458pub fn basis_nbasis_cv(
459    data: &FdMatrix,
460    argvals: &[f64],
461    nbasis_range: &[usize],
462    basis_type: &BasisType,
463    criterion: BasisCriterion,
464    n_folds: usize,
465    lambda: f64,
466) -> Option<BasisNbasisCvResult> {
467    let (n, m) = data.shape();
468    if n == 0 || m == 0 || argvals.len() != m || nbasis_range.is_empty() {
469        return None;
470    }
471
472    let mut scores = Vec::with_capacity(nbasis_range.len());
473
474    match criterion {
475        BasisCriterion::Gcv | BasisCriterion::Aic | BasisCriterion::Bic => {
476            for &nb in nbasis_range {
477                if nb < 2 {
478                    scores.push(f64::INFINITY);
479                    continue;
480                }
481
482                let penalty = match basis_type {
483                    BasisType::Bspline { order } => bspline_penalty_matrix(argvals, nb, *order, 2),
484                    BasisType::Fourier { period } => fourier_penalty_matrix(nb, *period, 2),
485                };
486
487                let fdpar = FdPar {
488                    basis_type: basis_type.clone(),
489                    nbasis: nb,
490                    lambda,
491                    lfd_order: 2,
492                    penalty_matrix: penalty,
493                };
494
495                match smooth_basis(data, argvals, &fdpar) {
496                    Some(result) => {
497                        let score = match criterion {
498                            BasisCriterion::Gcv => result.gcv,
499                            BasisCriterion::Aic => result.aic,
500                            BasisCriterion::Bic => result.bic,
501                            _ => unreachable!(),
502                        };
503                        scores.push(score);
504                    }
505                    None => scores.push(f64::INFINITY),
506                }
507            }
508        }
509        BasisCriterion::Cv => {
510            let folds = crate::cv::create_folds(n, n_folds.max(2), 42);
511
512            for &nb in nbasis_range {
513                if nb < 2 {
514                    scores.push(f64::INFINITY);
515                    continue;
516                }
517
518                let penalty = match basis_type {
519                    BasisType::Bspline { order } => bspline_penalty_matrix(argvals, nb, *order, 2),
520                    BasisType::Fourier { period } => fourier_penalty_matrix(nb, *period, 2),
521                };
522
523                let mut total_mse = 0.0;
524                let mut total_count = 0;
525
526                for fold in 0..n_folds.max(2) {
527                    let (train_idx, test_idx) = crate::cv::fold_indices(&folds, fold);
528                    if train_idx.is_empty() || test_idx.is_empty() {
529                        continue;
530                    }
531
532                    let train_data = crate::cv::subset_rows(data, &train_idx);
533
534                    let fdpar = FdPar {
535                        basis_type: basis_type.clone(),
536                        nbasis: nb,
537                        lambda,
538                        lfd_order: 2,
539                        penalty_matrix: penalty.clone(),
540                    };
541
542                    if let Some(train_result) = smooth_basis(&train_data, argvals, &fdpar) {
543                        // Evaluate basis
544                        let (basis_flat, actual_k) = evaluate_basis(argvals, basis_type, nb);
545                        let b_mat = DMatrix::from_column_slice(m, actual_k, &basis_flat);
546
547                        // For each test curve, project to basis using training coefficients
548                        // and measure reconstruction error
549                        let r_mat = DMatrix::from_column_slice(
550                            actual_k,
551                            actual_k,
552                            &train_result.penalty_matrix,
553                        );
554                        let btb = b_mat.transpose() * &b_mat;
555                        let ridge_eps = 1e-10;
556                        let system: DMatrix<f64> = &btb
557                            + lambda * &r_mat
558                            + ridge_eps * DMatrix::<f64>::identity(actual_k, actual_k);
559
560                        if let Some(system_inv) = invert_penalized_system(&system, actual_k) {
561                            let proj = &system_inv * b_mat.transpose();
562
563                            for &ti in &test_idx {
564                                let curve: Vec<f64> = (0..m).map(|j| data[(ti, j)]).collect();
565                                let y_vec = nalgebra::DVector::from_vec(curve.clone());
566                                let coefs = &proj * &y_vec;
567                                let fitted = &b_mat * &coefs;
568
569                                let mse: f64 =
570                                    (0..m).map(|j| (curve[j] - fitted[j]).powi(2)).sum::<f64>()
571                                        / m as f64;
572                                total_mse += mse;
573                                total_count += 1;
574                            }
575                        }
576                    }
577                }
578
579                if total_count > 0 {
580                    scores.push(total_mse / total_count as f64);
581                } else {
582                    scores.push(f64::INFINITY);
583                }
584            }
585        }
586    }
587
588    // Find optimal
589    let (best_idx, _) = scores
590        .iter()
591        .enumerate()
592        .min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))?;
593
594    Some(BasisNbasisCvResult {
595        optimal_nbasis: nbasis_range[best_idx],
596        scores,
597        nbasis_range: nbasis_range.to_vec(),
598        criterion,
599    })
600}
601
602#[cfg(test)]
603mod tests {
604    use super::*;
605    use std::f64::consts::PI;
606
607    fn uniform_grid(m: usize) -> Vec<f64> {
608        (0..m).map(|i| i as f64 / (m - 1) as f64).collect()
609    }
610
611    #[test]
612    fn test_bspline_penalty_matrix_symmetric() {
613        let t = uniform_grid(101);
614        let penalty = bspline_penalty_matrix(&t, 15, 4, 2);
615        let _k = 15; // may differ from actual due to knot construction
616        let actual_k = (penalty.len() as f64).sqrt() as usize;
617        for i in 0..actual_k {
618            for j in 0..actual_k {
619                assert!(
620                    (penalty[i + j * actual_k] - penalty[j + i * actual_k]).abs() < 1e-10,
621                    "Penalty matrix not symmetric at ({}, {})",
622                    i,
623                    j
624                );
625            }
626        }
627    }
628
629    #[test]
630    fn test_bspline_penalty_matrix_positive_semidefinite() {
631        let t = uniform_grid(101);
632        let penalty = bspline_penalty_matrix(&t, 10, 4, 2);
633        let k = (penalty.len() as f64).sqrt() as usize;
634        // Diagonal elements should be non-negative
635        for i in 0..k {
636            assert!(
637                penalty[i + i * k] >= -1e-10,
638                "Diagonal element {} is negative: {}",
639                i,
640                penalty[i + i * k]
641            );
642        }
643    }
644
645    #[test]
646    fn test_fourier_penalty_diagonal() {
647        let penalty = fourier_penalty_matrix(7, 1.0, 2);
648        // Should be diagonal
649        for i in 0..7 {
650            for j in 0..7 {
651                if i != j {
652                    assert!(
653                        penalty[i + j * 7].abs() < 1e-10,
654                        "Off-diagonal ({},{}) = {}",
655                        i,
656                        j,
657                        penalty[i + j * 7]
658                    );
659                }
660            }
661        }
662        // Constant term should have zero penalty
663        assert!(penalty[0].abs() < 1e-10);
664        // Higher frequency terms should have larger penalties
665        assert!(penalty[1 + 7] > 0.0);
666        assert!(penalty[3 + 3 * 7] > penalty[1 + 7]);
667    }
668
669    #[test]
670    fn test_smooth_basis_bspline() {
671        let m = 101;
672        let n = 5;
673        let t = uniform_grid(m);
674
675        // Generate noisy sine curves
676        let mut data = FdMatrix::zeros(n, m);
677        for i in 0..n {
678            for j in 0..m {
679                data[(i, j)] = (2.0 * PI * t[j]).sin() + 0.1 * (i as f64 * 0.3 + j as f64 * 0.01);
680            }
681        }
682
683        let nbasis = 15;
684        let penalty = bspline_penalty_matrix(&t, nbasis, 4, 2);
685        let _actual_k = (penalty.len() as f64).sqrt() as usize;
686
687        let fdpar = FdPar {
688            basis_type: BasisType::Bspline { order: 4 },
689            nbasis,
690            lambda: 1e-4,
691            lfd_order: 2,
692            penalty_matrix: penalty,
693        };
694
695        let result = smooth_basis(&data, &t, &fdpar);
696        assert!(result.is_some(), "smooth_basis should succeed");
697
698        let res = result.unwrap();
699        assert_eq!(res.fitted.shape(), (n, m));
700        assert_eq!(res.coefficients.nrows(), n);
701        assert!(res.edf > 0.0, "EDF should be positive");
702        assert!(res.gcv > 0.0, "GCV should be positive");
703    }
704
705    #[test]
706    fn test_smooth_basis_fourier() {
707        let m = 101;
708        let n = 3;
709        let t = uniform_grid(m);
710
711        let mut data = FdMatrix::zeros(n, m);
712        for i in 0..n {
713            for j in 0..m {
714                data[(i, j)] = (2.0 * PI * t[j]).sin() + (4.0 * PI * t[j]).cos();
715            }
716        }
717
718        let nbasis = 7;
719        let period = 1.0;
720        let penalty = fourier_penalty_matrix(nbasis, period, 2);
721
722        let fdpar = FdPar {
723            basis_type: BasisType::Fourier { period },
724            nbasis,
725            lambda: 1e-6,
726            lfd_order: 2,
727            penalty_matrix: penalty,
728        };
729
730        let result = smooth_basis(&data, &t, &fdpar);
731        assert!(result.is_some());
732
733        let res = result.unwrap();
734        // Fourier basis should fit periodic data well
735        for j in 0..m {
736            let expected = (2.0 * PI * t[j]).sin() + (4.0 * PI * t[j]).cos();
737            assert!(
738                (res.fitted[(0, j)] - expected).abs() < 0.1,
739                "Fourier fit poor at j={}: got {}, expected {}",
740                j,
741                res.fitted[(0, j)],
742                expected
743            );
744        }
745    }
746
747    #[test]
748    fn test_smooth_basis_gcv_selects_reasonable_lambda() {
749        let m = 101;
750        let n = 5;
751        let t = uniform_grid(m);
752
753        let mut data = FdMatrix::zeros(n, m);
754        for i in 0..n {
755            for j in 0..m {
756                data[(i, j)] =
757                    (2.0 * PI * t[j]).sin() + 0.1 * ((i * 37 + j * 13) % 20) as f64 / 20.0;
758            }
759        }
760
761        let basis_type = BasisType::Bspline { order: 4 };
762        let result = smooth_basis_gcv(&data, &t, &basis_type, 15, 2, (-8.0, 4.0), 25);
763        assert!(result.is_some(), "GCV search should succeed");
764    }
765
766    #[test]
767    fn test_smooth_basis_large_lambda_reduces_edf() {
768        let m = 101;
769        let n = 3;
770        let t = uniform_grid(m);
771
772        let mut data = FdMatrix::zeros(n, m);
773        for i in 0..n {
774            for j in 0..m {
775                data[(i, j)] = (2.0 * PI * t[j]).sin();
776            }
777        }
778
779        let nbasis = 15;
780        let penalty = bspline_penalty_matrix(&t, nbasis, 4, 2);
781        let _actual_k = (penalty.len() as f64).sqrt() as usize;
782
783        let fdpar_small = FdPar {
784            basis_type: BasisType::Bspline { order: 4 },
785            nbasis,
786            lambda: 1e-8,
787            lfd_order: 2,
788            penalty_matrix: penalty.clone(),
789        };
790        let fdpar_large = FdPar {
791            basis_type: BasisType::Bspline { order: 4 },
792            nbasis,
793            lambda: 1e2,
794            lfd_order: 2,
795            penalty_matrix: penalty,
796        };
797
798        let res_small = smooth_basis(&data, &t, &fdpar_small).unwrap();
799        let res_large = smooth_basis(&data, &t, &fdpar_large).unwrap();
800
801        assert!(
802            res_large.edf < res_small.edf,
803            "Larger lambda should reduce EDF: {} vs {}",
804            res_large.edf,
805            res_small.edf
806        );
807    }
808
809    // ============== basis_nbasis_cv tests ==============
810
811    #[test]
812    fn test_basis_nbasis_cv_gcv() {
813        let m = 101;
814        let n = 5;
815        let t = uniform_grid(m);
816        let mut data = FdMatrix::zeros(n, m);
817        for i in 0..n {
818            for j in 0..m {
819                data[(i, j)] =
820                    (2.0 * PI * t[j]).sin() + 0.1 * ((i * 37 + j * 13) % 20) as f64 / 20.0;
821            }
822        }
823
824        let nbasis_range: Vec<usize> = (4..=20).step_by(2).collect();
825        let result = basis_nbasis_cv(
826            &data,
827            &t,
828            &nbasis_range,
829            &BasisType::Bspline { order: 4 },
830            BasisCriterion::Gcv,
831            5,
832            1e-4,
833        );
834        assert!(result.is_some());
835        let res = result.unwrap();
836        assert!(nbasis_range.contains(&res.optimal_nbasis));
837        assert_eq!(res.scores.len(), nbasis_range.len());
838        assert_eq!(res.criterion, BasisCriterion::Gcv);
839    }
840
841    #[test]
842    fn test_basis_nbasis_cv_aic_bic() {
843        let m = 51;
844        let n = 5;
845        let t = uniform_grid(m);
846        let mut data = FdMatrix::zeros(n, m);
847        for i in 0..n {
848            for j in 0..m {
849                data[(i, j)] = (2.0 * PI * t[j]).sin();
850            }
851        }
852
853        let nbasis_range: Vec<usize> = vec![5, 7, 9, 11];
854        let aic_result = basis_nbasis_cv(
855            &data,
856            &t,
857            &nbasis_range,
858            &BasisType::Bspline { order: 4 },
859            BasisCriterion::Aic,
860            5,
861            0.0,
862        );
863        let bic_result = basis_nbasis_cv(
864            &data,
865            &t,
866            &nbasis_range,
867            &BasisType::Bspline { order: 4 },
868            BasisCriterion::Bic,
869            5,
870            0.0,
871        );
872        assert!(aic_result.is_some());
873        assert!(bic_result.is_some());
874    }
875
876    #[test]
877    fn test_basis_nbasis_cv_kfold() {
878        let m = 51;
879        let n = 10;
880        let t = uniform_grid(m);
881        let mut data = FdMatrix::zeros(n, m);
882        for i in 0..n {
883            for j in 0..m {
884                data[(i, j)] = (2.0 * PI * t[j]).sin() + 0.05 * ((i * 7 + j * 3) % 10) as f64;
885            }
886        }
887
888        let nbasis_range: Vec<usize> = vec![5, 7, 9];
889        let result = basis_nbasis_cv(
890            &data,
891            &t,
892            &nbasis_range,
893            &BasisType::Bspline { order: 4 },
894            BasisCriterion::Cv,
895            5,
896            1e-4,
897        );
898        assert!(result.is_some());
899        let res = result.unwrap();
900        assert!(nbasis_range.contains(&res.optimal_nbasis));
901        assert_eq!(res.criterion, BasisCriterion::Cv);
902    }
903}