Skip to main content

fdars_core/
smooth_basis.rs

1//! Basis-penalized smoothing with continuous derivative penalties.
2//!
3//! This module implements `smooth.basis` from R's fda package. Unlike the
4//! discrete difference penalty used in P-splines (`basis.rs`), this uses
5//! continuous derivative penalties: `min ||y - Φc||² + λ·∫(Lf)² dt`.
6//!
7//! Key capabilities:
8//! - [`smooth_basis`] — Penalized least squares with continuous roughness penalty
9//! - [`smooth_basis_gcv`] — GCV-optimal smoothing parameter selection
10//! - [`bspline_penalty_matrix`] / [`fourier_penalty_matrix`] — Roughness penalty matrices
11
12use crate::basis::{bspline_basis, fourier_basis_with_period};
13use crate::helpers::simpsons_weights;
14use crate::matrix::FdMatrix;
15use nalgebra::DMatrix;
16use std::f64::consts::PI;
17
18// ─── Types ──────────────────────────────────────────────────────────────────
19
20/// Basis type for penalized smoothing.
21#[derive(Debug, Clone, PartialEq)]
22pub enum BasisType {
23    /// B-spline basis with given order (typically 4 for cubic).
24    Bspline { order: usize },
25    /// Fourier basis with given period.
26    Fourier { period: f64 },
27}
28
29/// Functional data parameter object (basis + penalty specification).
30#[derive(Debug, Clone, PartialEq)]
31pub struct FdPar {
32    /// Type of basis system.
33    pub basis_type: BasisType,
34    /// Number of basis functions.
35    pub nbasis: usize,
36    /// Smoothing parameter.
37    pub lambda: f64,
38    /// Derivative order for the penalty (default: 2).
39    pub lfd_order: usize,
40    /// Precomputed K×K penalty matrix (column-major).
41    pub penalty_matrix: Vec<f64>,
42}
43
44/// Result of basis-penalized smoothing.
45#[derive(Debug, Clone, PartialEq)]
46#[non_exhaustive]
47pub struct SmoothBasisResult {
48    /// Basis coefficients (n × K).
49    pub coefficients: FdMatrix,
50    /// Fitted values (n × m).
51    pub fitted: FdMatrix,
52    /// Effective degrees of freedom.
53    pub edf: f64,
54    /// Generalized cross-validation score.
55    pub gcv: f64,
56    /// AIC.
57    pub aic: f64,
58    /// BIC.
59    pub bic: f64,
60    /// Roughness penalty matrix (K × K, column-major).
61    pub penalty_matrix: Vec<f64>,
62    /// Number of basis functions used.
63    pub nbasis: usize,
64}
65
66// ─── Penalty Matrices ───────────────────────────────────────────────────────
67
68/// Compute the roughness penalty matrix for B-splines via numerical quadrature.
69///
70/// R\[j,k\] = ∫ D^m B_j(t) · D^m B_k(t) dt
71///
72/// Uses Simpson's rule on a fine sub-grid for each knot interval.
73///
74/// # Arguments
75/// * `argvals` — Evaluation points (length m)
76/// * `nbasis` — Number of basis functions
77/// * `order` — B-spline order (typically 4 for cubic)
78/// * `lfd_order` — Derivative order for penalty (typically 2)
79///
80/// # Returns
81/// K × K penalty matrix in column-major layout (K = nbasis)
82pub fn bspline_penalty_matrix(
83    argvals: &[f64],
84    nbasis: usize,
85    order: usize,
86    lfd_order: usize,
87) -> Vec<f64> {
88    if nbasis < 2 || order < 1 || lfd_order >= order || argvals.len() < 2 {
89        return vec![0.0; nbasis * nbasis];
90    }
91
92    let nknots = nbasis.saturating_sub(order).max(2);
93
94    // Create a fine quadrature grid (10 sub-points per original interval)
95    let n_sub = 10;
96    let t_min = argvals[0];
97    let t_max = argvals[argvals.len() - 1];
98    let n_quad = (argvals.len() - 1) * n_sub + 1;
99    let quad_t: Vec<f64> = (0..n_quad)
100        .map(|i| t_min + (t_max - t_min) * i as f64 / (n_quad - 1) as f64)
101        .collect();
102
103    // Evaluate B-spline basis on fine grid
104    let basis_fine = bspline_basis(&quad_t, nknots, order);
105    let actual_nbasis = basis_fine.len() / n_quad;
106
107    // Compute derivatives of B-spline basis numerically
108    let h = (t_max - t_min) / (n_quad - 1) as f64;
109    let deriv_basis = differentiate_basis_columns(&basis_fine, n_quad, actual_nbasis, h, lfd_order);
110
111    // Integration weights on fine grid
112    let weights = simpsons_weights(&quad_t);
113
114    // Compute penalty matrix: R[j,k] = ∫ D^m B_j · D^m B_k dt
115    integrate_symmetric_penalty(&deriv_basis, &weights, actual_nbasis, n_quad)
116}
117
118/// Compute the roughness penalty matrix for a Fourier basis.
119///
120/// For Fourier basis, the penalty is diagonal with eigenvalues `(2πk/T)^(2m)`.
121///
122/// # Arguments
123/// * `nbasis` — Number of basis functions
124/// * `period` — Period of the Fourier basis
125/// * `lfd_order` — Derivative order for penalty
126///
127/// # Returns
128/// K × K penalty matrix in column-major layout
129pub fn fourier_penalty_matrix(nbasis: usize, period: f64, lfd_order: usize) -> Vec<f64> {
130    let k = nbasis;
131    let mut penalty = vec![0.0; k * k];
132
133    // First basis function is constant → lfd_order-th derivative is 0
134    // penalty[0] = 0 (already zero)
135
136    // For sin/cos pairs: eigenvalue is (2πk/T)^(2m)
137    // Matches R's fda package convention (sqrt(2)-normalized basis)
138    let mut freq = 1;
139    let mut idx = 1;
140    while idx < k {
141        let omega = 2.0 * PI * f64::from(freq) / period;
142        let eigenval = omega.powi(2 * lfd_order as i32);
143
144        // sin component
145        if idx < k {
146            penalty[idx + idx * k] = eigenval;
147            idx += 1;
148        }
149        // cos component
150        if idx < k {
151            penalty[idx + idx * k] = eigenval;
152            idx += 1;
153        }
154        freq += 1;
155    }
156
157    penalty
158}
159
160// ─── Smoothing Functions ────────────────────────────────────────────────────
161
162/// Perform basis-penalized smoothing.
163///
164/// Solves `(Φ'Φ + λR)c = Φ'y` per curve via Cholesky decomposition.
165/// This implements `smooth.basis` from R's fda package.
166///
167/// # Arguments
168/// * `data` — Functional data matrix (n × m)
169/// * `argvals` — Evaluation points (length m)
170/// * `fdpar` — Functional parameter object specifying basis and penalty
171///
172/// # Returns
173/// [`SmoothBasisResult`] with coefficients, fitted values, and diagnostics.
174pub fn smooth_basis(
175    data: &FdMatrix,
176    argvals: &[f64],
177    fdpar: &FdPar,
178) -> Result<SmoothBasisResult, crate::FdarError> {
179    let (n, m) = data.shape();
180    if n == 0 || m == 0 || argvals.len() != m || fdpar.nbasis < 2 {
181        return Err(crate::FdarError::InvalidDimension {
182            parameter: "data/argvals/fdpar",
183            expected: "n > 0, m > 0, argvals.len() == m, nbasis >= 2".to_string(),
184            actual: format!(
185                "n={}, m={}, argvals.len()={}, nbasis={}",
186                n,
187                m,
188                argvals.len(),
189                fdpar.nbasis
190            ),
191        });
192    }
193
194    // Evaluate basis on argvals
195    let (basis_flat, actual_nbasis) = evaluate_basis(argvals, &fdpar.basis_type, fdpar.nbasis);
196    let k = actual_nbasis;
197
198    let b_mat = DMatrix::from_column_slice(m, k, &basis_flat);
199    let r_mat = DMatrix::from_column_slice(k, k, &fdpar.penalty_matrix);
200
201    // (Φ'Φ + λR + εI) — small ridge ensures positive definiteness
202    let btb = b_mat.transpose() * &b_mat;
203    let ridge_eps = 1e-10;
204    let system: DMatrix<f64> =
205        &btb + fdpar.lambda * &r_mat + ridge_eps * DMatrix::<f64>::identity(k, k);
206
207    // Invert the penalized system
208    let system_inv =
209        invert_penalized_system(&system, k).ok_or_else(|| crate::FdarError::ComputationFailed {
210            operation: "matrix inversion",
211            detail: "failed to invert penalized system (Φ'Φ + λR); try increasing lambda or reducing the number of basis functions".to_string(),
212        })?;
213
214    // Hat matrix: H = Φ (Φ'Φ + λR)^{-1} Φ'  →  EDF = tr(H)
215    let h_mat = &b_mat * &system_inv * b_mat.transpose();
216    let edf: f64 = (0..m).map(|i| h_mat[(i, i)]).sum();
217
218    // Project all curves
219    let proj = &system_inv * b_mat.transpose();
220    let (all_coefs, all_fitted, total_rss) = project_all_curves(data, &b_mat, &proj, n, m, k);
221
222    let total_points = (n * m) as f64;
223    let gcv = compute_gcv(total_rss, total_points, edf, m);
224    let mse = total_rss / total_points;
225    // Total effective degrees of freedom = n curves * per-curve edf
226    let total_edf = n as f64 * edf;
227    let aic = total_points * mse.max(1e-300).ln() + 2.0 * total_edf;
228    let bic = total_points * mse.max(1e-300).ln() + total_points.ln() * total_edf;
229
230    Ok(SmoothBasisResult {
231        coefficients: all_coefs,
232        fitted: all_fitted,
233        edf,
234        gcv,
235        aic,
236        bic,
237        penalty_matrix: fdpar.penalty_matrix.clone(),
238        nbasis: k,
239    })
240}
241
242/// Perform basis-penalized smoothing with GCV-optimal lambda.
243///
244/// Searches over a log-lambda grid and selects the lambda minimizing GCV.
245///
246/// # Arguments
247/// * `data` — Functional data matrix (n × m)
248/// * `argvals` — Evaluation points (length m)
249/// * `basis_type` — Type of basis system
250/// * `nbasis` — Number of basis functions
251/// * `lfd_order` — Derivative order for penalty
252/// * `log_lambda_range` — Range of log10(lambda) to search, e.g. (-8.0, 4.0)
253/// * `n_grid` — Number of grid points for the search
254pub fn smooth_basis_gcv(
255    data: &FdMatrix,
256    argvals: &[f64],
257    basis_type: &BasisType,
258    nbasis: usize,
259    lfd_order: usize,
260    log_lambda_range: (f64, f64),
261    n_grid: usize,
262) -> Option<SmoothBasisResult> {
263    let m = argvals.len();
264    if m == 0 || nbasis < 2 || n_grid < 2 {
265        return None;
266    }
267
268    // Compute penalty matrix once
269    let penalty = match basis_type {
270        BasisType::Bspline { order } => bspline_penalty_matrix(argvals, nbasis, *order, lfd_order),
271        BasisType::Fourier { period } => fourier_penalty_matrix(nbasis, *period, lfd_order),
272    };
273
274    let (lo, hi) = log_lambda_range;
275    let mut best_gcv = f64::INFINITY;
276    let mut best_result: Option<SmoothBasisResult> = None;
277
278    for i in 0..n_grid {
279        let log_lam = lo + (hi - lo) * i as f64 / (n_grid - 1) as f64;
280        let lam = 10.0_f64.powf(log_lam);
281
282        let fdpar = FdPar {
283            basis_type: basis_type.clone(),
284            nbasis,
285            lambda: lam,
286            lfd_order,
287            penalty_matrix: penalty.clone(),
288        };
289
290        if let Ok(result) = smooth_basis(data, argvals, &fdpar) {
291            if result.gcv < best_gcv {
292                best_gcv = result.gcv;
293                best_result = Some(result);
294            }
295        }
296    }
297
298    best_result
299}
300
301// ─── Config Structs ─────────────────────────────────────────────────────────
302
303/// Configuration for GCV-based smoothing parameter selection.
304///
305/// Collects all tuning parameters for [`smooth_basis_gcv_with_config`], with
306/// sensible defaults obtained via [`SmoothBasisGcvConfig::default()`].
307///
308/// # Example
309/// ```no_run
310/// use fdars_core::smooth_basis::{SmoothBasisGcvConfig, BasisType};
311///
312/// let config = SmoothBasisGcvConfig {
313///     nbasis: 20,
314///     n_grid: 100,
315///     ..SmoothBasisGcvConfig::default()
316/// };
317/// ```
318#[derive(Debug, Clone, PartialEq)]
319pub struct SmoothBasisGcvConfig {
320    /// Basis type (BSpline or Fourier).
321    pub basis_type: BasisType,
322    /// Number of basis functions (default: 15).
323    pub nbasis: usize,
324    /// Order of the roughness penalty differential operator (default: 2).
325    pub lfd_order: usize,
326    /// Range of log10(lambda) values to search (default: (-10.0, 2.0)).
327    pub log_lambda_range: (f64, f64),
328    /// Number of grid points in the lambda search (default: 50).
329    pub n_grid: usize,
330}
331
332impl Default for SmoothBasisGcvConfig {
333    fn default() -> Self {
334        Self {
335            basis_type: BasisType::Bspline { order: 4 },
336            nbasis: 15,
337            lfd_order: 2,
338            log_lambda_range: (-10.0, 2.0),
339            n_grid: 50,
340        }
341    }
342}
343
344/// Perform basis-penalized smoothing with GCV-optimal lambda using a config struct.
345///
346/// This is the config-based alternative to [`smooth_basis_gcv`]. It takes data
347/// parameters directly and reads all tuning parameters from the config.
348///
349/// # Arguments
350/// * `data` — Functional data matrix (n × m)
351/// * `argvals` — Evaluation points (length m)
352/// * `config` — Tuning parameters
353///
354/// # Errors
355///
356/// Returns [`crate::FdarError::ComputationFailed`] if no valid smoothing result
357/// is found for any lambda in the search grid.
358#[must_use = "expensive computation whose result should not be discarded"]
359pub fn smooth_basis_gcv_with_config(
360    data: &FdMatrix,
361    argvals: &[f64],
362    config: &SmoothBasisGcvConfig,
363) -> Result<SmoothBasisResult, crate::FdarError> {
364    smooth_basis_gcv(
365        data,
366        argvals,
367        &config.basis_type,
368        config.nbasis,
369        config.lfd_order,
370        config.log_lambda_range,
371        config.n_grid,
372    )
373    .ok_or_else(|| crate::FdarError::ComputationFailed {
374        operation: "smooth_basis_gcv_with_config",
375        detail: "no valid smoothing result found in GCV lambda search".to_string(),
376    })
377}
378
379/// Configuration for cross-validation-based basis selection.
380///
381/// Collects all tuning parameters for [`basis_nbasis_cv_with_config`], with
382/// sensible defaults obtained via [`BasisNbasisCvConfig::default()`].
383///
384/// # Example
385/// ```no_run
386/// use fdars_core::smooth_basis::{BasisNbasisCvConfig, BasisType, BasisCriterion};
387///
388/// let config = BasisNbasisCvConfig {
389///     nbasis_range: (5, 25),
390///     criterion: BasisCriterion::Aic,
391///     ..BasisNbasisCvConfig::default()
392/// };
393/// ```
394#[derive(Debug, Clone, PartialEq)]
395pub struct BasisNbasisCvConfig {
396    /// Basis type (default: BSpline with order 4).
397    pub basis_type: BasisType,
398    /// Range of nbasis values to try, inclusive (default: (5, 30)).
399    pub nbasis_range: (usize, usize),
400    /// Roughness penalty lambda (default: 1e-4).
401    pub lambda: f64,
402    /// Penalty order (default: 2).
403    pub lfd_order: usize,
404    /// Number of CV folds (default: 5). Only used when `criterion` is `Cv`.
405    pub n_folds: usize,
406    /// Selection criterion (default: `Gcv`).
407    pub criterion: BasisCriterion,
408}
409
410impl Default for BasisNbasisCvConfig {
411    fn default() -> Self {
412        Self {
413            basis_type: BasisType::Bspline { order: 4 },
414            nbasis_range: (5, 30),
415            lambda: 1e-4,
416            lfd_order: 2,
417            n_folds: 5,
418            criterion: BasisCriterion::Gcv,
419        }
420    }
421}
422
423/// Select the optimal number of basis functions using a config struct.
424///
425/// This is the config-based alternative to [`basis_nbasis_cv`]. It takes data
426/// parameters directly and reads all tuning parameters from the config.
427///
428/// The `nbasis_range` tuple `(lo, hi)` is expanded to `lo..=hi` to form the
429/// candidate set.
430///
431/// # Arguments
432/// * `data` — Functional data matrix (n × m)
433/// * `argvals` — Evaluation points (length m)
434/// * `config` — Tuning parameters
435///
436/// # Errors
437///
438/// Returns [`crate::FdarError::ComputationFailed`] if no valid result is found
439/// for any nbasis in the search range.
440#[must_use = "expensive computation whose result should not be discarded"]
441pub fn basis_nbasis_cv_with_config(
442    data: &FdMatrix,
443    argvals: &[f64],
444    config: &BasisNbasisCvConfig,
445) -> Result<BasisNbasisCvResult, crate::FdarError> {
446    let nbasis_range: Vec<usize> = (config.nbasis_range.0..=config.nbasis_range.1).collect();
447    basis_nbasis_cv(
448        data,
449        argvals,
450        &nbasis_range,
451        &config.basis_type,
452        config.criterion,
453        config.n_folds,
454        config.lambda,
455    )
456    .ok_or_else(|| crate::FdarError::ComputationFailed {
457        operation: "basis_nbasis_cv_with_config",
458        detail: "no valid result found in nbasis CV search".to_string(),
459    })
460}
461
462// ─── Internal Helpers ───────────────────────────────────────────────────────
463
464/// Differentiate column-major basis matrix `lfd_order` times using gradient_uniform.
465fn differentiate_basis_columns(
466    basis: &[f64],
467    n_quad: usize,
468    nbasis: usize,
469    h: f64,
470    lfd_order: usize,
471) -> Vec<f64> {
472    let mut deriv = basis.to_vec();
473    for _ in 0..lfd_order {
474        let mut new_deriv = vec![0.0; n_quad * nbasis];
475        for j in 0..nbasis {
476            let col: Vec<f64> = (0..n_quad).map(|i| deriv[i + j * n_quad]).collect();
477            let grad = crate::helpers::gradient_uniform(&col, h);
478            for i in 0..n_quad {
479                new_deriv[i + j * n_quad] = grad[i];
480            }
481        }
482        deriv = new_deriv;
483    }
484    deriv
485}
486
487/// Integrate symmetric penalty: R[j,k] = ∫ D^m B_j · D^m B_k dt.
488fn integrate_symmetric_penalty(
489    deriv_basis: &[f64],
490    weights: &[f64],
491    k: usize,
492    n_quad: usize,
493) -> Vec<f64> {
494    let mut penalty = vec![0.0; k * k];
495    for j in 0..k {
496        for l in j..k {
497            let mut val = 0.0;
498            for i in 0..n_quad {
499                val += deriv_basis[i + j * n_quad] * deriv_basis[i + l * n_quad] * weights[i];
500            }
501            penalty[j + l * k] = val;
502            penalty[l + j * k] = val;
503        }
504    }
505    penalty
506}
507
508/// Evaluate basis functions on argvals, returning (flat column-major, actual_nbasis).
509fn evaluate_basis(argvals: &[f64], basis_type: &BasisType, nbasis: usize) -> (Vec<f64>, usize) {
510    let m = argvals.len();
511    match basis_type {
512        BasisType::Bspline { order } => {
513            let nknots = nbasis.saturating_sub(*order).max(2);
514            let basis = bspline_basis(argvals, nknots, *order);
515            let actual = basis.len() / m;
516            (basis, actual)
517        }
518        BasisType::Fourier { period } => {
519            let basis = fourier_basis_with_period(argvals, nbasis, *period);
520            (basis, nbasis)
521        }
522    }
523}
524
525/// Invert the penalized system matrix via Cholesky or SVD pseudoinverse.
526fn invert_penalized_system(system: &DMatrix<f64>, k: usize) -> Option<DMatrix<f64>> {
527    if let Some(chol) = system.clone().cholesky() {
528        return Some(chol.inverse());
529    }
530    // SVD fallback
531    let svd = nalgebra::SVD::new(system.clone(), true, true);
532    let u = svd.u.as_ref()?;
533    let v_t = svd.v_t.as_ref()?;
534    let max_sv: f64 = svd.singular_values.iter().copied().fold(0.0_f64, f64::max);
535    let eps = 1e-10 * max_sv;
536    let mut inv = DMatrix::<f64>::zeros(k, k);
537    for ii in 0..k {
538        for jj in 0..k {
539            let mut sum = 0.0;
540            for s in 0..k.min(svd.singular_values.len()) {
541                if svd.singular_values[s] > eps {
542                    sum += v_t[(s, ii)] / svd.singular_values[s] * u[(jj, s)];
543                }
544            }
545            inv[(ii, jj)] = sum;
546        }
547    }
548    Some(inv)
549}
550
551/// Project all curves onto basis, returning (coefficients, fitted, total_rss).
552fn project_all_curves(
553    data: &FdMatrix,
554    b_mat: &DMatrix<f64>,
555    proj: &DMatrix<f64>,
556    n: usize,
557    m: usize,
558    k: usize,
559) -> (FdMatrix, FdMatrix, f64) {
560    let mut all_coefs = FdMatrix::zeros(n, k);
561    let mut all_fitted = FdMatrix::zeros(n, m);
562    let mut total_rss = 0.0;
563
564    for i in 0..n {
565        let curve: Vec<f64> = (0..m).map(|j| data[(i, j)]).collect();
566        let y_vec = nalgebra::DVector::from_vec(curve.clone());
567        let coefs = proj * &y_vec;
568
569        for j in 0..k {
570            all_coefs[(i, j)] = coefs[j];
571        }
572        let fitted = b_mat * &coefs;
573        for j in 0..m {
574            all_fitted[(i, j)] = fitted[j];
575            let resid = curve[j] - fitted[j];
576            total_rss += resid * resid;
577        }
578    }
579
580    (all_coefs, all_fitted, total_rss)
581}
582
583/// Compute GCV score.
584fn compute_gcv(rss: f64, n_points: f64, edf: f64, m: usize) -> f64 {
585    let gcv_denom = 1.0 - edf / m as f64;
586    if gcv_denom.abs() > 1e-10 {
587        (rss / n_points) / (gcv_denom * gcv_denom)
588    } else {
589        f64::INFINITY
590    }
591}
592
593// ─── Nbasis Selection via CV ────────────────────────────────────────────────
594
595/// Criterion for nbasis selection.
596#[derive(Debug, Clone, Copy, PartialEq)]
597pub enum BasisCriterion {
598    /// Generalized cross-validation.
599    Gcv,
600    /// Leave-one-out cross-validation (k-fold).
601    Cv,
602    /// Akaike Information Criterion.
603    Aic,
604    /// Bayesian Information Criterion.
605    Bic,
606}
607
608/// Result of nbasis selection.
609#[derive(Debug, Clone, PartialEq)]
610#[non_exhaustive]
611pub struct BasisNbasisCvResult {
612    /// Optimal number of basis functions.
613    pub optimal_nbasis: usize,
614    /// Score for each nbasis tested.
615    pub scores: Vec<f64>,
616    /// Range of nbasis values tested.
617    pub nbasis_range: Vec<usize>,
618    /// Criterion used.
619    pub criterion: BasisCriterion,
620}
621
622/// Evaluate information criterion (GCV/AIC/BIC) for a range of nbasis values.
623fn evaluate_nbasis_info_criterion(
624    data: &FdMatrix,
625    argvals: &[f64],
626    nbasis_range: &[usize],
627    basis_type: &BasisType,
628    criterion: BasisCriterion,
629    lambda: f64,
630) -> Vec<f64> {
631    let mut scores = Vec::with_capacity(nbasis_range.len());
632    for &nb in nbasis_range {
633        if nb < 2 {
634            scores.push(f64::INFINITY);
635            continue;
636        }
637        let penalty = match basis_type {
638            BasisType::Bspline { order } => bspline_penalty_matrix(argvals, nb, *order, 2),
639            BasisType::Fourier { period } => fourier_penalty_matrix(nb, *period, 2),
640        };
641        let fdpar = FdPar {
642            basis_type: basis_type.clone(),
643            nbasis: nb,
644            lambda,
645            lfd_order: 2,
646            penalty_matrix: penalty,
647        };
648        match smooth_basis(data, argvals, &fdpar) {
649            Ok(result) => {
650                let score = match criterion {
651                    BasisCriterion::Gcv => result.gcv,
652                    BasisCriterion::Aic => result.aic,
653                    BasisCriterion::Bic => result.bic,
654                    BasisCriterion::Cv => unreachable!(),
655                };
656                scores.push(score);
657            }
658            Err(_) => scores.push(f64::INFINITY),
659        }
660    }
661    scores
662}
663
664/// Evaluate nbasis via k-fold cross-validation of reconstruction error.
665fn evaluate_nbasis_cv(
666    data: &FdMatrix,
667    argvals: &[f64],
668    nbasis_range: &[usize],
669    basis_type: &BasisType,
670    lambda: f64,
671    n_folds: usize,
672) -> Vec<f64> {
673    let (n, m) = data.shape();
674    let n_folds = n_folds.max(2);
675    let folds = crate::cv::create_folds(n, n_folds, 42);
676    let mut scores = Vec::with_capacity(nbasis_range.len());
677
678    for &nb in nbasis_range {
679        if nb < 2 {
680            scores.push(f64::INFINITY);
681            continue;
682        }
683        let penalty = match basis_type {
684            BasisType::Bspline { order } => bspline_penalty_matrix(argvals, nb, *order, 2),
685            BasisType::Fourier { period } => fourier_penalty_matrix(nb, *period, 2),
686        };
687
688        let mut total_mse = 0.0;
689        let mut total_count = 0;
690
691        for fold in 0..n_folds {
692            let (train_idx, test_idx) = crate::cv::fold_indices(&folds, fold);
693            if train_idx.is_empty() || test_idx.is_empty() {
694                continue;
695            }
696            let train_data = crate::cv::subset_rows(data, &train_idx);
697            let fdpar = FdPar {
698                basis_type: basis_type.clone(),
699                nbasis: nb,
700                lambda,
701                lfd_order: 2,
702                penalty_matrix: penalty.clone(),
703            };
704
705            if let Ok(train_result) = smooth_basis(&train_data, argvals, &fdpar) {
706                let (basis_flat, actual_k) = evaluate_basis(argvals, basis_type, nb);
707                let b_mat = DMatrix::from_column_slice(m, actual_k, &basis_flat);
708                let r_mat =
709                    DMatrix::from_column_slice(actual_k, actual_k, &train_result.penalty_matrix);
710                let btb = b_mat.transpose() * &b_mat;
711                let ridge_eps = 1e-10;
712                let system: DMatrix<f64> = &btb
713                    + lambda * &r_mat
714                    + ridge_eps * DMatrix::<f64>::identity(actual_k, actual_k);
715
716                if let Some(system_inv) = invert_penalized_system(&system, actual_k) {
717                    let proj = &system_inv * b_mat.transpose();
718                    for &ti in &test_idx {
719                        let curve: Vec<f64> = (0..m).map(|j| data[(ti, j)]).collect();
720                        let y_vec = nalgebra::DVector::from_vec(curve.clone());
721                        let coefs = &proj * &y_vec;
722                        let fitted = &b_mat * &coefs;
723                        let mse: f64 =
724                            (0..m).map(|j| (curve[j] - fitted[j]).powi(2)).sum::<f64>() / m as f64;
725                        total_mse += mse;
726                        total_count += 1;
727                    }
728                }
729            }
730        }
731
732        if total_count > 0 {
733            scores.push(total_mse / f64::from(total_count));
734        } else {
735            scores.push(f64::INFINITY);
736        }
737    }
738    scores
739}
740
741/// Select the optimal number of basis functions using multiple criteria
742/// (R's `fdata2basis_cv`).
743pub fn basis_nbasis_cv(
744    data: &FdMatrix,
745    argvals: &[f64],
746    nbasis_range: &[usize],
747    basis_type: &BasisType,
748    criterion: BasisCriterion,
749    n_folds: usize,
750    lambda: f64,
751) -> Option<BasisNbasisCvResult> {
752    let (n, m) = data.shape();
753    if n == 0 || m == 0 || argvals.len() != m || nbasis_range.is_empty() {
754        return None;
755    }
756
757    let scores = match criterion {
758        BasisCriterion::Gcv | BasisCriterion::Aic | BasisCriterion::Bic => {
759            evaluate_nbasis_info_criterion(
760                data,
761                argvals,
762                nbasis_range,
763                basis_type,
764                criterion,
765                lambda,
766            )
767        }
768        BasisCriterion::Cv => {
769            evaluate_nbasis_cv(data, argvals, nbasis_range, basis_type, lambda, n_folds)
770        }
771    };
772
773    let (best_idx, _) = scores
774        .iter()
775        .enumerate()
776        .min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))?;
777
778    Some(BasisNbasisCvResult {
779        optimal_nbasis: nbasis_range[best_idx],
780        scores,
781        nbasis_range: nbasis_range.to_vec(),
782        criterion,
783    })
784}
785
786#[cfg(test)]
787mod tests {
788    use super::*;
789    use crate::test_helpers::uniform_grid;
790    use std::f64::consts::PI;
791
792    #[test]
793    fn test_bspline_penalty_matrix_symmetric() {
794        let t = uniform_grid(101);
795        let penalty = bspline_penalty_matrix(&t, 15, 4, 2);
796        let _k = 15; // may differ from actual due to knot construction
797        let actual_k = (penalty.len() as f64).sqrt() as usize;
798        for i in 0..actual_k {
799            for j in 0..actual_k {
800                assert!(
801                    (penalty[i + j * actual_k] - penalty[j + i * actual_k]).abs() < 1e-10,
802                    "Penalty matrix not symmetric at ({}, {})",
803                    i,
804                    j
805                );
806            }
807        }
808    }
809
810    #[test]
811    fn test_bspline_penalty_matrix_positive_semidefinite() {
812        let t = uniform_grid(101);
813        let penalty = bspline_penalty_matrix(&t, 10, 4, 2);
814        let k = (penalty.len() as f64).sqrt() as usize;
815        // Diagonal elements should be non-negative
816        for i in 0..k {
817            assert!(
818                penalty[i + i * k] >= -1e-10,
819                "Diagonal element {} is negative: {}",
820                i,
821                penalty[i + i * k]
822            );
823        }
824    }
825
826    #[test]
827    fn test_fourier_penalty_diagonal() {
828        let penalty = fourier_penalty_matrix(7, 1.0, 2);
829        // Should be diagonal
830        for i in 0..7 {
831            for j in 0..7 {
832                if i != j {
833                    assert!(
834                        penalty[i + j * 7].abs() < 1e-10,
835                        "Off-diagonal ({},{}) = {}",
836                        i,
837                        j,
838                        penalty[i + j * 7]
839                    );
840                }
841            }
842        }
843        // Constant term should have zero penalty
844        assert!(penalty[0].abs() < 1e-10);
845        // Higher frequency terms should have larger penalties
846        assert!(penalty[1 + 7] > 0.0);
847        assert!(penalty[3 + 3 * 7] > penalty[1 + 7]);
848    }
849
850    #[test]
851    fn test_smooth_basis_bspline() {
852        let m = 101;
853        let n = 5;
854        let t = uniform_grid(m);
855
856        // Generate noisy sine curves
857        let mut data = FdMatrix::zeros(n, m);
858        for i in 0..n {
859            for j in 0..m {
860                data[(i, j)] = (2.0 * PI * t[j]).sin() + 0.1 * (i as f64 * 0.3 + j as f64 * 0.01);
861            }
862        }
863
864        let nbasis = 15;
865        let penalty = bspline_penalty_matrix(&t, nbasis, 4, 2);
866        let _actual_k = (penalty.len() as f64).sqrt() as usize;
867
868        let fdpar = FdPar {
869            basis_type: BasisType::Bspline { order: 4 },
870            nbasis,
871            lambda: 1e-4,
872            lfd_order: 2,
873            penalty_matrix: penalty,
874        };
875
876        let result = smooth_basis(&data, &t, &fdpar);
877        assert!(result.is_ok(), "smooth_basis should succeed");
878
879        let res = result.unwrap();
880        assert_eq!(res.fitted.shape(), (n, m));
881        assert_eq!(res.coefficients.nrows(), n);
882        assert!(res.edf > 0.0, "EDF should be positive");
883        assert!(res.gcv > 0.0, "GCV should be positive");
884    }
885
886    #[test]
887    fn test_smooth_basis_fourier() {
888        let m = 101;
889        let n = 3;
890        let t = uniform_grid(m);
891
892        let mut data = FdMatrix::zeros(n, m);
893        for i in 0..n {
894            for j in 0..m {
895                data[(i, j)] = (2.0 * PI * t[j]).sin() + (4.0 * PI * t[j]).cos();
896            }
897        }
898
899        let nbasis = 7;
900        let period = 1.0;
901        let penalty = fourier_penalty_matrix(nbasis, period, 2);
902
903        let fdpar = FdPar {
904            basis_type: BasisType::Fourier { period },
905            nbasis,
906            lambda: 1e-6,
907            lfd_order: 2,
908            penalty_matrix: penalty,
909        };
910
911        let result = smooth_basis(&data, &t, &fdpar);
912        assert!(result.is_ok());
913
914        let res = result.unwrap();
915        // Fourier basis should fit periodic data well
916        for j in 0..m {
917            let expected = (2.0 * PI * t[j]).sin() + (4.0 * PI * t[j]).cos();
918            assert!(
919                (res.fitted[(0, j)] - expected).abs() < 0.1,
920                "Fourier fit poor at j={}: got {}, expected {}",
921                j,
922                res.fitted[(0, j)],
923                expected
924            );
925        }
926    }
927
928    #[test]
929    fn test_smooth_basis_gcv_selects_reasonable_lambda() {
930        let m = 101;
931        let n = 5;
932        let t = uniform_grid(m);
933
934        let mut data = FdMatrix::zeros(n, m);
935        for i in 0..n {
936            for j in 0..m {
937                data[(i, j)] =
938                    (2.0 * PI * t[j]).sin() + 0.1 * ((i * 37 + j * 13) % 20) as f64 / 20.0;
939            }
940        }
941
942        let basis_type = BasisType::Bspline { order: 4 };
943        let result = smooth_basis_gcv(&data, &t, &basis_type, 15, 2, (-8.0, 4.0), 25);
944        assert!(result.is_some(), "GCV search should succeed");
945    }
946
947    #[test]
948    fn test_smooth_basis_large_lambda_reduces_edf() {
949        let m = 101;
950        let n = 3;
951        let t = uniform_grid(m);
952
953        let mut data = FdMatrix::zeros(n, m);
954        for i in 0..n {
955            for j in 0..m {
956                data[(i, j)] = (2.0 * PI * t[j]).sin();
957            }
958        }
959
960        let nbasis = 15;
961        let penalty = bspline_penalty_matrix(&t, nbasis, 4, 2);
962        let _actual_k = (penalty.len() as f64).sqrt() as usize;
963
964        let fdpar_small = FdPar {
965            basis_type: BasisType::Bspline { order: 4 },
966            nbasis,
967            lambda: 1e-8,
968            lfd_order: 2,
969            penalty_matrix: penalty.clone(),
970        };
971        let fdpar_large = FdPar {
972            basis_type: BasisType::Bspline { order: 4 },
973            nbasis,
974            lambda: 1e2,
975            lfd_order: 2,
976            penalty_matrix: penalty,
977        };
978
979        let res_small = smooth_basis(&data, &t, &fdpar_small).unwrap();
980        let res_large = smooth_basis(&data, &t, &fdpar_large).unwrap();
981
982        assert!(
983            res_large.edf < res_small.edf,
984            "Larger lambda should reduce EDF: {} vs {}",
985            res_large.edf,
986            res_small.edf
987        );
988    }
989
990    // ============== basis_nbasis_cv tests ==============
991
992    #[test]
993    fn test_basis_nbasis_cv_gcv() {
994        let m = 101;
995        let n = 5;
996        let t = uniform_grid(m);
997        let mut data = FdMatrix::zeros(n, m);
998        for i in 0..n {
999            for j in 0..m {
1000                data[(i, j)] =
1001                    (2.0 * PI * t[j]).sin() + 0.1 * ((i * 37 + j * 13) % 20) as f64 / 20.0;
1002            }
1003        }
1004
1005        let nbasis_range: Vec<usize> = (4..=20).step_by(2).collect();
1006        let result = basis_nbasis_cv(
1007            &data,
1008            &t,
1009            &nbasis_range,
1010            &BasisType::Bspline { order: 4 },
1011            BasisCriterion::Gcv,
1012            5,
1013            1e-4,
1014        );
1015        assert!(result.is_some());
1016        let res = result.unwrap();
1017        assert!(nbasis_range.contains(&res.optimal_nbasis));
1018        assert_eq!(res.scores.len(), nbasis_range.len());
1019        assert_eq!(res.criterion, BasisCriterion::Gcv);
1020    }
1021
1022    #[test]
1023    fn test_basis_nbasis_cv_aic_bic() {
1024        let m = 51;
1025        let n = 5;
1026        let t = uniform_grid(m);
1027        let mut data = FdMatrix::zeros(n, m);
1028        for i in 0..n {
1029            for j in 0..m {
1030                data[(i, j)] = (2.0 * PI * t[j]).sin();
1031            }
1032        }
1033
1034        let nbasis_range: Vec<usize> = vec![5, 7, 9, 11];
1035        let aic_result = basis_nbasis_cv(
1036            &data,
1037            &t,
1038            &nbasis_range,
1039            &BasisType::Bspline { order: 4 },
1040            BasisCriterion::Aic,
1041            5,
1042            0.0,
1043        );
1044        let bic_result = basis_nbasis_cv(
1045            &data,
1046            &t,
1047            &nbasis_range,
1048            &BasisType::Bspline { order: 4 },
1049            BasisCriterion::Bic,
1050            5,
1051            0.0,
1052        );
1053        assert!(aic_result.is_some());
1054        assert!(bic_result.is_some());
1055    }
1056
1057    #[test]
1058    fn test_basis_nbasis_cv_kfold() {
1059        let m = 51;
1060        let n = 10;
1061        let t = uniform_grid(m);
1062        let mut data = FdMatrix::zeros(n, m);
1063        for i in 0..n {
1064            for j in 0..m {
1065                data[(i, j)] = (2.0 * PI * t[j]).sin() + 0.05 * ((i * 7 + j * 3) % 10) as f64;
1066            }
1067        }
1068
1069        let nbasis_range: Vec<usize> = vec![5, 7, 9];
1070        let result = basis_nbasis_cv(
1071            &data,
1072            &t,
1073            &nbasis_range,
1074            &BasisType::Bspline { order: 4 },
1075            BasisCriterion::Cv,
1076            5,
1077            1e-4,
1078        );
1079        assert!(result.is_some());
1080        let res = result.unwrap();
1081        assert!(nbasis_range.contains(&res.optimal_nbasis));
1082        assert_eq!(res.criterion, BasisCriterion::Cv);
1083    }
1084
1085    // ============== Comprehensive additional tests ==============
1086
1087    // Helper: generate standard test data (sine + high-freq component)
1088    fn make_test_data(n: usize, m: usize) -> (FdMatrix, Vec<f64>) {
1089        let t: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
1090        let mut data = FdMatrix::zeros(n, m);
1091        for i in 0..n {
1092            for j in 0..m {
1093                data[(i, j)] = (2.0 * PI * t[j]).sin()
1094                    + 0.1 * (10.0 * t[j]).sin()
1095                    + 0.05 * ((i * 37 + j * 13) % 20) as f64 / 20.0;
1096            }
1097        }
1098        (data, t)
1099    }
1100
1101    // Helper: create an FdPar for B-spline smoothing
1102    fn make_bspline_fdpar(argvals: &[f64], nbasis: usize, lambda: f64) -> FdPar {
1103        let penalty = bspline_penalty_matrix(argvals, nbasis, 4, 2);
1104        FdPar {
1105            basis_type: BasisType::Bspline { order: 4 },
1106            nbasis,
1107            lambda,
1108            lfd_order: 2,
1109            penalty_matrix: penalty,
1110        }
1111    }
1112
1113    // Helper: create an FdPar for Fourier smoothing
1114    fn make_fourier_fdpar(nbasis: usize, period: f64, lambda: f64) -> FdPar {
1115        let penalty = fourier_penalty_matrix(nbasis, period, 2);
1116        FdPar {
1117            basis_type: BasisType::Fourier { period },
1118            nbasis,
1119            lambda,
1120            lfd_order: 2,
1121            penalty_matrix: penalty,
1122        }
1123    }
1124
1125    // ─── BasisType enum tests ───────────────────────────────────────────────
1126
1127    #[test]
1128    fn test_basis_type_bspline_variant() {
1129        let bt = BasisType::Bspline { order: 4 };
1130        assert_eq!(bt, BasisType::Bspline { order: 4 });
1131        // Different orders are not equal
1132        assert_ne!(bt, BasisType::Bspline { order: 3 });
1133    }
1134
1135    #[test]
1136    fn test_basis_type_fourier_variant() {
1137        let bt = BasisType::Fourier { period: 1.0 };
1138        assert_eq!(bt, BasisType::Fourier { period: 1.0 });
1139        assert_ne!(bt, BasisType::Fourier { period: 2.0 });
1140    }
1141
1142    #[test]
1143    fn test_basis_type_cross_variant_inequality() {
1144        let bspline = BasisType::Bspline { order: 4 };
1145        let fourier = BasisType::Fourier { period: 1.0 };
1146        assert_ne!(bspline, fourier);
1147    }
1148
1149    #[test]
1150    fn test_basis_type_clone_and_debug() {
1151        let bt = BasisType::Bspline { order: 4 };
1152        let cloned = bt.clone();
1153        assert_eq!(bt, cloned);
1154        let debug_str = format!("{:?}", bt);
1155        assert!(debug_str.contains("Bspline"));
1156        assert!(debug_str.contains("4"));
1157    }
1158
1159    // ─── FdPar struct tests ─────────────────────────────────────────────────
1160
1161    #[test]
1162    fn test_fdpar_construction_and_fields() {
1163        let penalty = vec![1.0, 0.0, 0.0, 1.0];
1164        let fdpar = FdPar {
1165            basis_type: BasisType::Bspline { order: 4 },
1166            nbasis: 2,
1167            lambda: 0.01,
1168            lfd_order: 2,
1169            penalty_matrix: penalty.clone(),
1170        };
1171        assert_eq!(fdpar.nbasis, 2);
1172        assert!((fdpar.lambda - 0.01).abs() < 1e-15);
1173        assert_eq!(fdpar.lfd_order, 2);
1174        assert_eq!(fdpar.penalty_matrix.len(), 4);
1175    }
1176
1177    #[test]
1178    fn test_fdpar_clone_and_debug() {
1179        let t = uniform_grid(50);
1180        let fdpar = make_bspline_fdpar(&t, 8, 1e-3);
1181        let cloned = fdpar.clone();
1182        assert_eq!(fdpar, cloned);
1183        let debug_str = format!("{:?}", fdpar);
1184        assert!(debug_str.contains("FdPar"));
1185    }
1186
1187    // ─── BasisCriterion enum tests ──────────────────────────────────────────
1188
1189    #[test]
1190    fn test_basis_criterion_variants() {
1191        assert_eq!(BasisCriterion::Gcv, BasisCriterion::Gcv);
1192        assert_eq!(BasisCriterion::Cv, BasisCriterion::Cv);
1193        assert_eq!(BasisCriterion::Aic, BasisCriterion::Aic);
1194        assert_eq!(BasisCriterion::Bic, BasisCriterion::Bic);
1195        assert_ne!(BasisCriterion::Gcv, BasisCriterion::Aic);
1196        assert_ne!(BasisCriterion::Cv, BasisCriterion::Bic);
1197    }
1198
1199    #[test]
1200    fn test_basis_criterion_copy() {
1201        let c = BasisCriterion::Gcv;
1202        let copied = c; // Copy
1203        assert_eq!(c, copied);
1204    }
1205
1206    #[test]
1207    fn test_basis_criterion_debug() {
1208        let debug_str = format!("{:?}", BasisCriterion::Bic);
1209        assert!(debug_str.contains("Bic"));
1210    }
1211
1212    // ─── SmoothBasisResult tests ────────────────────────────────────────────
1213
1214    #[test]
1215    fn test_smooth_basis_result_all_fields() {
1216        let (data, t) = make_test_data(3, 50);
1217        let fdpar = make_bspline_fdpar(&t, 10, 1e-4);
1218        let res = smooth_basis(&data, &t, &fdpar).unwrap();
1219
1220        // coefficients: n curves x k basis functions
1221        assert_eq!(res.coefficients.nrows(), 3);
1222        assert!(res.coefficients.ncols() > 0);
1223        assert_eq!(res.nbasis, res.coefficients.ncols());
1224        // fitted: n x m
1225        assert_eq!(res.fitted.shape(), (3, 50));
1226        // edf should be between 1 and nbasis
1227        assert!(res.edf > 0.0 && res.edf <= res.nbasis as f64);
1228        // gcv, aic, bic should be finite
1229        assert!(res.gcv.is_finite());
1230        assert!(res.aic.is_finite());
1231        assert!(res.bic.is_finite());
1232        // penalty_matrix should be k x k
1233        let k = res.nbasis;
1234        assert_eq!(res.penalty_matrix.len(), k * k);
1235    }
1236
1237    #[test]
1238    fn test_smooth_basis_result_clone() {
1239        let (data, t) = make_test_data(2, 50);
1240        let fdpar = make_bspline_fdpar(&t, 8, 1e-3);
1241        let res = smooth_basis(&data, &t, &fdpar).unwrap();
1242        let cloned = res.clone();
1243        assert_eq!(res, cloned);
1244    }
1245
1246    // ─── smooth_basis: B-spline detailed tests ──────────────────────────────
1247
1248    #[test]
1249    fn test_smooth_basis_bspline_coefficient_shape() {
1250        let (data, t) = make_test_data(4, 50);
1251        let nbasis = 12;
1252        let fdpar = make_bspline_fdpar(&t, nbasis, 1e-4);
1253        let res = smooth_basis(&data, &t, &fdpar).unwrap();
1254        assert_eq!(res.coefficients.nrows(), 4);
1255        // actual nbasis may differ from requested due to knot construction
1256        assert!(res.coefficients.ncols() >= 2);
1257        assert_eq!(res.nbasis, res.coefficients.ncols());
1258    }
1259
1260    #[test]
1261    fn test_smooth_basis_bspline_fitted_values_shape() {
1262        let m = 80;
1263        let n = 6;
1264        let (data, t) = make_test_data(n, m);
1265        let fdpar = make_bspline_fdpar(&t, 15, 1e-4);
1266        let res = smooth_basis(&data, &t, &fdpar).unwrap();
1267        assert_eq!(res.fitted.shape(), (n, m));
1268    }
1269
1270    #[test]
1271    fn test_smooth_basis_bspline_zero_lambda_interpolates() {
1272        // With lambda=0, the smoother should nearly interpolate the data
1273        let m = 30;
1274        let n = 2;
1275        let (data, t) = make_test_data(n, m);
1276        let fdpar = make_bspline_fdpar(&t, 15, 0.0);
1277        let res = smooth_basis(&data, &t, &fdpar).unwrap();
1278
1279        // Residuals should be very small (near interpolation)
1280        let mut max_resid = 0.0_f64;
1281        for i in 0..n {
1282            for j in 0..m {
1283                let resid = (data[(i, j)] - res.fitted[(i, j)]).abs();
1284                max_resid = max_resid.max(resid);
1285            }
1286        }
1287        assert!(
1288            max_resid < 0.5,
1289            "Zero-lambda B-spline should closely interpolate; max_resid = {}",
1290            max_resid
1291        );
1292    }
1293
1294    #[test]
1295    fn test_smooth_basis_bspline_large_lambda_oversmooths() {
1296        // With very large lambda, the fit should be much smoother (lower variance)
1297        // than with small lambda
1298        let m = 50;
1299        let n = 1;
1300        let (data, t) = make_test_data(n, m);
1301
1302        let fdpar_small = make_bspline_fdpar(&t, 15, 1e-6);
1303        let res_small = smooth_basis(&data, &t, &fdpar_small).unwrap();
1304
1305        let fdpar_large = make_bspline_fdpar(&t, 15, 1e6);
1306        let res_large = smooth_basis(&data, &t, &fdpar_large).unwrap();
1307
1308        let compute_variance = |fitted: &FdMatrix, row: usize, ncols: usize| -> f64 {
1309            let vals: Vec<f64> = (0..ncols).map(|j| fitted[(row, j)]).collect();
1310            let mean = vals.iter().sum::<f64>() / ncols as f64;
1311            vals.iter().map(|&v| (v - mean).powi(2)).sum::<f64>() / ncols as f64
1312        };
1313
1314        let var_small = compute_variance(&res_small.fitted, 0, m);
1315        let var_large = compute_variance(&res_large.fitted, 0, m);
1316        assert!(
1317            var_large < var_small,
1318            "Large lambda should yield lower variance fit: var_large={}, var_small={}",
1319            var_large,
1320            var_small
1321        );
1322    }
1323
1324    #[test]
1325    fn test_smooth_basis_bspline_penalty_effect_on_smoothness() {
1326        // Compare roughness of fits with small vs large lambda
1327        let m = 50;
1328        let n = 1;
1329        let (data, t) = make_test_data(n, m);
1330
1331        let fdpar_small = make_bspline_fdpar(&t, 15, 1e-8);
1332        let fdpar_large = make_bspline_fdpar(&t, 15, 1.0);
1333
1334        let res_small = smooth_basis(&data, &t, &fdpar_small).unwrap();
1335        let res_large = smooth_basis(&data, &t, &fdpar_large).unwrap();
1336
1337        // Measure roughness as sum of squared second differences
1338        let roughness = |fitted: &FdMatrix, row: usize, ncols: usize| -> f64 {
1339            (1..ncols - 1)
1340                .map(|j| {
1341                    let d2 = fitted[(row, j + 1)] - 2.0 * fitted[(row, j)] + fitted[(row, j - 1)];
1342                    d2 * d2
1343                })
1344                .sum::<f64>()
1345        };
1346
1347        let r_small = roughness(&res_small.fitted, 0, m);
1348        let r_large = roughness(&res_large.fitted, 0, m);
1349        assert!(
1350            r_large < r_small,
1351            "Larger lambda should produce smoother fit: roughness_large={}, roughness_small={}",
1352            r_large,
1353            r_small
1354        );
1355    }
1356
1357    #[test]
1358    fn test_smooth_basis_bspline_single_curve() {
1359        let m = 50;
1360        let (data, t) = make_test_data(1, m);
1361        let fdpar = make_bspline_fdpar(&t, 10, 1e-4);
1362        let res = smooth_basis(&data, &t, &fdpar).unwrap();
1363        assert_eq!(res.fitted.nrows(), 1);
1364        assert_eq!(res.fitted.ncols(), m);
1365        assert!(res.gcv.is_finite());
1366    }
1367
1368    #[test]
1369    fn test_smooth_basis_bspline_many_curves() {
1370        let m = 50;
1371        let n = 20;
1372        let (data, t) = make_test_data(n, m);
1373        let fdpar = make_bspline_fdpar(&t, 10, 1e-4);
1374        let res = smooth_basis(&data, &t, &fdpar).unwrap();
1375        assert_eq!(res.fitted.nrows(), n);
1376        assert_eq!(res.coefficients.nrows(), n);
1377    }
1378
1379    #[test]
1380    fn test_smooth_basis_bspline_minimal_nbasis() {
1381        // nbasis = 2 is the minimum allowed
1382        let m = 50;
1383        let (data, t) = make_test_data(1, m);
1384        let fdpar = make_bspline_fdpar(&t, 2, 1e-4);
1385        let res = smooth_basis(&data, &t, &fdpar);
1386        // Should succeed (or at least not panic); the fit may be poor
1387        assert!(res.is_ok());
1388    }
1389
1390    #[test]
1391    fn test_smooth_basis_bspline_different_orders() {
1392        let m = 50;
1393        let (data, t) = make_test_data(2, m);
1394        // Order 3 (quadratic B-splines)
1395        let penalty3 = bspline_penalty_matrix(&t, 10, 3, 2);
1396        let fdpar3 = FdPar {
1397            basis_type: BasisType::Bspline { order: 3 },
1398            nbasis: 10,
1399            lambda: 1e-4,
1400            lfd_order: 2,
1401            penalty_matrix: penalty3,
1402        };
1403        let res3 = smooth_basis(&data, &t, &fdpar3);
1404        assert!(res3.is_ok());
1405
1406        // Order 5 (quartic B-splines)
1407        let penalty5 = bspline_penalty_matrix(&t, 10, 5, 2);
1408        let fdpar5 = FdPar {
1409            basis_type: BasisType::Bspline { order: 5 },
1410            nbasis: 10,
1411            lambda: 1e-4,
1412            lfd_order: 2,
1413            penalty_matrix: penalty5,
1414        };
1415        let res5 = smooth_basis(&data, &t, &fdpar5);
1416        assert!(res5.is_ok());
1417    }
1418
1419    // ─── smooth_basis: Fourier detailed tests ───────────────────────────────
1420
1421    #[test]
1422    fn test_smooth_basis_fourier_coefficient_shape() {
1423        let m = 50;
1424        let n = 3;
1425        let t = uniform_grid(m);
1426        let mut data = FdMatrix::zeros(n, m);
1427        for i in 0..n {
1428            for j in 0..m {
1429                data[(i, j)] = (2.0 * PI * t[j]).sin();
1430            }
1431        }
1432        let nbasis = 7;
1433        let fdpar = make_fourier_fdpar(nbasis, 1.0, 1e-6);
1434        let res = smooth_basis(&data, &t, &fdpar).unwrap();
1435        assert_eq!(res.coefficients.nrows(), n);
1436        assert_eq!(res.coefficients.ncols(), nbasis);
1437        assert_eq!(res.nbasis, nbasis);
1438    }
1439
1440    #[test]
1441    fn test_smooth_basis_fourier_fits_pure_sine() {
1442        // Fourier basis should perfectly fit a pure sine with enough basis fns
1443        let m = 100;
1444        let t = uniform_grid(m);
1445        let mut data = FdMatrix::zeros(1, m);
1446        for j in 0..m {
1447            data[(0, j)] = (2.0 * PI * t[j]).sin();
1448        }
1449        let fdpar = make_fourier_fdpar(5, 1.0, 1e-8);
1450        let res = smooth_basis(&data, &t, &fdpar).unwrap();
1451
1452        for j in 0..m {
1453            let expected = (2.0 * PI * t[j]).sin();
1454            assert!(
1455                (res.fitted[(0, j)] - expected).abs() < 0.05,
1456                "Fourier should fit pure sine; j={}, got={}, expected={}",
1457                j,
1458                res.fitted[(0, j)],
1459                expected
1460            );
1461        }
1462    }
1463
1464    #[test]
1465    fn test_smooth_basis_fourier_different_periods() {
1466        let m = 50;
1467        let t = uniform_grid(m);
1468        let mut data = FdMatrix::zeros(1, m);
1469        for j in 0..m {
1470            data[(0, j)] = (2.0 * PI * t[j]).sin();
1471        }
1472
1473        // Period = 1.0 (matches the data)
1474        let fdpar1 = make_fourier_fdpar(7, 1.0, 1e-6);
1475        let res1 = smooth_basis(&data, &t, &fdpar1).unwrap();
1476
1477        // Period = 2.0 (mismatch, but should still produce a result)
1478        let fdpar2 = make_fourier_fdpar(7, 2.0, 1e-6);
1479        let res2 = smooth_basis(&data, &t, &fdpar2).unwrap();
1480
1481        // Both should succeed and have valid shapes
1482        assert_eq!(res1.fitted.shape(), (1, m));
1483        assert_eq!(res2.fitted.shape(), (1, m));
1484    }
1485
1486    #[test]
1487    fn test_smooth_basis_fourier_zero_lambda() {
1488        let m = 50;
1489        let t = uniform_grid(m);
1490        let mut data = FdMatrix::zeros(1, m);
1491        for j in 0..m {
1492            data[(0, j)] = (2.0 * PI * t[j]).sin() + (4.0 * PI * t[j]).cos();
1493        }
1494        let fdpar = make_fourier_fdpar(9, 1.0, 0.0);
1495        let res = smooth_basis(&data, &t, &fdpar).unwrap();
1496        assert_eq!(res.fitted.shape(), (1, m));
1497        // EDF should be close to nbasis with zero penalty
1498        assert!(res.edf > 1.0);
1499    }
1500
1501    #[test]
1502    fn test_smooth_basis_fourier_large_lambda() {
1503        let m = 50;
1504        let t = uniform_grid(m);
1505        let mut data = FdMatrix::zeros(1, m);
1506        for j in 0..m {
1507            data[(0, j)] = (2.0 * PI * t[j]).sin();
1508        }
1509        let fdpar = make_fourier_fdpar(9, 1.0, 1e6);
1510        let res = smooth_basis(&data, &t, &fdpar).unwrap();
1511        // EDF should be very small with huge penalty
1512        assert!(
1513            res.edf < 5.0,
1514            "Large lambda should reduce EDF; edf={}",
1515            res.edf
1516        );
1517    }
1518
1519    // ─── smooth_basis: Lambda comparison tests ──────────────────────────────
1520
1521    #[test]
1522    fn test_smooth_basis_lambda_gradient_edf() {
1523        // EDF should monotonically decrease with increasing lambda
1524        let m = 50;
1525        let (data, t) = make_test_data(3, m);
1526        let lambdas = [1e-8, 1e-4, 1e-2, 1.0, 1e2];
1527        let mut prev_edf = f64::INFINITY;
1528        for &lam in &lambdas {
1529            let fdpar = make_bspline_fdpar(&t, 12, lam);
1530            let res = smooth_basis(&data, &t, &fdpar).unwrap();
1531            assert!(
1532                res.edf <= prev_edf + 0.01,
1533                "EDF should decrease: lambda={}, edf={}, prev_edf={}",
1534                lam,
1535                res.edf,
1536                prev_edf
1537            );
1538            prev_edf = res.edf;
1539        }
1540    }
1541
1542    #[test]
1543    fn test_smooth_basis_lambda_gradient_rss() {
1544        // RSS should monotonically increase with increasing lambda
1545        let m = 50;
1546        let n = 2;
1547        let (data, t) = make_test_data(n, m);
1548        let lambdas = [0.0, 1e-6, 1e-2, 1.0, 1e4];
1549        let mut prev_rss = -1.0;
1550        for &lam in &lambdas {
1551            let fdpar = make_bspline_fdpar(&t, 12, lam);
1552            let res = smooth_basis(&data, &t, &fdpar).unwrap();
1553            let mut rss = 0.0;
1554            for i in 0..n {
1555                for j in 0..m {
1556                    rss += (data[(i, j)] - res.fitted[(i, j)]).powi(2);
1557                }
1558            }
1559            assert!(
1560                rss >= prev_rss - 1e-8,
1561                "RSS should increase: lambda={}, rss={}, prev_rss={}",
1562                lam,
1563                rss,
1564                prev_rss
1565            );
1566            prev_rss = rss;
1567        }
1568    }
1569
1570    // ─── smooth_basis: Error cases ──────────────────────────────────────────
1571
1572    #[test]
1573    fn test_smooth_basis_empty_data_rows() {
1574        let t = uniform_grid(50);
1575        let data = FdMatrix::zeros(0, 50);
1576        let fdpar = make_bspline_fdpar(&t, 10, 1e-4);
1577        let res = smooth_basis(&data, &t, &fdpar);
1578        assert!(res.is_err());
1579    }
1580
1581    #[test]
1582    fn test_smooth_basis_empty_data_cols() {
1583        let data = FdMatrix::zeros(5, 0);
1584        let fdpar = FdPar {
1585            basis_type: BasisType::Bspline { order: 4 },
1586            nbasis: 10,
1587            lambda: 1e-4,
1588            lfd_order: 2,
1589            penalty_matrix: vec![0.0; 100],
1590        };
1591        let res = smooth_basis(&data, &[], &fdpar);
1592        assert!(res.is_err());
1593    }
1594
1595    #[test]
1596    fn test_smooth_basis_mismatched_argvals() {
1597        let t = uniform_grid(50);
1598        let data = FdMatrix::zeros(3, 40); // m=40 but argvals has 50
1599        let fdpar = make_bspline_fdpar(&t, 10, 1e-4);
1600        let res = smooth_basis(&data, &t, &fdpar);
1601        assert!(res.is_err());
1602    }
1603
1604    #[test]
1605    fn test_smooth_basis_nbasis_too_small() {
1606        let t = uniform_grid(50);
1607        let data = FdMatrix::zeros(3, 50);
1608        // nbasis = 1, which is below minimum of 2
1609        let fdpar = FdPar {
1610            basis_type: BasisType::Bspline { order: 4 },
1611            nbasis: 1,
1612            lambda: 1e-4,
1613            lfd_order: 2,
1614            penalty_matrix: vec![0.0; 1],
1615        };
1616        let res = smooth_basis(&data, &t, &fdpar);
1617        assert!(res.is_err());
1618    }
1619
1620    #[test]
1621    fn test_smooth_basis_error_is_invalid_dimension() {
1622        let t = uniform_grid(50);
1623        let data = FdMatrix::zeros(0, 50);
1624        let fdpar = make_bspline_fdpar(&t, 10, 1e-4);
1625        let err = smooth_basis(&data, &t, &fdpar).unwrap_err();
1626        match err {
1627            crate::FdarError::InvalidDimension { .. } => {} // expected
1628            other => panic!("Expected InvalidDimension, got {:?}", other),
1629        }
1630    }
1631
1632    // ─── Penalty matrix detailed tests ──────────────────────────────────────
1633
1634    #[test]
1635    fn test_bspline_penalty_matrix_different_orders() {
1636        let t = uniform_grid(101);
1637        // Order 1 penalty (penalize derivatives)
1638        let p1 = bspline_penalty_matrix(&t, 10, 4, 1);
1639        // Order 2 penalty (penalize curvature)
1640        let p2 = bspline_penalty_matrix(&t, 10, 4, 2);
1641        // Both should be square and same size
1642        assert_eq!(p1.len(), p2.len());
1643        // But they should differ
1644        let diff: f64 = p1.iter().zip(p2.iter()).map(|(a, b)| (a - b).abs()).sum();
1645        assert!(
1646            diff > 1e-10,
1647            "Different lfd_orders should produce different penalties"
1648        );
1649    }
1650
1651    #[test]
1652    fn test_bspline_penalty_matrix_edge_cases() {
1653        // Too few argvals
1654        let t = vec![0.0];
1655        let p = bspline_penalty_matrix(&t, 10, 4, 2);
1656        // Should return zero matrix
1657        assert!(p.iter().all(|&v| v == 0.0));
1658
1659        // nbasis < 2
1660        let t2 = uniform_grid(50);
1661        let p2 = bspline_penalty_matrix(&t2, 1, 4, 2);
1662        assert!(p2.iter().all(|&v| v == 0.0));
1663
1664        // lfd_order >= order
1665        let p3 = bspline_penalty_matrix(&t2, 10, 4, 4);
1666        assert!(p3.iter().all(|&v| v == 0.0));
1667    }
1668
1669    #[test]
1670    fn test_bspline_penalty_nonnegative_diagonal() {
1671        let t = uniform_grid(101);
1672        for nbasis in [5, 10, 20] {
1673            let p = bspline_penalty_matrix(&t, nbasis, 4, 2);
1674            let k = (p.len() as f64).sqrt() as usize;
1675            for i in 0..k {
1676                assert!(
1677                    p[i + i * k] >= -1e-10,
1678                    "Diagonal ({},{}) negative for nbasis={}: {}",
1679                    i,
1680                    i,
1681                    nbasis,
1682                    p[i + i * k]
1683                );
1684            }
1685        }
1686    }
1687
1688    #[test]
1689    fn test_fourier_penalty_increasing_with_frequency() {
1690        let penalty = fourier_penalty_matrix(11, 1.0, 2);
1691        let k = 11;
1692        // Constant term is zero
1693        assert!(penalty[0].abs() < 1e-15);
1694        // Pairs: (1,2) -> freq 1, (3,4) -> freq 2, etc.
1695        let mut prev_eigenval = 0.0;
1696        for freq in 1..=5 {
1697            let idx_sin = 2 * freq - 1;
1698            let eigenval = penalty[idx_sin + idx_sin * k];
1699            assert!(
1700                eigenval > prev_eigenval,
1701                "Higher frequency should have larger penalty: freq={}, eigenval={}, prev={}",
1702                freq,
1703                eigenval,
1704                prev_eigenval
1705            );
1706            prev_eigenval = eigenval;
1707            // cos and sin of same frequency should have same penalty
1708            let idx_cos = 2 * freq;
1709            if idx_cos < k {
1710                assert!(
1711                    (penalty[idx_cos + idx_cos * k] - eigenval).abs() < 1e-10,
1712                    "Sin and cos penalty should match at freq {}",
1713                    freq
1714                );
1715            }
1716        }
1717    }
1718
1719    #[test]
1720    fn test_fourier_penalty_different_periods() {
1721        let p1 = fourier_penalty_matrix(7, 1.0, 2);
1722        let p2 = fourier_penalty_matrix(7, 2.0, 2);
1723        // Longer period -> smaller omega -> smaller penalty eigenvalues
1724        for i in 1..7 {
1725            assert!(
1726                p2[i + i * 7] < p1[i + i * 7] || (p1[i + i * 7] == 0.0 && p2[i + i * 7] == 0.0),
1727                "Longer period should have smaller penalties at i={}",
1728                i
1729            );
1730        }
1731    }
1732
1733    #[test]
1734    fn test_fourier_penalty_first_order() {
1735        // lfd_order = 1: penalize first derivative
1736        let p = fourier_penalty_matrix(5, 1.0, 1);
1737        // Eigenvalues: (2*pi*freq)^2 for lfd_order=1
1738        let omega1 = 2.0 * PI;
1739        let expected1 = omega1.powi(2);
1740        assert!(
1741            (p[1 + 5] - expected1).abs() < 1e-6,
1742            "First-order penalty eigenval: got {}, expected {}",
1743            p[1 + 5],
1744            expected1
1745        );
1746    }
1747
1748    #[test]
1749    fn test_fourier_penalty_zero_nbasis() {
1750        let p = fourier_penalty_matrix(0, 1.0, 2);
1751        assert!(p.is_empty());
1752    }
1753
1754    #[test]
1755    fn test_fourier_penalty_nbasis_one() {
1756        let p = fourier_penalty_matrix(1, 1.0, 2);
1757        assert_eq!(p.len(), 1);
1758        assert!(p[0].abs() < 1e-15); // constant term has zero penalty
1759    }
1760
1761    // ─── smooth_basis_gcv detailed tests ────────────────────────────────────
1762
1763    #[test]
1764    fn test_smooth_basis_gcv_returns_valid_result() {
1765        let (data, t) = make_test_data(5, 50);
1766        let bt = BasisType::Bspline { order: 4 };
1767        let result = smooth_basis_gcv(&data, &t, &bt, 12, 2, (-6.0, 2.0), 20);
1768        assert!(result.is_some());
1769        let res = result.unwrap();
1770        assert_eq!(res.fitted.shape(), (5, 50));
1771        assert!(res.gcv.is_finite());
1772        assert!(res.edf > 0.0);
1773    }
1774
1775    #[test]
1776    fn test_smooth_basis_gcv_fourier() {
1777        let m = 80;
1778        let t = uniform_grid(m);
1779        let mut data = FdMatrix::zeros(3, m);
1780        for i in 0..3 {
1781            for j in 0..m {
1782                data[(i, j)] = (2.0 * PI * t[j]).sin() + 0.5 * (4.0 * PI * t[j]).cos();
1783            }
1784        }
1785        let bt = BasisType::Fourier { period: 1.0 };
1786        let result = smooth_basis_gcv(&data, &t, &bt, 9, 2, (-8.0, 4.0), 25);
1787        assert!(result.is_some());
1788        let res = result.unwrap();
1789        assert_eq!(res.fitted.nrows(), 3);
1790        assert_eq!(res.nbasis, 9);
1791    }
1792
1793    #[test]
1794    fn test_smooth_basis_gcv_selects_finite_gcv() {
1795        let (data, t) = make_test_data(5, 60);
1796        let bt = BasisType::Bspline { order: 4 };
1797        let res = smooth_basis_gcv(&data, &t, &bt, 12, 2, (-6.0, 2.0), 15).unwrap();
1798        assert!(res.gcv.is_finite());
1799        assert!(res.gcv > 0.0);
1800    }
1801
1802    #[test]
1803    fn test_smooth_basis_gcv_empty_data() {
1804        let data = FdMatrix::zeros(0, 50);
1805        let t = uniform_grid(50);
1806        let bt = BasisType::Bspline { order: 4 };
1807        let result = smooth_basis_gcv(&data, &t, &bt, 10, 2, (-6.0, 2.0), 10);
1808        // Should return None since smooth_basis will error for empty data
1809        assert!(result.is_none());
1810    }
1811
1812    #[test]
1813    fn test_smooth_basis_gcv_empty_argvals() {
1814        let data = FdMatrix::zeros(5, 0);
1815        let bt = BasisType::Bspline { order: 4 };
1816        let result = smooth_basis_gcv(&data, &[], &bt, 10, 2, (-6.0, 2.0), 10);
1817        assert!(result.is_none());
1818    }
1819
1820    #[test]
1821    fn test_smooth_basis_gcv_nbasis_too_small() {
1822        let (data, t) = make_test_data(5, 50);
1823        let bt = BasisType::Bspline { order: 4 };
1824        let result = smooth_basis_gcv(&data, &t, &bt, 1, 2, (-6.0, 2.0), 10);
1825        assert!(result.is_none());
1826    }
1827
1828    #[test]
1829    fn test_smooth_basis_gcv_ngrid_too_small() {
1830        let (data, t) = make_test_data(5, 50);
1831        let bt = BasisType::Bspline { order: 4 };
1832        let result = smooth_basis_gcv(&data, &t, &bt, 10, 2, (-6.0, 2.0), 1);
1833        assert!(result.is_none());
1834    }
1835
1836    #[test]
1837    fn test_smooth_basis_gcv_narrow_range() {
1838        let (data, t) = make_test_data(3, 50);
1839        let bt = BasisType::Bspline { order: 4 };
1840        // Very narrow search range
1841        let result = smooth_basis_gcv(&data, &t, &bt, 10, 2, (-3.0, -2.0), 5);
1842        assert!(result.is_some());
1843    }
1844
1845    #[test]
1846    fn test_smooth_basis_gcv_wide_range() {
1847        let (data, t) = make_test_data(3, 50);
1848        let bt = BasisType::Bspline { order: 4 };
1849        // Very wide search range
1850        let result = smooth_basis_gcv(&data, &t, &bt, 10, 2, (-12.0, 8.0), 30);
1851        assert!(result.is_some());
1852    }
1853
1854    // ─── basis_nbasis_cv detailed tests ─────────────────────────────────────
1855
1856    #[test]
1857    fn test_basis_nbasis_cv_scores_length() {
1858        let (data, t) = make_test_data(5, 50);
1859        let nbasis_range: Vec<usize> = vec![4, 6, 8, 10, 12];
1860        let res = basis_nbasis_cv(
1861            &data,
1862            &t,
1863            &nbasis_range,
1864            &BasisType::Bspline { order: 4 },
1865            BasisCriterion::Gcv,
1866            5,
1867            1e-4,
1868        )
1869        .unwrap();
1870        assert_eq!(res.scores.len(), 5);
1871        assert_eq!(res.nbasis_range.len(), 5);
1872        assert_eq!(res.nbasis_range, nbasis_range);
1873    }
1874
1875    #[test]
1876    fn test_basis_nbasis_cv_optimal_within_range() {
1877        let (data, t) = make_test_data(8, 50);
1878        let nbasis_range: Vec<usize> = vec![5, 7, 9, 11, 13, 15];
1879        for criterion in [
1880            BasisCriterion::Gcv,
1881            BasisCriterion::Aic,
1882            BasisCriterion::Bic,
1883        ] {
1884            let res = basis_nbasis_cv(
1885                &data,
1886                &t,
1887                &nbasis_range,
1888                &BasisType::Bspline { order: 4 },
1889                criterion,
1890                5,
1891                1e-4,
1892            )
1893            .unwrap();
1894            assert!(
1895                nbasis_range.contains(&res.optimal_nbasis),
1896                "optimal_nbasis {} not in range for {:?}",
1897                res.optimal_nbasis,
1898                criterion
1899            );
1900        }
1901    }
1902
1903    #[test]
1904    fn test_basis_nbasis_cv_fourier_gcv() {
1905        let m = 80;
1906        let t = uniform_grid(m);
1907        let mut data = FdMatrix::zeros(5, m);
1908        for i in 0..5 {
1909            for j in 0..m {
1910                data[(i, j)] = (2.0 * PI * t[j]).sin()
1911                    + 0.3 * (4.0 * PI * t[j]).cos()
1912                    + 0.02 * ((i * 7 + j * 3) % 10) as f64;
1913            }
1914        }
1915        let nbasis_range: Vec<usize> = vec![5, 7, 9, 11];
1916        let res = basis_nbasis_cv(
1917            &data,
1918            &t,
1919            &nbasis_range,
1920            &BasisType::Fourier { period: 1.0 },
1921            BasisCriterion::Gcv,
1922            5,
1923            1e-4,
1924        )
1925        .unwrap();
1926        assert!(nbasis_range.contains(&res.optimal_nbasis));
1927    }
1928
1929    #[test]
1930    fn test_basis_nbasis_cv_fourier_cv() {
1931        let m = 60;
1932        let t = uniform_grid(m);
1933        let n = 10;
1934        let mut data = FdMatrix::zeros(n, m);
1935        for i in 0..n {
1936            for j in 0..m {
1937                data[(i, j)] = (2.0 * PI * t[j]).sin() + 0.02 * ((i * 11 + j) % 15) as f64;
1938            }
1939        }
1940        let nbasis_range: Vec<usize> = vec![5, 7, 9];
1941        let res = basis_nbasis_cv(
1942            &data,
1943            &t,
1944            &nbasis_range,
1945            &BasisType::Fourier { period: 1.0 },
1946            BasisCriterion::Cv,
1947            5,
1948            1e-4,
1949        )
1950        .unwrap();
1951        assert!(nbasis_range.contains(&res.optimal_nbasis));
1952        assert_eq!(res.criterion, BasisCriterion::Cv);
1953    }
1954
1955    #[test]
1956    fn test_basis_nbasis_cv_with_nbasis_below_minimum() {
1957        // Range includes nbasis = 1 which is invalid
1958        let (data, t) = make_test_data(5, 50);
1959        let nbasis_range: Vec<usize> = vec![1, 5, 10];
1960        let res = basis_nbasis_cv(
1961            &data,
1962            &t,
1963            &nbasis_range,
1964            &BasisType::Bspline { order: 4 },
1965            BasisCriterion::Gcv,
1966            5,
1967            1e-4,
1968        )
1969        .unwrap();
1970        // Score for nbasis=1 should be infinity, so optimal should be 5 or 10
1971        assert!(
1972            res.optimal_nbasis >= 5,
1973            "Should skip invalid nbasis=1, got optimal={}",
1974            res.optimal_nbasis
1975        );
1976        assert!(res.scores[0].is_infinite());
1977    }
1978
1979    #[test]
1980    fn test_basis_nbasis_cv_empty_range() {
1981        let (data, t) = make_test_data(5, 50);
1982        let nbasis_range: Vec<usize> = vec![];
1983        let result = basis_nbasis_cv(
1984            &data,
1985            &t,
1986            &nbasis_range,
1987            &BasisType::Bspline { order: 4 },
1988            BasisCriterion::Gcv,
1989            5,
1990            1e-4,
1991        );
1992        assert!(result.is_none());
1993    }
1994
1995    #[test]
1996    fn test_basis_nbasis_cv_empty_data() {
1997        let data = FdMatrix::zeros(0, 50);
1998        let t = uniform_grid(50);
1999        let nbasis_range: Vec<usize> = vec![5, 10];
2000        let result = basis_nbasis_cv(
2001            &data,
2002            &t,
2003            &nbasis_range,
2004            &BasisType::Bspline { order: 4 },
2005            BasisCriterion::Gcv,
2006            5,
2007            1e-4,
2008        );
2009        assert!(result.is_none());
2010    }
2011
2012    #[test]
2013    fn test_basis_nbasis_cv_mismatched_argvals() {
2014        let data = FdMatrix::zeros(5, 50);
2015        let t = uniform_grid(40); // mismatch
2016        let nbasis_range: Vec<usize> = vec![5, 10];
2017        let result = basis_nbasis_cv(
2018            &data,
2019            &t,
2020            &nbasis_range,
2021            &BasisType::Bspline { order: 4 },
2022            BasisCriterion::Gcv,
2023            5,
2024            1e-4,
2025        );
2026        assert!(result.is_none());
2027    }
2028
2029    #[test]
2030    fn test_basis_nbasis_cv_single_nbasis() {
2031        let (data, t) = make_test_data(5, 50);
2032        let nbasis_range: Vec<usize> = vec![10];
2033        let res = basis_nbasis_cv(
2034            &data,
2035            &t,
2036            &nbasis_range,
2037            &BasisType::Bspline { order: 4 },
2038            BasisCriterion::Gcv,
2039            5,
2040            1e-4,
2041        )
2042        .unwrap();
2043        assert_eq!(res.optimal_nbasis, 10);
2044        assert_eq!(res.scores.len(), 1);
2045    }
2046
2047    #[test]
2048    fn test_basis_nbasis_cv_bic_penalizes_more_than_aic() {
2049        // BIC penalizes complexity more heavily than AIC, so it should generally
2050        // select the same or fewer basis functions
2051        let (data, t) = make_test_data(5, 80);
2052        let nbasis_range: Vec<usize> = (4..=20).step_by(2).collect();
2053
2054        let aic_res = basis_nbasis_cv(
2055            &data,
2056            &t,
2057            &nbasis_range,
2058            &BasisType::Bspline { order: 4 },
2059            BasisCriterion::Aic,
2060            5,
2061            1e-4,
2062        )
2063        .unwrap();
2064        let bic_res = basis_nbasis_cv(
2065            &data,
2066            &t,
2067            &nbasis_range,
2068            &BasisType::Bspline { order: 4 },
2069            BasisCriterion::Bic,
2070            5,
2071            1e-4,
2072        )
2073        .unwrap();
2074        // BIC should select at most as many basis functions as AIC
2075        // (not guaranteed in all cases, but typical behavior)
2076        assert!(
2077            bic_res.optimal_nbasis <= aic_res.optimal_nbasis + 4,
2078            "BIC selected {} vs AIC selected {} -- BIC should not select much more than AIC",
2079            bic_res.optimal_nbasis,
2080            aic_res.optimal_nbasis
2081        );
2082    }
2083
2084    // ─── Fitted values quality tests ────────────────────────────────────────
2085
2086    #[test]
2087    fn test_smooth_basis_fitted_close_to_data() {
2088        // With moderate penalty and enough basis functions, fitted should be close to data
2089        let m = 50;
2090        let n = 3;
2091        let t = uniform_grid(m);
2092        let mut data = FdMatrix::zeros(n, m);
2093        for i in 0..n {
2094            for j in 0..m {
2095                data[(i, j)] = (2.0 * PI * t[j]).sin();
2096            }
2097        }
2098        let fdpar = make_bspline_fdpar(&t, 15, 1e-6);
2099        let res = smooth_basis(&data, &t, &fdpar).unwrap();
2100
2101        let mut max_err = 0.0_f64;
2102        for i in 0..n {
2103            for j in 0..m {
2104                let err = (data[(i, j)] - res.fitted[(i, j)]).abs();
2105                max_err = max_err.max(err);
2106            }
2107        }
2108        assert!(
2109            max_err < 0.1,
2110            "Fitted should be close to smooth data; max_err={}",
2111            max_err
2112        );
2113    }
2114
2115    #[test]
2116    fn test_smooth_basis_constant_data() {
2117        // Constant data should be fit exactly
2118        let m = 50;
2119        let n = 2;
2120        let t = uniform_grid(m);
2121        let mut data = FdMatrix::zeros(n, m);
2122        for i in 0..n {
2123            for j in 0..m {
2124                data[(i, j)] = 3.15;
2125            }
2126        }
2127        let fdpar = make_bspline_fdpar(&t, 10, 1e-4);
2128        let res = smooth_basis(&data, &t, &fdpar).unwrap();
2129        for i in 0..n {
2130            for j in 0..m {
2131                assert!(
2132                    (res.fitted[(i, j)] - 3.15).abs() < 0.01,
2133                    "Constant data should be fit well at ({},{}): got {}",
2134                    i,
2135                    j,
2136                    res.fitted[(i, j)]
2137                );
2138            }
2139        }
2140    }
2141
2142    #[test]
2143    fn test_smooth_basis_linear_data() {
2144        // Linear data should be fit well with cubic B-splines
2145        let m = 50;
2146        let t = uniform_grid(m);
2147        let mut data = FdMatrix::zeros(1, m);
2148        for j in 0..m {
2149            data[(0, j)] = 2.0 * t[j] + 1.0;
2150        }
2151        let fdpar = make_bspline_fdpar(&t, 10, 1e-4);
2152        let res = smooth_basis(&data, &t, &fdpar).unwrap();
2153        for j in 0..m {
2154            let expected = 2.0 * t[j] + 1.0;
2155            assert!(
2156                (res.fitted[(0, j)] - expected).abs() < 0.05,
2157                "Linear data should be fit well at j={}: got {}, expected {}",
2158                j,
2159                res.fitted[(0, j)],
2160                expected
2161            );
2162        }
2163    }
2164
2165    // ─── EDF and diagnostic tests ───────────────────────────────────────────
2166
2167    #[test]
2168    fn test_smooth_basis_edf_bounded() {
2169        let m = 50;
2170        let (data, t) = make_test_data(3, m);
2171        let fdpar = make_bspline_fdpar(&t, 12, 1e-4);
2172        let res = smooth_basis(&data, &t, &fdpar).unwrap();
2173        // EDF should be between 1 and m (evaluation points)
2174        assert!(
2175            res.edf > 0.0 && res.edf <= m as f64,
2176            "EDF should be in (0, {}]; got {}",
2177            m,
2178            res.edf
2179        );
2180    }
2181
2182    #[test]
2183    fn test_smooth_basis_gcv_aic_bic_all_finite() {
2184        let (data, t) = make_test_data(4, 60);
2185        let fdpar = make_bspline_fdpar(&t, 12, 1e-3);
2186        let res = smooth_basis(&data, &t, &fdpar).unwrap();
2187        assert!(res.gcv.is_finite(), "GCV should be finite: {}", res.gcv);
2188        assert!(res.aic.is_finite(), "AIC should be finite: {}", res.aic);
2189        assert!(res.bic.is_finite(), "BIC should be finite: {}", res.bic);
2190    }
2191
2192    // ─── Penalty matrix size consistency tests ──────────────────────────────
2193
2194    #[test]
2195    fn test_smooth_basis_penalty_matrix_in_result() {
2196        let (data, t) = make_test_data(3, 50);
2197        let nbasis = 10;
2198        let fdpar = make_bspline_fdpar(&t, nbasis, 1e-4);
2199        let res = smooth_basis(&data, &t, &fdpar).unwrap();
2200        let k = res.nbasis;
2201        assert_eq!(
2202            res.penalty_matrix.len(),
2203            k * k,
2204            "Penalty matrix should be k*k = {}*{} = {}; got {}",
2205            k,
2206            k,
2207            k * k,
2208            res.penalty_matrix.len()
2209        );
2210    }
2211
2212    // ─── Regression: multiple identical curves ──────────────────────────────
2213
2214    #[test]
2215    fn test_smooth_basis_identical_curves_same_coefficients() {
2216        let m = 50;
2217        let t = uniform_grid(m);
2218        let curve: Vec<f64> = (0..m).map(|j| (2.0 * PI * t[j]).sin()).collect();
2219        let n = 4;
2220        let mut data = FdMatrix::zeros(n, m);
2221        for i in 0..n {
2222            for j in 0..m {
2223                data[(i, j)] = curve[j];
2224            }
2225        }
2226        let fdpar = make_bspline_fdpar(&t, 10, 1e-4);
2227        let res = smooth_basis(&data, &t, &fdpar).unwrap();
2228
2229        // All curves should have the same coefficients
2230        let k = res.coefficients.ncols();
2231        for i in 1..n {
2232            for j in 0..k {
2233                assert!(
2234                    (res.coefficients[(i, j)] - res.coefficients[(0, j)]).abs() < 1e-10,
2235                    "Identical curves should have identical coefficients: curve {} col {} differs",
2236                    i,
2237                    j
2238                );
2239            }
2240        }
2241    }
2242
2243    // ─── Cross-validation: different numbers of folds ───────────────────────
2244
2245    #[test]
2246    fn test_basis_nbasis_cv_different_nfolds() {
2247        let (data, t) = make_test_data(12, 50);
2248        let nbasis_range: Vec<usize> = vec![5, 8, 11];
2249        for nfolds in [2, 3, 5, 10] {
2250            let res = basis_nbasis_cv(
2251                &data,
2252                &t,
2253                &nbasis_range,
2254                &BasisType::Bspline { order: 4 },
2255                BasisCriterion::Cv,
2256                nfolds,
2257                1e-4,
2258            );
2259            assert!(res.is_some(), "CV should succeed with nfolds={}", nfolds);
2260            let r = res.unwrap();
2261            assert!(nbasis_range.contains(&r.optimal_nbasis));
2262        }
2263    }
2264
2265    // ─── Large nbasis / more basis than reasonable ──────────────────────────
2266
2267    #[test]
2268    fn test_smooth_basis_many_basis_functions() {
2269        let m = 100;
2270        let (data, t) = make_test_data(2, m);
2271        // Many basis functions relative to data points
2272        let fdpar = make_bspline_fdpar(&t, 40, 1e-2);
2273        let res = smooth_basis(&data, &t, &fdpar);
2274        assert!(
2275            res.is_ok(),
2276            "Should handle many basis functions with penalty"
2277        );
2278    }
2279
2280    // ─── evaluate_basis internal function (indirectly tested) ───────────────
2281
2282    #[test]
2283    fn test_smooth_basis_bspline_vs_fourier_different_results() {
2284        let m = 50;
2285        let (data, t) = make_test_data(2, m);
2286        let fdpar_bs = make_bspline_fdpar(&t, 9, 1e-4);
2287        let fdpar_f = make_fourier_fdpar(9, 1.0, 1e-4);
2288        let res_bs = smooth_basis(&data, &t, &fdpar_bs).unwrap();
2289        let res_f = smooth_basis(&data, &t, &fdpar_f).unwrap();
2290        // Results should differ between the two basis types
2291        let diff: f64 = (0..m)
2292            .map(|j| (res_bs.fitted[(0, j)] - res_f.fitted[(0, j)]).abs())
2293            .sum();
2294        // They fit the same data, so some difference is expected but not huge
2295        assert!(
2296            diff > 1e-10,
2297            "B-spline and Fourier fits should differ for the same data"
2298        );
2299    }
2300
2301    // ─── compute_gcv edge cases (indirectly tested) ─────────────────────────
2302
2303    #[test]
2304    fn test_smooth_basis_gcv_positive_for_noisy_data() {
2305        let m = 50;
2306        let t = uniform_grid(m);
2307        let mut data = FdMatrix::zeros(1, m);
2308        for j in 0..m {
2309            // Noisy data
2310            data[(0, j)] = (2.0 * PI * t[j]).sin() + 0.5 * ((j * 37) % 20) as f64 / 20.0 - 0.25;
2311        }
2312        let fdpar = make_bspline_fdpar(&t, 10, 1e-3);
2313        let res = smooth_basis(&data, &t, &fdpar).unwrap();
2314        assert!(res.gcv > 0.0, "GCV should be positive for noisy data");
2315    }
2316
2317    // ─── Penalty order (lfd_order) tests ────────────────────────────────────
2318
2319    #[test]
2320    fn test_smooth_basis_different_lfd_orders() {
2321        let m = 50;
2322        let (data, t) = make_test_data(2, m);
2323
2324        // lfd_order = 1 (penalize first derivative)
2325        let penalty1 = bspline_penalty_matrix(&t, 10, 4, 1);
2326        let fdpar1 = FdPar {
2327            basis_type: BasisType::Bspline { order: 4 },
2328            nbasis: 10,
2329            lambda: 1e-2,
2330            lfd_order: 1,
2331            penalty_matrix: penalty1,
2332        };
2333        let res1 = smooth_basis(&data, &t, &fdpar1);
2334        assert!(res1.is_ok());
2335
2336        // lfd_order = 2 (penalize second derivative)
2337        let penalty2 = bspline_penalty_matrix(&t, 10, 4, 2);
2338        let fdpar2 = FdPar {
2339            basis_type: BasisType::Bspline { order: 4 },
2340            nbasis: 10,
2341            lambda: 1e-2,
2342            lfd_order: 2,
2343            penalty_matrix: penalty2,
2344        };
2345        let res2 = smooth_basis(&data, &t, &fdpar2);
2346        assert!(res2.is_ok());
2347
2348        // Different penalty orders should produce different fitted values
2349        let r1 = res1.unwrap();
2350        let r2 = res2.unwrap();
2351        let diff: f64 = (0..m)
2352            .map(|j| (r1.fitted[(0, j)] - r2.fitted[(0, j)]).abs())
2353            .sum();
2354        assert!(
2355            diff > 1e-10,
2356            "Different lfd_orders should produce different fits"
2357        );
2358    }
2359
2360    // ─── BasisNbasisCvResult field tests ────────────────────────────────────
2361
2362    #[test]
2363    fn test_basis_nbasis_cv_result_fields() {
2364        let (data, t) = make_test_data(6, 50);
2365        let nbasis_range: Vec<usize> = vec![5, 7, 9, 11, 13];
2366        let res = basis_nbasis_cv(
2367            &data,
2368            &t,
2369            &nbasis_range,
2370            &BasisType::Bspline { order: 4 },
2371            BasisCriterion::Aic,
2372            5,
2373            1e-4,
2374        )
2375        .unwrap();
2376
2377        assert!(nbasis_range.contains(&res.optimal_nbasis));
2378        assert_eq!(res.scores.len(), nbasis_range.len());
2379        assert_eq!(res.nbasis_range, nbasis_range);
2380        assert_eq!(res.criterion, BasisCriterion::Aic);
2381        // optimal_nbasis should correspond to minimum score
2382        let min_score = res.scores.iter().copied().fold(f64::INFINITY, f64::min);
2383        let best_idx = res
2384            .scores
2385            .iter()
2386            .position(|&s| (s - min_score).abs() < 1e-15)
2387            .unwrap();
2388        assert_eq!(res.optimal_nbasis, nbasis_range[best_idx]);
2389    }
2390
2391    #[test]
2392    fn test_basis_nbasis_cv_result_clone() {
2393        let (data, t) = make_test_data(5, 50);
2394        let nbasis_range: Vec<usize> = vec![5, 10];
2395        let res = basis_nbasis_cv(
2396            &data,
2397            &t,
2398            &nbasis_range,
2399            &BasisType::Bspline { order: 4 },
2400            BasisCriterion::Gcv,
2401            5,
2402            1e-4,
2403        )
2404        .unwrap();
2405        let cloned = res.clone();
2406        assert_eq!(res, cloned);
2407    }
2408
2409    // ─── Non-uniform argvals ────────────────────────────────────────────────
2410
2411    #[test]
2412    fn test_smooth_basis_nonuniform_argvals() {
2413        let m = 50;
2414        // Non-uniform grid: denser at the ends
2415        let t: Vec<f64> = (0..m)
2416            .map(|i| {
2417                let x = i as f64 / (m - 1) as f64;
2418                0.5 * (1.0 - (PI * x).cos())
2419            })
2420            .collect();
2421        let mut data = FdMatrix::zeros(2, m);
2422        for i in 0..2 {
2423            for j in 0..m {
2424                data[(i, j)] = (2.0 * PI * t[j]).sin() + 0.1 * i as f64;
2425            }
2426        }
2427        let fdpar = make_bspline_fdpar(&t, 10, 1e-4);
2428        let res = smooth_basis(&data, &t, &fdpar);
2429        assert!(res.is_ok(), "Should handle non-uniform argvals");
2430        let r = res.unwrap();
2431        assert_eq!(r.fitted.shape(), (2, m));
2432    }
2433
2434    // ─── Numerical stability with extreme lambda ────────────────────────────
2435
2436    #[test]
2437    fn test_smooth_basis_very_small_lambda() {
2438        let m = 50;
2439        let (data, t) = make_test_data(2, m);
2440        let fdpar = make_bspline_fdpar(&t, 10, 1e-15);
2441        let res = smooth_basis(&data, &t, &fdpar);
2442        assert!(res.is_ok(), "Should handle very small lambda");
2443    }
2444
2445    #[test]
2446    fn test_smooth_basis_very_large_lambda() {
2447        let m = 50;
2448        let (data, t) = make_test_data(2, m);
2449        let fdpar = make_bspline_fdpar(&t, 10, 1e10);
2450        let res = smooth_basis(&data, &t, &fdpar);
2451        assert!(res.is_ok(), "Should handle very large lambda");
2452    }
2453
2454    // ─── Multiple curves consistency ────────────────────────────────────────
2455
2456    #[test]
2457    fn test_smooth_basis_multi_curve_vs_single_curve() {
2458        // Smoothing multiple curves at once should give the same result as smoothing each individually
2459        let m = 50;
2460        let n = 3;
2461        let (data, t) = make_test_data(n, m);
2462        let fdpar = make_bspline_fdpar(&t, 10, 1e-3);
2463
2464        // All at once
2465        let res_all = smooth_basis(&data, &t, &fdpar).unwrap();
2466
2467        // One at a time
2468        for i in 0..n {
2469            let mut single = FdMatrix::zeros(1, m);
2470            for j in 0..m {
2471                single[(0, j)] = data[(i, j)];
2472            }
2473            let res_single = smooth_basis(&single, &t, &fdpar).unwrap();
2474            for j in 0..m {
2475                assert!(
2476                    (res_all.fitted[(i, j)] - res_single.fitted[(0, j)]).abs() < 1e-10,
2477                    "Multi-curve fit should match single-curve fit: curve {} point {}",
2478                    i,
2479                    j
2480                );
2481            }
2482        }
2483    }
2484
2485    // ─── BasisCriterion comparison: all criteria produce finite scores ──────
2486
2487    #[test]
2488    fn test_basis_nbasis_cv_all_criteria_finite_scores() {
2489        let (data, t) = make_test_data(10, 60);
2490        let nbasis_range: Vec<usize> = vec![5, 7, 9, 11];
2491
2492        for criterion in [
2493            BasisCriterion::Gcv,
2494            BasisCriterion::Aic,
2495            BasisCriterion::Bic,
2496            BasisCriterion::Cv,
2497        ] {
2498            let res = basis_nbasis_cv(
2499                &data,
2500                &t,
2501                &nbasis_range,
2502                &BasisType::Bspline { order: 4 },
2503                criterion,
2504                5,
2505                1e-4,
2506            )
2507            .unwrap();
2508            // At least some scores should be finite (valid nbasis values)
2509            let finite_count = res.scores.iter().filter(|s| s.is_finite()).count();
2510            assert!(
2511                finite_count > 0,
2512                "At least one score should be finite for {:?}",
2513                criterion
2514            );
2515        }
2516    }
2517
2518    // ─── SmoothBasisGcvConfig tests ────────────────────────────────────────
2519
2520    #[test]
2521    fn test_smooth_basis_gcv_config_default() {
2522        let config = SmoothBasisGcvConfig::default();
2523        assert_eq!(config.basis_type, BasisType::Bspline { order: 4 });
2524        assert_eq!(config.nbasis, 15);
2525        assert_eq!(config.lfd_order, 2);
2526        assert_eq!(config.log_lambda_range, (-10.0, 2.0));
2527        assert_eq!(config.n_grid, 50);
2528    }
2529
2530    #[test]
2531    fn test_smooth_basis_gcv_config_clone_eq() {
2532        let config = SmoothBasisGcvConfig {
2533            nbasis: 20,
2534            ..SmoothBasisGcvConfig::default()
2535        };
2536        let cloned = config.clone();
2537        assert_eq!(config, cloned);
2538    }
2539
2540    #[test]
2541    fn test_smooth_basis_gcv_config_debug() {
2542        let config = SmoothBasisGcvConfig::default();
2543        let debug_str = format!("{:?}", config);
2544        assert!(debug_str.contains("SmoothBasisGcvConfig"));
2545        assert!(debug_str.contains("nbasis"));
2546    }
2547
2548    #[test]
2549    fn test_smooth_basis_gcv_config_partial_override() {
2550        let config = SmoothBasisGcvConfig {
2551            basis_type: BasisType::Fourier { period: 2.0 },
2552            n_grid: 100,
2553            ..SmoothBasisGcvConfig::default()
2554        };
2555        assert_eq!(config.basis_type, BasisType::Fourier { period: 2.0 });
2556        assert_eq!(config.n_grid, 100);
2557        // defaults preserved
2558        assert_eq!(config.nbasis, 15);
2559        assert_eq!(config.lfd_order, 2);
2560    }
2561
2562    #[test]
2563    fn test_smooth_basis_gcv_with_config_default() {
2564        let (data, t) = make_test_data(5, 101);
2565        let config = SmoothBasisGcvConfig::default();
2566        let result = smooth_basis_gcv_with_config(&data, &t, &config);
2567        assert!(result.is_ok(), "GCV with default config should succeed");
2568        let res = result.unwrap();
2569        assert_eq!(res.fitted.shape(), (5, 101));
2570        assert!(res.edf > 0.0);
2571        assert!(res.gcv.is_finite());
2572    }
2573
2574    #[test]
2575    fn test_smooth_basis_gcv_with_config_custom() {
2576        let (data, t) = make_test_data(3, 50);
2577        let config = SmoothBasisGcvConfig {
2578            nbasis: 10,
2579            log_lambda_range: (-6.0, 0.0),
2580            n_grid: 15,
2581            ..SmoothBasisGcvConfig::default()
2582        };
2583        let result = smooth_basis_gcv_with_config(&data, &t, &config);
2584        assert!(result.is_ok());
2585    }
2586
2587    #[test]
2588    fn test_smooth_basis_gcv_with_config_matches_direct() {
2589        let (data, t) = make_test_data(3, 50);
2590        let config = SmoothBasisGcvConfig {
2591            nbasis: 10,
2592            log_lambda_range: (-6.0, 0.0),
2593            n_grid: 20,
2594            ..SmoothBasisGcvConfig::default()
2595        };
2596        let with_config = smooth_basis_gcv_with_config(&data, &t, &config).unwrap();
2597        let direct = smooth_basis_gcv(
2598            &data,
2599            &t,
2600            &config.basis_type,
2601            config.nbasis,
2602            config.lfd_order,
2603            config.log_lambda_range,
2604            config.n_grid,
2605        )
2606        .unwrap();
2607        assert_eq!(with_config.gcv, direct.gcv);
2608        assert_eq!(with_config.edf, direct.edf);
2609        assert_eq!(with_config.nbasis, direct.nbasis);
2610    }
2611
2612    #[test]
2613    fn test_smooth_basis_gcv_with_config_fourier() {
2614        let m = 100;
2615        let t = uniform_grid(m);
2616        let mut data = FdMatrix::zeros(2, m);
2617        for i in 0..2 {
2618            for j in 0..m {
2619                data[(i, j)] = (2.0 * PI * t[j]).sin() + (4.0 * PI * t[j]).cos();
2620            }
2621        }
2622        let config = SmoothBasisGcvConfig {
2623            basis_type: BasisType::Fourier { period: 1.0 },
2624            nbasis: 7,
2625            n_grid: 20,
2626            ..SmoothBasisGcvConfig::default()
2627        };
2628        let result = smooth_basis_gcv_with_config(&data, &t, &config);
2629        assert!(result.is_ok());
2630    }
2631
2632    // ─── BasisNbasisCvConfig tests ─────────────────────────────────────────
2633
2634    #[test]
2635    fn test_basis_nbasis_cv_config_default() {
2636        let config = BasisNbasisCvConfig::default();
2637        assert_eq!(config.basis_type, BasisType::Bspline { order: 4 });
2638        assert_eq!(config.nbasis_range, (5, 30));
2639        assert!((config.lambda - 1e-4).abs() < 1e-15);
2640        assert_eq!(config.lfd_order, 2);
2641        assert_eq!(config.n_folds, 5);
2642        assert_eq!(config.criterion, BasisCriterion::Gcv);
2643    }
2644
2645    #[test]
2646    fn test_basis_nbasis_cv_config_clone_eq() {
2647        let config = BasisNbasisCvConfig {
2648            nbasis_range: (4, 15),
2649            ..BasisNbasisCvConfig::default()
2650        };
2651        let cloned = config.clone();
2652        assert_eq!(config, cloned);
2653    }
2654
2655    #[test]
2656    fn test_basis_nbasis_cv_config_debug() {
2657        let config = BasisNbasisCvConfig::default();
2658        let debug_str = format!("{:?}", config);
2659        assert!(debug_str.contains("BasisNbasisCvConfig"));
2660        assert!(debug_str.contains("nbasis_range"));
2661    }
2662
2663    #[test]
2664    fn test_basis_nbasis_cv_config_partial_override() {
2665        let config = BasisNbasisCvConfig {
2666            criterion: BasisCriterion::Aic,
2667            lambda: 1e-2,
2668            ..BasisNbasisCvConfig::default()
2669        };
2670        assert_eq!(config.criterion, BasisCriterion::Aic);
2671        assert!((config.lambda - 1e-2).abs() < 1e-15);
2672        // defaults preserved
2673        assert_eq!(config.nbasis_range, (5, 30));
2674        assert_eq!(config.n_folds, 5);
2675    }
2676
2677    #[test]
2678    fn test_basis_nbasis_cv_with_config_default() {
2679        let (data, t) = make_test_data(5, 51);
2680        let config = BasisNbasisCvConfig {
2681            nbasis_range: (5, 12),
2682            ..BasisNbasisCvConfig::default()
2683        };
2684        let result = basis_nbasis_cv_with_config(&data, &t, &config);
2685        assert!(
2686            result.is_ok(),
2687            "nbasis CV with default config should succeed"
2688        );
2689        let res = result.unwrap();
2690        assert!(res.optimal_nbasis >= 5 && res.optimal_nbasis <= 12);
2691        assert_eq!(res.scores.len(), 8); // 5..=12 = 8 values
2692        assert_eq!(res.criterion, BasisCriterion::Gcv);
2693    }
2694
2695    #[test]
2696    fn test_basis_nbasis_cv_with_config_aic() {
2697        let (data, t) = make_test_data(5, 51);
2698        let config = BasisNbasisCvConfig {
2699            nbasis_range: (5, 10),
2700            criterion: BasisCriterion::Aic,
2701            ..BasisNbasisCvConfig::default()
2702        };
2703        let result = basis_nbasis_cv_with_config(&data, &t, &config);
2704        assert!(result.is_ok());
2705        assert_eq!(result.unwrap().criterion, BasisCriterion::Aic);
2706    }
2707
2708    #[test]
2709    fn test_basis_nbasis_cv_with_config_cv_folds() {
2710        let (data, t) = make_test_data(10, 51);
2711        let config = BasisNbasisCvConfig {
2712            nbasis_range: (5, 9),
2713            criterion: BasisCriterion::Cv,
2714            n_folds: 3,
2715            ..BasisNbasisCvConfig::default()
2716        };
2717        let result = basis_nbasis_cv_with_config(&data, &t, &config);
2718        assert!(result.is_ok());
2719        assert_eq!(result.unwrap().criterion, BasisCriterion::Cv);
2720    }
2721
2722    #[test]
2723    fn test_basis_nbasis_cv_with_config_matches_direct() {
2724        let (data, t) = make_test_data(5, 51);
2725        let config = BasisNbasisCvConfig {
2726            nbasis_range: (5, 10),
2727            criterion: BasisCriterion::Bic,
2728            lambda: 1e-3,
2729            ..BasisNbasisCvConfig::default()
2730        };
2731        let with_config = basis_nbasis_cv_with_config(&data, &t, &config).unwrap();
2732        let nbasis_range: Vec<usize> = (5..=10).collect();
2733        let direct = basis_nbasis_cv(
2734            &data,
2735            &t,
2736            &nbasis_range,
2737            &config.basis_type,
2738            config.criterion,
2739            config.n_folds,
2740            config.lambda,
2741        )
2742        .unwrap();
2743        assert_eq!(with_config.optimal_nbasis, direct.optimal_nbasis);
2744        assert_eq!(with_config.scores, direct.scores);
2745        assert_eq!(with_config.nbasis_range, direct.nbasis_range);
2746    }
2747
2748    #[test]
2749    fn test_basis_nbasis_cv_with_config_nbasis_range_expansion() {
2750        let (data, t) = make_test_data(5, 51);
2751        let config = BasisNbasisCvConfig {
2752            nbasis_range: (7, 7), // single value
2753            ..BasisNbasisCvConfig::default()
2754        };
2755        let result = basis_nbasis_cv_with_config(&data, &t, &config);
2756        assert!(result.is_ok());
2757        let res = result.unwrap();
2758        assert_eq!(res.optimal_nbasis, 7);
2759        assert_eq!(res.scores.len(), 1);
2760    }
2761}