Skip to main content

fdars_core/
smooth_basis.rs

1//! Basis-penalized smoothing with continuous derivative penalties.
2//!
3//! This module implements `smooth.basis` from R's fda package. Unlike the
4//! discrete difference penalty used in P-splines (`basis.rs`), this uses
5//! continuous derivative penalties: `min ||y - Φc||² + λ·∫(Lf)² dt`.
6//!
7//! Key capabilities:
8//! - [`smooth_basis`] — Penalized least squares with continuous roughness penalty
9//! - [`smooth_basis_gcv`] — GCV-optimal smoothing parameter selection
10//! - [`bspline_penalty_matrix`] / [`fourier_penalty_matrix`] — Roughness penalty matrices
11
12use crate::basis::{bspline_basis, fourier_basis_with_period};
13use crate::helpers::simpsons_weights;
14use crate::matrix::FdMatrix;
15use nalgebra::DMatrix;
16use std::f64::consts::PI;
17
18// ─── Types ──────────────────────────────────────────────────────────────────
19
20/// Basis type for penalized smoothing.
21#[derive(Debug, Clone, PartialEq)]
22pub enum BasisType {
23    /// B-spline basis with given order (typically 4 for cubic).
24    Bspline { order: usize },
25    /// Fourier basis with given period.
26    Fourier { period: f64 },
27}
28
29/// Functional data parameter object (basis + penalty specification).
30#[derive(Debug, Clone, PartialEq)]
31pub struct FdPar {
32    /// Type of basis system.
33    pub basis_type: BasisType,
34    /// Number of basis functions.
35    pub nbasis: usize,
36    /// Smoothing parameter.
37    pub lambda: f64,
38    /// Derivative order for the penalty (default: 2).
39    pub lfd_order: usize,
40    /// Precomputed K×K penalty matrix (column-major).
41    pub penalty_matrix: Vec<f64>,
42}
43
44/// Result of basis-penalized smoothing.
45#[derive(Debug, Clone, PartialEq)]
46pub struct SmoothBasisResult {
47    /// Basis coefficients (n × K).
48    pub coefficients: FdMatrix,
49    /// Fitted values (n × m).
50    pub fitted: FdMatrix,
51    /// Effective degrees of freedom.
52    pub edf: f64,
53    /// Generalized cross-validation score.
54    pub gcv: f64,
55    /// AIC.
56    pub aic: f64,
57    /// BIC.
58    pub bic: f64,
59    /// Roughness penalty matrix (K × K, column-major).
60    pub penalty_matrix: Vec<f64>,
61    /// Number of basis functions used.
62    pub nbasis: usize,
63}
64
65// ─── Penalty Matrices ───────────────────────────────────────────────────────
66
67/// Compute the roughness penalty matrix for B-splines via numerical quadrature.
68///
69/// R\[j,k\] = ∫ D^m B_j(t) · D^m B_k(t) dt
70///
71/// Uses Simpson's rule on a fine sub-grid for each knot interval.
72///
73/// # Arguments
74/// * `argvals` — Evaluation points (length m)
75/// * `nbasis` — Number of basis functions
76/// * `order` — B-spline order (typically 4 for cubic)
77/// * `lfd_order` — Derivative order for penalty (typically 2)
78///
79/// # Returns
80/// K × K penalty matrix in column-major layout (K = nbasis)
81pub fn bspline_penalty_matrix(
82    argvals: &[f64],
83    nbasis: usize,
84    order: usize,
85    lfd_order: usize,
86) -> Vec<f64> {
87    if nbasis < 2 || order < 1 || lfd_order >= order || argvals.len() < 2 {
88        return vec![0.0; nbasis * nbasis];
89    }
90
91    let nknots = nbasis.saturating_sub(order).max(2);
92
93    // Create a fine quadrature grid (10 sub-points per original interval)
94    let n_sub = 10;
95    let t_min = argvals[0];
96    let t_max = argvals[argvals.len() - 1];
97    let n_quad = (argvals.len() - 1) * n_sub + 1;
98    let quad_t: Vec<f64> = (0..n_quad)
99        .map(|i| t_min + (t_max - t_min) * i as f64 / (n_quad - 1) as f64)
100        .collect();
101
102    // Evaluate B-spline basis on fine grid
103    let basis_fine = bspline_basis(&quad_t, nknots, order);
104    let actual_nbasis = basis_fine.len() / n_quad;
105
106    // Compute derivatives of B-spline basis numerically
107    let h = (t_max - t_min) / (n_quad - 1) as f64;
108    let deriv_basis = differentiate_basis_columns(&basis_fine, n_quad, actual_nbasis, h, lfd_order);
109
110    // Integration weights on fine grid
111    let weights = simpsons_weights(&quad_t);
112
113    // Compute penalty matrix: R[j,k] = ∫ D^m B_j · D^m B_k dt
114    integrate_symmetric_penalty(&deriv_basis, &weights, actual_nbasis, n_quad)
115}
116
117/// Compute the roughness penalty matrix for a Fourier basis.
118///
119/// For Fourier basis, the penalty is diagonal with eigenvalues `(2πk/T)^(2m)`.
120///
121/// # Arguments
122/// * `nbasis` — Number of basis functions
123/// * `period` — Period of the Fourier basis
124/// * `lfd_order` — Derivative order for penalty
125///
126/// # Returns
127/// K × K penalty matrix in column-major layout
128pub fn fourier_penalty_matrix(nbasis: usize, period: f64, lfd_order: usize) -> Vec<f64> {
129    let k = nbasis;
130    let mut penalty = vec![0.0; k * k];
131
132    // First basis function is constant → lfd_order-th derivative is 0
133    // penalty[0] = 0 (already zero)
134
135    // For sin/cos pairs: eigenvalue is (2πk/T)^(2m)
136    // Matches R's fda package convention (sqrt(2)-normalized basis)
137    let mut freq = 1;
138    let mut idx = 1;
139    while idx < k {
140        let omega = 2.0 * PI * f64::from(freq) / period;
141        let eigenval = omega.powi(2 * lfd_order as i32);
142
143        // sin component
144        if idx < k {
145            penalty[idx + idx * k] = eigenval;
146            idx += 1;
147        }
148        // cos component
149        if idx < k {
150            penalty[idx + idx * k] = eigenval;
151            idx += 1;
152        }
153        freq += 1;
154    }
155
156    penalty
157}
158
159// ─── Smoothing Functions ────────────────────────────────────────────────────
160
161/// Perform basis-penalized smoothing.
162///
163/// Solves `(Φ'Φ + λR)c = Φ'y` per curve via Cholesky decomposition.
164/// This implements `smooth.basis` from R's fda package.
165///
166/// # Arguments
167/// * `data` — Functional data matrix (n × m)
168/// * `argvals` — Evaluation points (length m)
169/// * `fdpar` — Functional parameter object specifying basis and penalty
170///
171/// # Returns
172/// [`SmoothBasisResult`] with coefficients, fitted values, and diagnostics.
173pub fn smooth_basis(
174    data: &FdMatrix,
175    argvals: &[f64],
176    fdpar: &FdPar,
177) -> Result<SmoothBasisResult, crate::FdarError> {
178    let (n, m) = data.shape();
179    if n == 0 || m == 0 || argvals.len() != m || fdpar.nbasis < 2 {
180        return Err(crate::FdarError::InvalidDimension {
181            parameter: "data/argvals/fdpar",
182            expected: "n > 0, m > 0, argvals.len() == m, nbasis >= 2".to_string(),
183            actual: format!(
184                "n={}, m={}, argvals.len()={}, nbasis={}",
185                n,
186                m,
187                argvals.len(),
188                fdpar.nbasis
189            ),
190        });
191    }
192
193    // Evaluate basis on argvals
194    let (basis_flat, actual_nbasis) = evaluate_basis(argvals, &fdpar.basis_type, fdpar.nbasis);
195    let k = actual_nbasis;
196
197    let b_mat = DMatrix::from_column_slice(m, k, &basis_flat);
198    let r_mat = DMatrix::from_column_slice(k, k, &fdpar.penalty_matrix);
199
200    // (Φ'Φ + λR + εI) — small ridge ensures positive definiteness
201    let btb = b_mat.transpose() * &b_mat;
202    let ridge_eps = 1e-10;
203    let system: DMatrix<f64> =
204        &btb + fdpar.lambda * &r_mat + ridge_eps * DMatrix::<f64>::identity(k, k);
205
206    // Invert the penalized system
207    let system_inv =
208        invert_penalized_system(&system, k).ok_or_else(|| crate::FdarError::ComputationFailed {
209            operation: "matrix inversion",
210            detail: "failed to invert penalized system (Φ'Φ + λR)".to_string(),
211        })?;
212
213    // Hat matrix: H = Φ (Φ'Φ + λR)^{-1} Φ'  →  EDF = tr(H)
214    let h_mat = &b_mat * &system_inv * b_mat.transpose();
215    let edf: f64 = (0..m).map(|i| h_mat[(i, i)]).sum();
216
217    // Project all curves
218    let proj = &system_inv * b_mat.transpose();
219    let (all_coefs, all_fitted, total_rss) = project_all_curves(data, &b_mat, &proj, n, m, k);
220
221    let total_points = (n * m) as f64;
222    let gcv = compute_gcv(total_rss, total_points, edf, m);
223    let mse = total_rss / total_points;
224    // Total effective degrees of freedom = n curves * per-curve edf
225    let total_edf = n as f64 * edf;
226    let aic = total_points * mse.max(1e-300).ln() + 2.0 * total_edf;
227    let bic = total_points * mse.max(1e-300).ln() + total_points.ln() * total_edf;
228
229    Ok(SmoothBasisResult {
230        coefficients: all_coefs,
231        fitted: all_fitted,
232        edf,
233        gcv,
234        aic,
235        bic,
236        penalty_matrix: fdpar.penalty_matrix.clone(),
237        nbasis: k,
238    })
239}
240
241/// Perform basis-penalized smoothing with GCV-optimal lambda.
242///
243/// Searches over a log-lambda grid and selects the lambda minimizing GCV.
244///
245/// # Arguments
246/// * `data` — Functional data matrix (n × m)
247/// * `argvals` — Evaluation points (length m)
248/// * `basis_type` — Type of basis system
249/// * `nbasis` — Number of basis functions
250/// * `lfd_order` — Derivative order for penalty
251/// * `log_lambda_range` — Range of log10(lambda) to search, e.g. (-8.0, 4.0)
252/// * `n_grid` — Number of grid points for the search
253pub fn smooth_basis_gcv(
254    data: &FdMatrix,
255    argvals: &[f64],
256    basis_type: &BasisType,
257    nbasis: usize,
258    lfd_order: usize,
259    log_lambda_range: (f64, f64),
260    n_grid: usize,
261) -> Option<SmoothBasisResult> {
262    let m = argvals.len();
263    if m == 0 || nbasis < 2 || n_grid < 2 {
264        return None;
265    }
266
267    // Compute penalty matrix once
268    let penalty = match basis_type {
269        BasisType::Bspline { order } => bspline_penalty_matrix(argvals, nbasis, *order, lfd_order),
270        BasisType::Fourier { period } => fourier_penalty_matrix(nbasis, *period, lfd_order),
271    };
272
273    let (lo, hi) = log_lambda_range;
274    let mut best_gcv = f64::INFINITY;
275    let mut best_result: Option<SmoothBasisResult> = None;
276
277    for i in 0..n_grid {
278        let log_lam = lo + (hi - lo) * i as f64 / (n_grid - 1) as f64;
279        let lam = 10.0_f64.powf(log_lam);
280
281        let fdpar = FdPar {
282            basis_type: basis_type.clone(),
283            nbasis,
284            lambda: lam,
285            lfd_order,
286            penalty_matrix: penalty.clone(),
287        };
288
289        if let Ok(result) = smooth_basis(data, argvals, &fdpar) {
290            if result.gcv < best_gcv {
291                best_gcv = result.gcv;
292                best_result = Some(result);
293            }
294        }
295    }
296
297    best_result
298}
299
300// ─── Internal Helpers ───────────────────────────────────────────────────────
301
302/// Differentiate column-major basis matrix `lfd_order` times using gradient_uniform.
303fn differentiate_basis_columns(
304    basis: &[f64],
305    n_quad: usize,
306    nbasis: usize,
307    h: f64,
308    lfd_order: usize,
309) -> Vec<f64> {
310    let mut deriv = basis.to_vec();
311    for _ in 0..lfd_order {
312        let mut new_deriv = vec![0.0; n_quad * nbasis];
313        for j in 0..nbasis {
314            let col: Vec<f64> = (0..n_quad).map(|i| deriv[i + j * n_quad]).collect();
315            let grad = crate::helpers::gradient_uniform(&col, h);
316            for i in 0..n_quad {
317                new_deriv[i + j * n_quad] = grad[i];
318            }
319        }
320        deriv = new_deriv;
321    }
322    deriv
323}
324
325/// Integrate symmetric penalty: R[j,k] = ∫ D^m B_j · D^m B_k dt.
326fn integrate_symmetric_penalty(
327    deriv_basis: &[f64],
328    weights: &[f64],
329    k: usize,
330    n_quad: usize,
331) -> Vec<f64> {
332    let mut penalty = vec![0.0; k * k];
333    for j in 0..k {
334        for l in j..k {
335            let mut val = 0.0;
336            for i in 0..n_quad {
337                val += deriv_basis[i + j * n_quad] * deriv_basis[i + l * n_quad] * weights[i];
338            }
339            penalty[j + l * k] = val;
340            penalty[l + j * k] = val;
341        }
342    }
343    penalty
344}
345
346/// Evaluate basis functions on argvals, returning (flat column-major, actual_nbasis).
347fn evaluate_basis(argvals: &[f64], basis_type: &BasisType, nbasis: usize) -> (Vec<f64>, usize) {
348    let m = argvals.len();
349    match basis_type {
350        BasisType::Bspline { order } => {
351            let nknots = nbasis.saturating_sub(*order).max(2);
352            let basis = bspline_basis(argvals, nknots, *order);
353            let actual = basis.len() / m;
354            (basis, actual)
355        }
356        BasisType::Fourier { period } => {
357            let basis = fourier_basis_with_period(argvals, nbasis, *period);
358            (basis, nbasis)
359        }
360    }
361}
362
363/// Invert the penalized system matrix via Cholesky or SVD pseudoinverse.
364fn invert_penalized_system(system: &DMatrix<f64>, k: usize) -> Option<DMatrix<f64>> {
365    if let Some(chol) = system.clone().cholesky() {
366        return Some(chol.inverse());
367    }
368    // SVD fallback
369    let svd = nalgebra::SVD::new(system.clone(), true, true);
370    let u = svd.u.as_ref()?;
371    let v_t = svd.v_t.as_ref()?;
372    let max_sv: f64 = svd.singular_values.iter().copied().fold(0.0_f64, f64::max);
373    let eps = 1e-10 * max_sv;
374    let mut inv = DMatrix::<f64>::zeros(k, k);
375    for ii in 0..k {
376        for jj in 0..k {
377            let mut sum = 0.0;
378            for s in 0..k.min(svd.singular_values.len()) {
379                if svd.singular_values[s] > eps {
380                    sum += v_t[(s, ii)] / svd.singular_values[s] * u[(jj, s)];
381                }
382            }
383            inv[(ii, jj)] = sum;
384        }
385    }
386    Some(inv)
387}
388
389/// Project all curves onto basis, returning (coefficients, fitted, total_rss).
390fn project_all_curves(
391    data: &FdMatrix,
392    b_mat: &DMatrix<f64>,
393    proj: &DMatrix<f64>,
394    n: usize,
395    m: usize,
396    k: usize,
397) -> (FdMatrix, FdMatrix, f64) {
398    let mut all_coefs = FdMatrix::zeros(n, k);
399    let mut all_fitted = FdMatrix::zeros(n, m);
400    let mut total_rss = 0.0;
401
402    for i in 0..n {
403        let curve: Vec<f64> = (0..m).map(|j| data[(i, j)]).collect();
404        let y_vec = nalgebra::DVector::from_vec(curve.clone());
405        let coefs = proj * &y_vec;
406
407        for j in 0..k {
408            all_coefs[(i, j)] = coefs[j];
409        }
410        let fitted = b_mat * &coefs;
411        for j in 0..m {
412            all_fitted[(i, j)] = fitted[j];
413            let resid = curve[j] - fitted[j];
414            total_rss += resid * resid;
415        }
416    }
417
418    (all_coefs, all_fitted, total_rss)
419}
420
421/// Compute GCV score.
422fn compute_gcv(rss: f64, n_points: f64, edf: f64, m: usize) -> f64 {
423    let gcv_denom = 1.0 - edf / m as f64;
424    if gcv_denom.abs() > 1e-10 {
425        (rss / n_points) / (gcv_denom * gcv_denom)
426    } else {
427        f64::INFINITY
428    }
429}
430
431// ─── Nbasis Selection via CV ────────────────────────────────────────────────
432
433/// Criterion for nbasis selection.
434#[derive(Debug, Clone, Copy, PartialEq)]
435pub enum BasisCriterion {
436    /// Generalized cross-validation.
437    Gcv,
438    /// Leave-one-out cross-validation (k-fold).
439    Cv,
440    /// Akaike Information Criterion.
441    Aic,
442    /// Bayesian Information Criterion.
443    Bic,
444}
445
446/// Result of nbasis selection.
447#[derive(Debug, Clone, PartialEq)]
448pub struct BasisNbasisCvResult {
449    /// Optimal number of basis functions.
450    pub optimal_nbasis: usize,
451    /// Score for each nbasis tested.
452    pub scores: Vec<f64>,
453    /// Range of nbasis values tested.
454    pub nbasis_range: Vec<usize>,
455    /// Criterion used.
456    pub criterion: BasisCriterion,
457}
458
459/// Evaluate information criterion (GCV/AIC/BIC) for a range of nbasis values.
460fn evaluate_nbasis_info_criterion(
461    data: &FdMatrix,
462    argvals: &[f64],
463    nbasis_range: &[usize],
464    basis_type: &BasisType,
465    criterion: BasisCriterion,
466    lambda: f64,
467) -> Vec<f64> {
468    let mut scores = Vec::with_capacity(nbasis_range.len());
469    for &nb in nbasis_range {
470        if nb < 2 {
471            scores.push(f64::INFINITY);
472            continue;
473        }
474        let penalty = match basis_type {
475            BasisType::Bspline { order } => bspline_penalty_matrix(argvals, nb, *order, 2),
476            BasisType::Fourier { period } => fourier_penalty_matrix(nb, *period, 2),
477        };
478        let fdpar = FdPar {
479            basis_type: basis_type.clone(),
480            nbasis: nb,
481            lambda,
482            lfd_order: 2,
483            penalty_matrix: penalty,
484        };
485        match smooth_basis(data, argvals, &fdpar) {
486            Ok(result) => {
487                let score = match criterion {
488                    BasisCriterion::Gcv => result.gcv,
489                    BasisCriterion::Aic => result.aic,
490                    BasisCriterion::Bic => result.bic,
491                    _ => unreachable!(),
492                };
493                scores.push(score);
494            }
495            Err(_) => scores.push(f64::INFINITY),
496        }
497    }
498    scores
499}
500
501/// Evaluate nbasis via k-fold cross-validation of reconstruction error.
502fn evaluate_nbasis_cv(
503    data: &FdMatrix,
504    argvals: &[f64],
505    nbasis_range: &[usize],
506    basis_type: &BasisType,
507    lambda: f64,
508    n_folds: usize,
509) -> Vec<f64> {
510    let (n, m) = data.shape();
511    let n_folds = n_folds.max(2);
512    let folds = crate::cv::create_folds(n, n_folds, 42);
513    let mut scores = Vec::with_capacity(nbasis_range.len());
514
515    for &nb in nbasis_range {
516        if nb < 2 {
517            scores.push(f64::INFINITY);
518            continue;
519        }
520        let penalty = match basis_type {
521            BasisType::Bspline { order } => bspline_penalty_matrix(argvals, nb, *order, 2),
522            BasisType::Fourier { period } => fourier_penalty_matrix(nb, *period, 2),
523        };
524
525        let mut total_mse = 0.0;
526        let mut total_count = 0;
527
528        for fold in 0..n_folds {
529            let (train_idx, test_idx) = crate::cv::fold_indices(&folds, fold);
530            if train_idx.is_empty() || test_idx.is_empty() {
531                continue;
532            }
533            let train_data = crate::cv::subset_rows(data, &train_idx);
534            let fdpar = FdPar {
535                basis_type: basis_type.clone(),
536                nbasis: nb,
537                lambda,
538                lfd_order: 2,
539                penalty_matrix: penalty.clone(),
540            };
541
542            if let Ok(train_result) = smooth_basis(&train_data, argvals, &fdpar) {
543                let (basis_flat, actual_k) = evaluate_basis(argvals, basis_type, nb);
544                let b_mat = DMatrix::from_column_slice(m, actual_k, &basis_flat);
545                let r_mat =
546                    DMatrix::from_column_slice(actual_k, actual_k, &train_result.penalty_matrix);
547                let btb = b_mat.transpose() * &b_mat;
548                let ridge_eps = 1e-10;
549                let system: DMatrix<f64> = &btb
550                    + lambda * &r_mat
551                    + ridge_eps * DMatrix::<f64>::identity(actual_k, actual_k);
552
553                if let Some(system_inv) = invert_penalized_system(&system, actual_k) {
554                    let proj = &system_inv * b_mat.transpose();
555                    for &ti in &test_idx {
556                        let curve: Vec<f64> = (0..m).map(|j| data[(ti, j)]).collect();
557                        let y_vec = nalgebra::DVector::from_vec(curve.clone());
558                        let coefs = &proj * &y_vec;
559                        let fitted = &b_mat * &coefs;
560                        let mse: f64 =
561                            (0..m).map(|j| (curve[j] - fitted[j]).powi(2)).sum::<f64>() / m as f64;
562                        total_mse += mse;
563                        total_count += 1;
564                    }
565                }
566            }
567        }
568
569        if total_count > 0 {
570            scores.push(total_mse / f64::from(total_count));
571        } else {
572            scores.push(f64::INFINITY);
573        }
574    }
575    scores
576}
577
578/// Select the optimal number of basis functions using multiple criteria
579/// (R's `fdata2basis_cv`).
580pub fn basis_nbasis_cv(
581    data: &FdMatrix,
582    argvals: &[f64],
583    nbasis_range: &[usize],
584    basis_type: &BasisType,
585    criterion: BasisCriterion,
586    n_folds: usize,
587    lambda: f64,
588) -> Option<BasisNbasisCvResult> {
589    let (n, m) = data.shape();
590    if n == 0 || m == 0 || argvals.len() != m || nbasis_range.is_empty() {
591        return None;
592    }
593
594    let scores = match criterion {
595        BasisCriterion::Gcv | BasisCriterion::Aic | BasisCriterion::Bic => {
596            evaluate_nbasis_info_criterion(
597                data,
598                argvals,
599                nbasis_range,
600                basis_type,
601                criterion,
602                lambda,
603            )
604        }
605        BasisCriterion::Cv => {
606            evaluate_nbasis_cv(data, argvals, nbasis_range, basis_type, lambda, n_folds)
607        }
608    };
609
610    let (best_idx, _) = scores
611        .iter()
612        .enumerate()
613        .min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))?;
614
615    Some(BasisNbasisCvResult {
616        optimal_nbasis: nbasis_range[best_idx],
617        scores,
618        nbasis_range: nbasis_range.to_vec(),
619        criterion,
620    })
621}
622
623#[cfg(test)]
624mod tests {
625    use super::*;
626    use std::f64::consts::PI;
627
628    fn uniform_grid(m: usize) -> Vec<f64> {
629        (0..m).map(|i| i as f64 / (m - 1) as f64).collect()
630    }
631
632    #[test]
633    fn test_bspline_penalty_matrix_symmetric() {
634        let t = uniform_grid(101);
635        let penalty = bspline_penalty_matrix(&t, 15, 4, 2);
636        let _k = 15; // may differ from actual due to knot construction
637        let actual_k = (penalty.len() as f64).sqrt() as usize;
638        for i in 0..actual_k {
639            for j in 0..actual_k {
640                assert!(
641                    (penalty[i + j * actual_k] - penalty[j + i * actual_k]).abs() < 1e-10,
642                    "Penalty matrix not symmetric at ({}, {})",
643                    i,
644                    j
645                );
646            }
647        }
648    }
649
650    #[test]
651    fn test_bspline_penalty_matrix_positive_semidefinite() {
652        let t = uniform_grid(101);
653        let penalty = bspline_penalty_matrix(&t, 10, 4, 2);
654        let k = (penalty.len() as f64).sqrt() as usize;
655        // Diagonal elements should be non-negative
656        for i in 0..k {
657            assert!(
658                penalty[i + i * k] >= -1e-10,
659                "Diagonal element {} is negative: {}",
660                i,
661                penalty[i + i * k]
662            );
663        }
664    }
665
666    #[test]
667    fn test_fourier_penalty_diagonal() {
668        let penalty = fourier_penalty_matrix(7, 1.0, 2);
669        // Should be diagonal
670        for i in 0..7 {
671            for j in 0..7 {
672                if i != j {
673                    assert!(
674                        penalty[i + j * 7].abs() < 1e-10,
675                        "Off-diagonal ({},{}) = {}",
676                        i,
677                        j,
678                        penalty[i + j * 7]
679                    );
680                }
681            }
682        }
683        // Constant term should have zero penalty
684        assert!(penalty[0].abs() < 1e-10);
685        // Higher frequency terms should have larger penalties
686        assert!(penalty[1 + 7] > 0.0);
687        assert!(penalty[3 + 3 * 7] > penalty[1 + 7]);
688    }
689
690    #[test]
691    fn test_smooth_basis_bspline() {
692        let m = 101;
693        let n = 5;
694        let t = uniform_grid(m);
695
696        // Generate noisy sine curves
697        let mut data = FdMatrix::zeros(n, m);
698        for i in 0..n {
699            for j in 0..m {
700                data[(i, j)] = (2.0 * PI * t[j]).sin() + 0.1 * (i as f64 * 0.3 + j as f64 * 0.01);
701            }
702        }
703
704        let nbasis = 15;
705        let penalty = bspline_penalty_matrix(&t, nbasis, 4, 2);
706        let _actual_k = (penalty.len() as f64).sqrt() as usize;
707
708        let fdpar = FdPar {
709            basis_type: BasisType::Bspline { order: 4 },
710            nbasis,
711            lambda: 1e-4,
712            lfd_order: 2,
713            penalty_matrix: penalty,
714        };
715
716        let result = smooth_basis(&data, &t, &fdpar);
717        assert!(result.is_ok(), "smooth_basis should succeed");
718
719        let res = result.unwrap();
720        assert_eq!(res.fitted.shape(), (n, m));
721        assert_eq!(res.coefficients.nrows(), n);
722        assert!(res.edf > 0.0, "EDF should be positive");
723        assert!(res.gcv > 0.0, "GCV should be positive");
724    }
725
726    #[test]
727    fn test_smooth_basis_fourier() {
728        let m = 101;
729        let n = 3;
730        let t = uniform_grid(m);
731
732        let mut data = FdMatrix::zeros(n, m);
733        for i in 0..n {
734            for j in 0..m {
735                data[(i, j)] = (2.0 * PI * t[j]).sin() + (4.0 * PI * t[j]).cos();
736            }
737        }
738
739        let nbasis = 7;
740        let period = 1.0;
741        let penalty = fourier_penalty_matrix(nbasis, period, 2);
742
743        let fdpar = FdPar {
744            basis_type: BasisType::Fourier { period },
745            nbasis,
746            lambda: 1e-6,
747            lfd_order: 2,
748            penalty_matrix: penalty,
749        };
750
751        let result = smooth_basis(&data, &t, &fdpar);
752        assert!(result.is_ok());
753
754        let res = result.unwrap();
755        // Fourier basis should fit periodic data well
756        for j in 0..m {
757            let expected = (2.0 * PI * t[j]).sin() + (4.0 * PI * t[j]).cos();
758            assert!(
759                (res.fitted[(0, j)] - expected).abs() < 0.1,
760                "Fourier fit poor at j={}: got {}, expected {}",
761                j,
762                res.fitted[(0, j)],
763                expected
764            );
765        }
766    }
767
768    #[test]
769    fn test_smooth_basis_gcv_selects_reasonable_lambda() {
770        let m = 101;
771        let n = 5;
772        let t = uniform_grid(m);
773
774        let mut data = FdMatrix::zeros(n, m);
775        for i in 0..n {
776            for j in 0..m {
777                data[(i, j)] =
778                    (2.0 * PI * t[j]).sin() + 0.1 * ((i * 37 + j * 13) % 20) as f64 / 20.0;
779            }
780        }
781
782        let basis_type = BasisType::Bspline { order: 4 };
783        let result = smooth_basis_gcv(&data, &t, &basis_type, 15, 2, (-8.0, 4.0), 25);
784        assert!(result.is_some(), "GCV search should succeed");
785    }
786
787    #[test]
788    fn test_smooth_basis_large_lambda_reduces_edf() {
789        let m = 101;
790        let n = 3;
791        let t = uniform_grid(m);
792
793        let mut data = FdMatrix::zeros(n, m);
794        for i in 0..n {
795            for j in 0..m {
796                data[(i, j)] = (2.0 * PI * t[j]).sin();
797            }
798        }
799
800        let nbasis = 15;
801        let penalty = bspline_penalty_matrix(&t, nbasis, 4, 2);
802        let _actual_k = (penalty.len() as f64).sqrt() as usize;
803
804        let fdpar_small = FdPar {
805            basis_type: BasisType::Bspline { order: 4 },
806            nbasis,
807            lambda: 1e-8,
808            lfd_order: 2,
809            penalty_matrix: penalty.clone(),
810        };
811        let fdpar_large = FdPar {
812            basis_type: BasisType::Bspline { order: 4 },
813            nbasis,
814            lambda: 1e2,
815            lfd_order: 2,
816            penalty_matrix: penalty,
817        };
818
819        let res_small = smooth_basis(&data, &t, &fdpar_small).unwrap();
820        let res_large = smooth_basis(&data, &t, &fdpar_large).unwrap();
821
822        assert!(
823            res_large.edf < res_small.edf,
824            "Larger lambda should reduce EDF: {} vs {}",
825            res_large.edf,
826            res_small.edf
827        );
828    }
829
830    // ============== basis_nbasis_cv tests ==============
831
832    #[test]
833    fn test_basis_nbasis_cv_gcv() {
834        let m = 101;
835        let n = 5;
836        let t = uniform_grid(m);
837        let mut data = FdMatrix::zeros(n, m);
838        for i in 0..n {
839            for j in 0..m {
840                data[(i, j)] =
841                    (2.0 * PI * t[j]).sin() + 0.1 * ((i * 37 + j * 13) % 20) as f64 / 20.0;
842            }
843        }
844
845        let nbasis_range: Vec<usize> = (4..=20).step_by(2).collect();
846        let result = basis_nbasis_cv(
847            &data,
848            &t,
849            &nbasis_range,
850            &BasisType::Bspline { order: 4 },
851            BasisCriterion::Gcv,
852            5,
853            1e-4,
854        );
855        assert!(result.is_some());
856        let res = result.unwrap();
857        assert!(nbasis_range.contains(&res.optimal_nbasis));
858        assert_eq!(res.scores.len(), nbasis_range.len());
859        assert_eq!(res.criterion, BasisCriterion::Gcv);
860    }
861
862    #[test]
863    fn test_basis_nbasis_cv_aic_bic() {
864        let m = 51;
865        let n = 5;
866        let t = uniform_grid(m);
867        let mut data = FdMatrix::zeros(n, m);
868        for i in 0..n {
869            for j in 0..m {
870                data[(i, j)] = (2.0 * PI * t[j]).sin();
871            }
872        }
873
874        let nbasis_range: Vec<usize> = vec![5, 7, 9, 11];
875        let aic_result = basis_nbasis_cv(
876            &data,
877            &t,
878            &nbasis_range,
879            &BasisType::Bspline { order: 4 },
880            BasisCriterion::Aic,
881            5,
882            0.0,
883        );
884        let bic_result = basis_nbasis_cv(
885            &data,
886            &t,
887            &nbasis_range,
888            &BasisType::Bspline { order: 4 },
889            BasisCriterion::Bic,
890            5,
891            0.0,
892        );
893        assert!(aic_result.is_some());
894        assert!(bic_result.is_some());
895    }
896
897    #[test]
898    fn test_basis_nbasis_cv_kfold() {
899        let m = 51;
900        let n = 10;
901        let t = uniform_grid(m);
902        let mut data = FdMatrix::zeros(n, m);
903        for i in 0..n {
904            for j in 0..m {
905                data[(i, j)] = (2.0 * PI * t[j]).sin() + 0.05 * ((i * 7 + j * 3) % 10) as f64;
906            }
907        }
908
909        let nbasis_range: Vec<usize> = vec![5, 7, 9];
910        let result = basis_nbasis_cv(
911            &data,
912            &t,
913            &nbasis_range,
914            &BasisType::Bspline { order: 4 },
915            BasisCriterion::Cv,
916            5,
917            1e-4,
918        );
919        assert!(result.is_some());
920        let res = result.unwrap();
921        assert!(nbasis_range.contains(&res.optimal_nbasis));
922        assert_eq!(res.criterion, BasisCriterion::Cv);
923    }
924}