Skip to main content

fdars_core/
smooth_basis.rs

1//! Basis-penalized smoothing with continuous derivative penalties.
2//!
3//! This module implements `smooth.basis` from R's fda package. Unlike the
4//! discrete difference penalty used in P-splines (`basis.rs`), this uses
5//! continuous derivative penalties: `min ||y - Φc||² + λ·∫(Lf)² dt`.
6//!
7//! Key capabilities:
8//! - [`smooth_basis`] — Penalized least squares with continuous roughness penalty
9//! - [`smooth_basis_gcv`] — GCV-optimal smoothing parameter selection
10//! - [`bspline_penalty_matrix`] / [`fourier_penalty_matrix`] — Roughness penalty matrices
11
12use crate::basis::{bspline_basis, fourier_basis_with_period};
13use crate::helpers::simpsons_weights;
14use crate::matrix::FdMatrix;
15use nalgebra::DMatrix;
16use std::f64::consts::PI;
17
18// ─── Types ──────────────────────────────────────────────────────────────────
19
20/// Basis type for penalized smoothing.
21#[derive(Debug, Clone, PartialEq)]
22pub enum BasisType {
23    /// B-spline basis with given order (typically 4 for cubic).
24    Bspline { order: usize },
25    /// Fourier basis with given period.
26    Fourier { period: f64 },
27}
28
29/// Functional data parameter object (basis + penalty specification).
30#[derive(Debug, Clone, PartialEq)]
31pub struct FdPar {
32    /// Type of basis system.
33    pub basis_type: BasisType,
34    /// Number of basis functions.
35    pub nbasis: usize,
36    /// Smoothing parameter.
37    pub lambda: f64,
38    /// Derivative order for the penalty (default: 2).
39    pub lfd_order: usize,
40    /// Precomputed K×K penalty matrix (column-major).
41    pub penalty_matrix: Vec<f64>,
42}
43
44/// Result of basis-penalized smoothing.
45#[derive(Debug, Clone, PartialEq)]
46pub struct SmoothBasisResult {
47    /// Basis coefficients (n × K).
48    pub coefficients: FdMatrix,
49    /// Fitted values (n × m).
50    pub fitted: FdMatrix,
51    /// Effective degrees of freedom.
52    pub edf: f64,
53    /// Generalized cross-validation score.
54    pub gcv: f64,
55    /// AIC.
56    pub aic: f64,
57    /// BIC.
58    pub bic: f64,
59    /// Roughness penalty matrix (K × K, column-major).
60    pub penalty_matrix: Vec<f64>,
61    /// Number of basis functions used.
62    pub nbasis: usize,
63}
64
65// ─── Penalty Matrices ───────────────────────────────────────────────────────
66
67/// Compute the roughness penalty matrix for B-splines via numerical quadrature.
68///
69/// R\[j,k\] = ∫ D^m B_j(t) · D^m B_k(t) dt
70///
71/// Uses Simpson's rule on a fine sub-grid for each knot interval.
72///
73/// # Arguments
74/// * `argvals` — Evaluation points (length m)
75/// * `nbasis` — Number of basis functions
76/// * `order` — B-spline order (typically 4 for cubic)
77/// * `lfd_order` — Derivative order for penalty (typically 2)
78///
79/// # Returns
80/// K × K penalty matrix in column-major layout (K = nbasis)
81pub fn bspline_penalty_matrix(
82    argvals: &[f64],
83    nbasis: usize,
84    order: usize,
85    lfd_order: usize,
86) -> Vec<f64> {
87    if nbasis < 2 || order < 1 || lfd_order >= order || argvals.len() < 2 {
88        return vec![0.0; nbasis * nbasis];
89    }
90
91    let nknots = nbasis.saturating_sub(order).max(2);
92
93    // Create a fine quadrature grid (10 sub-points per original interval)
94    let n_sub = 10;
95    let t_min = argvals[0];
96    let t_max = argvals[argvals.len() - 1];
97    let n_quad = (argvals.len() - 1) * n_sub + 1;
98    let quad_t: Vec<f64> = (0..n_quad)
99        .map(|i| t_min + (t_max - t_min) * i as f64 / (n_quad - 1) as f64)
100        .collect();
101
102    // Evaluate B-spline basis on fine grid
103    let basis_fine = bspline_basis(&quad_t, nknots, order);
104    let actual_nbasis = basis_fine.len() / n_quad;
105
106    // Compute derivatives of B-spline basis numerically
107    let h = (t_max - t_min) / (n_quad - 1) as f64;
108    let deriv_basis = differentiate_basis_columns(&basis_fine, n_quad, actual_nbasis, h, lfd_order);
109
110    // Integration weights on fine grid
111    let weights = simpsons_weights(&quad_t);
112
113    // Compute penalty matrix: R[j,k] = ∫ D^m B_j · D^m B_k dt
114    integrate_symmetric_penalty(&deriv_basis, &weights, actual_nbasis, n_quad)
115}
116
117/// Compute the roughness penalty matrix for a Fourier basis.
118///
119/// For Fourier basis, the penalty is diagonal with eigenvalues `(2πk/T)^(2m)`.
120///
121/// # Arguments
122/// * `nbasis` — Number of basis functions
123/// * `period` — Period of the Fourier basis
124/// * `lfd_order` — Derivative order for penalty
125///
126/// # Returns
127/// K × K penalty matrix in column-major layout
128pub fn fourier_penalty_matrix(nbasis: usize, period: f64, lfd_order: usize) -> Vec<f64> {
129    let k = nbasis;
130    let mut penalty = vec![0.0; k * k];
131
132    // First basis function is constant → lfd_order-th derivative is 0
133    // penalty[0] = 0 (already zero)
134
135    // For sin/cos pairs: eigenvalue is (2πk/T)^(2m)
136    // Matches R's fda package convention (sqrt(2)-normalized basis)
137    let mut freq = 1;
138    let mut idx = 1;
139    while idx < k {
140        let omega = 2.0 * PI * f64::from(freq) / period;
141        let eigenval = omega.powi(2 * lfd_order as i32);
142
143        // sin component
144        if idx < k {
145            penalty[idx + idx * k] = eigenval;
146            idx += 1;
147        }
148        // cos component
149        if idx < k {
150            penalty[idx + idx * k] = eigenval;
151            idx += 1;
152        }
153        freq += 1;
154    }
155
156    penalty
157}
158
159// ─── Smoothing Functions ────────────────────────────────────────────────────
160
161/// Perform basis-penalized smoothing.
162///
163/// Solves `(Φ'Φ + λR)c = Φ'y` per curve via Cholesky decomposition.
164/// This implements `smooth.basis` from R's fda package.
165///
166/// # Arguments
167/// * `data` — Functional data matrix (n × m)
168/// * `argvals` — Evaluation points (length m)
169/// * `fdpar` — Functional parameter object specifying basis and penalty
170///
171/// # Returns
172/// [`SmoothBasisResult`] with coefficients, fitted values, and diagnostics.
173pub fn smooth_basis(
174    data: &FdMatrix,
175    argvals: &[f64],
176    fdpar: &FdPar,
177) -> Result<SmoothBasisResult, crate::FdarError> {
178    let (n, m) = data.shape();
179    if n == 0 || m == 0 || argvals.len() != m || fdpar.nbasis < 2 {
180        return Err(crate::FdarError::InvalidDimension {
181            parameter: "data/argvals/fdpar",
182            expected: "n > 0, m > 0, argvals.len() == m, nbasis >= 2".to_string(),
183            actual: format!(
184                "n={}, m={}, argvals.len()={}, nbasis={}",
185                n,
186                m,
187                argvals.len(),
188                fdpar.nbasis
189            ),
190        });
191    }
192
193    // Evaluate basis on argvals
194    let (basis_flat, actual_nbasis) = evaluate_basis(argvals, &fdpar.basis_type, fdpar.nbasis);
195    let k = actual_nbasis;
196
197    let b_mat = DMatrix::from_column_slice(m, k, &basis_flat);
198    let r_mat = DMatrix::from_column_slice(k, k, &fdpar.penalty_matrix);
199
200    // (Φ'Φ + λR + εI) — small ridge ensures positive definiteness
201    let btb = b_mat.transpose() * &b_mat;
202    let ridge_eps = 1e-10;
203    let system: DMatrix<f64> =
204        &btb + fdpar.lambda * &r_mat + ridge_eps * DMatrix::<f64>::identity(k, k);
205
206    // Invert the penalized system
207    let system_inv =
208        invert_penalized_system(&system, k).ok_or_else(|| crate::FdarError::ComputationFailed {
209            operation: "matrix inversion",
210            detail: "failed to invert penalized system (Φ'Φ + λR); try increasing lambda or reducing the number of basis functions".to_string(),
211        })?;
212
213    // Hat matrix: H = Φ (Φ'Φ + λR)^{-1} Φ'  →  EDF = tr(H)
214    let h_mat = &b_mat * &system_inv * b_mat.transpose();
215    let edf: f64 = (0..m).map(|i| h_mat[(i, i)]).sum();
216
217    // Project all curves
218    let proj = &system_inv * b_mat.transpose();
219    let (all_coefs, all_fitted, total_rss) = project_all_curves(data, &b_mat, &proj, n, m, k);
220
221    let total_points = (n * m) as f64;
222    let gcv = compute_gcv(total_rss, total_points, edf, m);
223    let mse = total_rss / total_points;
224    // Total effective degrees of freedom = n curves * per-curve edf
225    let total_edf = n as f64 * edf;
226    let aic = total_points * mse.max(1e-300).ln() + 2.0 * total_edf;
227    let bic = total_points * mse.max(1e-300).ln() + total_points.ln() * total_edf;
228
229    Ok(SmoothBasisResult {
230        coefficients: all_coefs,
231        fitted: all_fitted,
232        edf,
233        gcv,
234        aic,
235        bic,
236        penalty_matrix: fdpar.penalty_matrix.clone(),
237        nbasis: k,
238    })
239}
240
241/// Perform basis-penalized smoothing with GCV-optimal lambda.
242///
243/// Searches over a log-lambda grid and selects the lambda minimizing GCV.
244///
245/// # Arguments
246/// * `data` — Functional data matrix (n × m)
247/// * `argvals` — Evaluation points (length m)
248/// * `basis_type` — Type of basis system
249/// * `nbasis` — Number of basis functions
250/// * `lfd_order` — Derivative order for penalty
251/// * `log_lambda_range` — Range of log10(lambda) to search, e.g. (-8.0, 4.0)
252/// * `n_grid` — Number of grid points for the search
253pub fn smooth_basis_gcv(
254    data: &FdMatrix,
255    argvals: &[f64],
256    basis_type: &BasisType,
257    nbasis: usize,
258    lfd_order: usize,
259    log_lambda_range: (f64, f64),
260    n_grid: usize,
261) -> Option<SmoothBasisResult> {
262    let m = argvals.len();
263    if m == 0 || nbasis < 2 || n_grid < 2 {
264        return None;
265    }
266
267    // Compute penalty matrix once
268    let penalty = match basis_type {
269        BasisType::Bspline { order } => bspline_penalty_matrix(argvals, nbasis, *order, lfd_order),
270        BasisType::Fourier { period } => fourier_penalty_matrix(nbasis, *period, lfd_order),
271    };
272
273    let (lo, hi) = log_lambda_range;
274    let mut best_gcv = f64::INFINITY;
275    let mut best_result: Option<SmoothBasisResult> = None;
276
277    for i in 0..n_grid {
278        let log_lam = lo + (hi - lo) * i as f64 / (n_grid - 1) as f64;
279        let lam = 10.0_f64.powf(log_lam);
280
281        let fdpar = FdPar {
282            basis_type: basis_type.clone(),
283            nbasis,
284            lambda: lam,
285            lfd_order,
286            penalty_matrix: penalty.clone(),
287        };
288
289        if let Ok(result) = smooth_basis(data, argvals, &fdpar) {
290            if result.gcv < best_gcv {
291                best_gcv = result.gcv;
292                best_result = Some(result);
293            }
294        }
295    }
296
297    best_result
298}
299
300// ─── Config Structs ─────────────────────────────────────────────────────────
301
302/// Configuration for GCV-based smoothing parameter selection.
303///
304/// Collects all tuning parameters for [`smooth_basis_gcv_with_config`], with
305/// sensible defaults obtained via [`SmoothBasisGcvConfig::default()`].
306///
307/// # Example
308/// ```no_run
309/// use fdars_core::smooth_basis::{SmoothBasisGcvConfig, BasisType};
310///
311/// let config = SmoothBasisGcvConfig {
312///     nbasis: 20,
313///     n_grid: 100,
314///     ..SmoothBasisGcvConfig::default()
315/// };
316/// ```
317#[derive(Debug, Clone, PartialEq)]
318pub struct SmoothBasisGcvConfig {
319    /// Basis type (BSpline or Fourier).
320    pub basis_type: BasisType,
321    /// Number of basis functions (default: 15).
322    pub nbasis: usize,
323    /// Order of the roughness penalty differential operator (default: 2).
324    pub lfd_order: usize,
325    /// Range of log10(lambda) values to search (default: (-10.0, 2.0)).
326    pub log_lambda_range: (f64, f64),
327    /// Number of grid points in the lambda search (default: 50).
328    pub n_grid: usize,
329}
330
331impl Default for SmoothBasisGcvConfig {
332    fn default() -> Self {
333        Self {
334            basis_type: BasisType::Bspline { order: 4 },
335            nbasis: 15,
336            lfd_order: 2,
337            log_lambda_range: (-10.0, 2.0),
338            n_grid: 50,
339        }
340    }
341}
342
343/// Perform basis-penalized smoothing with GCV-optimal lambda using a config struct.
344///
345/// This is the config-based alternative to [`smooth_basis_gcv`]. It takes data
346/// parameters directly and reads all tuning parameters from the config.
347///
348/// # Arguments
349/// * `data` — Functional data matrix (n × m)
350/// * `argvals` — Evaluation points (length m)
351/// * `config` — Tuning parameters
352///
353/// # Errors
354///
355/// Returns [`crate::FdarError::ComputationFailed`] if no valid smoothing result
356/// is found for any lambda in the search grid.
357#[must_use = "expensive computation whose result should not be discarded"]
358pub fn smooth_basis_gcv_with_config(
359    data: &FdMatrix,
360    argvals: &[f64],
361    config: &SmoothBasisGcvConfig,
362) -> Result<SmoothBasisResult, crate::FdarError> {
363    smooth_basis_gcv(
364        data,
365        argvals,
366        &config.basis_type,
367        config.nbasis,
368        config.lfd_order,
369        config.log_lambda_range,
370        config.n_grid,
371    )
372    .ok_or_else(|| crate::FdarError::ComputationFailed {
373        operation: "smooth_basis_gcv_with_config",
374        detail: "no valid smoothing result found in GCV lambda search".to_string(),
375    })
376}
377
378/// Configuration for cross-validation-based basis selection.
379///
380/// Collects all tuning parameters for [`basis_nbasis_cv_with_config`], with
381/// sensible defaults obtained via [`BasisNbasisCvConfig::default()`].
382///
383/// # Example
384/// ```no_run
385/// use fdars_core::smooth_basis::{BasisNbasisCvConfig, BasisType, BasisCriterion};
386///
387/// let config = BasisNbasisCvConfig {
388///     nbasis_range: (5, 25),
389///     criterion: BasisCriterion::Aic,
390///     ..BasisNbasisCvConfig::default()
391/// };
392/// ```
393#[derive(Debug, Clone, PartialEq)]
394pub struct BasisNbasisCvConfig {
395    /// Basis type (default: BSpline with order 4).
396    pub basis_type: BasisType,
397    /// Range of nbasis values to try, inclusive (default: (5, 30)).
398    pub nbasis_range: (usize, usize),
399    /// Roughness penalty lambda (default: 1e-4).
400    pub lambda: f64,
401    /// Penalty order (default: 2).
402    pub lfd_order: usize,
403    /// Number of CV folds (default: 5). Only used when `criterion` is `Cv`.
404    pub n_folds: usize,
405    /// Selection criterion (default: `Gcv`).
406    pub criterion: BasisCriterion,
407}
408
409impl Default for BasisNbasisCvConfig {
410    fn default() -> Self {
411        Self {
412            basis_type: BasisType::Bspline { order: 4 },
413            nbasis_range: (5, 30),
414            lambda: 1e-4,
415            lfd_order: 2,
416            n_folds: 5,
417            criterion: BasisCriterion::Gcv,
418        }
419    }
420}
421
422/// Select the optimal number of basis functions using a config struct.
423///
424/// This is the config-based alternative to [`basis_nbasis_cv`]. It takes data
425/// parameters directly and reads all tuning parameters from the config.
426///
427/// The `nbasis_range` tuple `(lo, hi)` is expanded to `lo..=hi` to form the
428/// candidate set.
429///
430/// # Arguments
431/// * `data` — Functional data matrix (n × m)
432/// * `argvals` — Evaluation points (length m)
433/// * `config` — Tuning parameters
434///
435/// # Errors
436///
437/// Returns [`crate::FdarError::ComputationFailed`] if no valid result is found
438/// for any nbasis in the search range.
439#[must_use = "expensive computation whose result should not be discarded"]
440pub fn basis_nbasis_cv_with_config(
441    data: &FdMatrix,
442    argvals: &[f64],
443    config: &BasisNbasisCvConfig,
444) -> Result<BasisNbasisCvResult, crate::FdarError> {
445    let nbasis_range: Vec<usize> = (config.nbasis_range.0..=config.nbasis_range.1).collect();
446    basis_nbasis_cv(
447        data,
448        argvals,
449        &nbasis_range,
450        &config.basis_type,
451        config.criterion,
452        config.n_folds,
453        config.lambda,
454    )
455    .ok_or_else(|| crate::FdarError::ComputationFailed {
456        operation: "basis_nbasis_cv_with_config",
457        detail: "no valid result found in nbasis CV search".to_string(),
458    })
459}
460
461// ─── Internal Helpers ───────────────────────────────────────────────────────
462
463/// Differentiate column-major basis matrix `lfd_order` times using gradient_uniform.
464fn differentiate_basis_columns(
465    basis: &[f64],
466    n_quad: usize,
467    nbasis: usize,
468    h: f64,
469    lfd_order: usize,
470) -> Vec<f64> {
471    let mut deriv = basis.to_vec();
472    for _ in 0..lfd_order {
473        let mut new_deriv = vec![0.0; n_quad * nbasis];
474        for j in 0..nbasis {
475            let col: Vec<f64> = (0..n_quad).map(|i| deriv[i + j * n_quad]).collect();
476            let grad = crate::helpers::gradient_uniform(&col, h);
477            for i in 0..n_quad {
478                new_deriv[i + j * n_quad] = grad[i];
479            }
480        }
481        deriv = new_deriv;
482    }
483    deriv
484}
485
486/// Integrate symmetric penalty: R[j,k] = ∫ D^m B_j · D^m B_k dt.
487fn integrate_symmetric_penalty(
488    deriv_basis: &[f64],
489    weights: &[f64],
490    k: usize,
491    n_quad: usize,
492) -> Vec<f64> {
493    let mut penalty = vec![0.0; k * k];
494    for j in 0..k {
495        for l in j..k {
496            let mut val = 0.0;
497            for i in 0..n_quad {
498                val += deriv_basis[i + j * n_quad] * deriv_basis[i + l * n_quad] * weights[i];
499            }
500            penalty[j + l * k] = val;
501            penalty[l + j * k] = val;
502        }
503    }
504    penalty
505}
506
507/// Evaluate basis functions on argvals, returning (flat column-major, actual_nbasis).
508fn evaluate_basis(argvals: &[f64], basis_type: &BasisType, nbasis: usize) -> (Vec<f64>, usize) {
509    let m = argvals.len();
510    match basis_type {
511        BasisType::Bspline { order } => {
512            let nknots = nbasis.saturating_sub(*order).max(2);
513            let basis = bspline_basis(argvals, nknots, *order);
514            let actual = basis.len() / m;
515            (basis, actual)
516        }
517        BasisType::Fourier { period } => {
518            let basis = fourier_basis_with_period(argvals, nbasis, *period);
519            (basis, nbasis)
520        }
521    }
522}
523
524/// Invert the penalized system matrix via Cholesky or SVD pseudoinverse.
525fn invert_penalized_system(system: &DMatrix<f64>, k: usize) -> Option<DMatrix<f64>> {
526    if let Some(chol) = system.clone().cholesky() {
527        return Some(chol.inverse());
528    }
529    // SVD fallback
530    let svd = nalgebra::SVD::new(system.clone(), true, true);
531    let u = svd.u.as_ref()?;
532    let v_t = svd.v_t.as_ref()?;
533    let max_sv: f64 = svd.singular_values.iter().copied().fold(0.0_f64, f64::max);
534    let eps = 1e-10 * max_sv;
535    let mut inv = DMatrix::<f64>::zeros(k, k);
536    for ii in 0..k {
537        for jj in 0..k {
538            let mut sum = 0.0;
539            for s in 0..k.min(svd.singular_values.len()) {
540                if svd.singular_values[s] > eps {
541                    sum += v_t[(s, ii)] / svd.singular_values[s] * u[(jj, s)];
542                }
543            }
544            inv[(ii, jj)] = sum;
545        }
546    }
547    Some(inv)
548}
549
550/// Project all curves onto basis, returning (coefficients, fitted, total_rss).
551fn project_all_curves(
552    data: &FdMatrix,
553    b_mat: &DMatrix<f64>,
554    proj: &DMatrix<f64>,
555    n: usize,
556    m: usize,
557    k: usize,
558) -> (FdMatrix, FdMatrix, f64) {
559    let mut all_coefs = FdMatrix::zeros(n, k);
560    let mut all_fitted = FdMatrix::zeros(n, m);
561    let mut total_rss = 0.0;
562
563    for i in 0..n {
564        let curve: Vec<f64> = (0..m).map(|j| data[(i, j)]).collect();
565        let y_vec = nalgebra::DVector::from_vec(curve.clone());
566        let coefs = proj * &y_vec;
567
568        for j in 0..k {
569            all_coefs[(i, j)] = coefs[j];
570        }
571        let fitted = b_mat * &coefs;
572        for j in 0..m {
573            all_fitted[(i, j)] = fitted[j];
574            let resid = curve[j] - fitted[j];
575            total_rss += resid * resid;
576        }
577    }
578
579    (all_coefs, all_fitted, total_rss)
580}
581
582/// Compute GCV score.
583fn compute_gcv(rss: f64, n_points: f64, edf: f64, m: usize) -> f64 {
584    let gcv_denom = 1.0 - edf / m as f64;
585    if gcv_denom.abs() > 1e-10 {
586        (rss / n_points) / (gcv_denom * gcv_denom)
587    } else {
588        f64::INFINITY
589    }
590}
591
592// ─── Nbasis Selection via CV ────────────────────────────────────────────────
593
594/// Criterion for nbasis selection.
595#[derive(Debug, Clone, Copy, PartialEq)]
596pub enum BasisCriterion {
597    /// Generalized cross-validation.
598    Gcv,
599    /// Leave-one-out cross-validation (k-fold).
600    Cv,
601    /// Akaike Information Criterion.
602    Aic,
603    /// Bayesian Information Criterion.
604    Bic,
605}
606
607/// Result of nbasis selection.
608#[derive(Debug, Clone, PartialEq)]
609pub struct BasisNbasisCvResult {
610    /// Optimal number of basis functions.
611    pub optimal_nbasis: usize,
612    /// Score for each nbasis tested.
613    pub scores: Vec<f64>,
614    /// Range of nbasis values tested.
615    pub nbasis_range: Vec<usize>,
616    /// Criterion used.
617    pub criterion: BasisCriterion,
618}
619
620/// Evaluate information criterion (GCV/AIC/BIC) for a range of nbasis values.
621fn evaluate_nbasis_info_criterion(
622    data: &FdMatrix,
623    argvals: &[f64],
624    nbasis_range: &[usize],
625    basis_type: &BasisType,
626    criterion: BasisCriterion,
627    lambda: f64,
628) -> Vec<f64> {
629    let mut scores = Vec::with_capacity(nbasis_range.len());
630    for &nb in nbasis_range {
631        if nb < 2 {
632            scores.push(f64::INFINITY);
633            continue;
634        }
635        let penalty = match basis_type {
636            BasisType::Bspline { order } => bspline_penalty_matrix(argvals, nb, *order, 2),
637            BasisType::Fourier { period } => fourier_penalty_matrix(nb, *period, 2),
638        };
639        let fdpar = FdPar {
640            basis_type: basis_type.clone(),
641            nbasis: nb,
642            lambda,
643            lfd_order: 2,
644            penalty_matrix: penalty,
645        };
646        match smooth_basis(data, argvals, &fdpar) {
647            Ok(result) => {
648                let score = match criterion {
649                    BasisCriterion::Gcv => result.gcv,
650                    BasisCriterion::Aic => result.aic,
651                    BasisCriterion::Bic => result.bic,
652                    BasisCriterion::Cv => unreachable!(),
653                };
654                scores.push(score);
655            }
656            Err(_) => scores.push(f64::INFINITY),
657        }
658    }
659    scores
660}
661
662/// Evaluate nbasis via k-fold cross-validation of reconstruction error.
663fn evaluate_nbasis_cv(
664    data: &FdMatrix,
665    argvals: &[f64],
666    nbasis_range: &[usize],
667    basis_type: &BasisType,
668    lambda: f64,
669    n_folds: usize,
670) -> Vec<f64> {
671    let (n, m) = data.shape();
672    let n_folds = n_folds.max(2);
673    let folds = crate::cv::create_folds(n, n_folds, 42);
674    let mut scores = Vec::with_capacity(nbasis_range.len());
675
676    for &nb in nbasis_range {
677        if nb < 2 {
678            scores.push(f64::INFINITY);
679            continue;
680        }
681        let penalty = match basis_type {
682            BasisType::Bspline { order } => bspline_penalty_matrix(argvals, nb, *order, 2),
683            BasisType::Fourier { period } => fourier_penalty_matrix(nb, *period, 2),
684        };
685
686        let mut total_mse = 0.0;
687        let mut total_count = 0;
688
689        for fold in 0..n_folds {
690            let (train_idx, test_idx) = crate::cv::fold_indices(&folds, fold);
691            if train_idx.is_empty() || test_idx.is_empty() {
692                continue;
693            }
694            let train_data = crate::cv::subset_rows(data, &train_idx);
695            let fdpar = FdPar {
696                basis_type: basis_type.clone(),
697                nbasis: nb,
698                lambda,
699                lfd_order: 2,
700                penalty_matrix: penalty.clone(),
701            };
702
703            if let Ok(train_result) = smooth_basis(&train_data, argvals, &fdpar) {
704                let (basis_flat, actual_k) = evaluate_basis(argvals, basis_type, nb);
705                let b_mat = DMatrix::from_column_slice(m, actual_k, &basis_flat);
706                let r_mat =
707                    DMatrix::from_column_slice(actual_k, actual_k, &train_result.penalty_matrix);
708                let btb = b_mat.transpose() * &b_mat;
709                let ridge_eps = 1e-10;
710                let system: DMatrix<f64> = &btb
711                    + lambda * &r_mat
712                    + ridge_eps * DMatrix::<f64>::identity(actual_k, actual_k);
713
714                if let Some(system_inv) = invert_penalized_system(&system, actual_k) {
715                    let proj = &system_inv * b_mat.transpose();
716                    for &ti in &test_idx {
717                        let curve: Vec<f64> = (0..m).map(|j| data[(ti, j)]).collect();
718                        let y_vec = nalgebra::DVector::from_vec(curve.clone());
719                        let coefs = &proj * &y_vec;
720                        let fitted = &b_mat * &coefs;
721                        let mse: f64 =
722                            (0..m).map(|j| (curve[j] - fitted[j]).powi(2)).sum::<f64>() / m as f64;
723                        total_mse += mse;
724                        total_count += 1;
725                    }
726                }
727            }
728        }
729
730        if total_count > 0 {
731            scores.push(total_mse / f64::from(total_count));
732        } else {
733            scores.push(f64::INFINITY);
734        }
735    }
736    scores
737}
738
739/// Select the optimal number of basis functions using multiple criteria
740/// (R's `fdata2basis_cv`).
741pub fn basis_nbasis_cv(
742    data: &FdMatrix,
743    argvals: &[f64],
744    nbasis_range: &[usize],
745    basis_type: &BasisType,
746    criterion: BasisCriterion,
747    n_folds: usize,
748    lambda: f64,
749) -> Option<BasisNbasisCvResult> {
750    let (n, m) = data.shape();
751    if n == 0 || m == 0 || argvals.len() != m || nbasis_range.is_empty() {
752        return None;
753    }
754
755    let scores = match criterion {
756        BasisCriterion::Gcv | BasisCriterion::Aic | BasisCriterion::Bic => {
757            evaluate_nbasis_info_criterion(
758                data,
759                argvals,
760                nbasis_range,
761                basis_type,
762                criterion,
763                lambda,
764            )
765        }
766        BasisCriterion::Cv => {
767            evaluate_nbasis_cv(data, argvals, nbasis_range, basis_type, lambda, n_folds)
768        }
769    };
770
771    let (best_idx, _) = scores
772        .iter()
773        .enumerate()
774        .min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))?;
775
776    Some(BasisNbasisCvResult {
777        optimal_nbasis: nbasis_range[best_idx],
778        scores,
779        nbasis_range: nbasis_range.to_vec(),
780        criterion,
781    })
782}
783
784#[cfg(test)]
785mod tests {
786    use super::*;
787    use crate::test_helpers::uniform_grid;
788    use std::f64::consts::PI;
789
790    #[test]
791    fn test_bspline_penalty_matrix_symmetric() {
792        let t = uniform_grid(101);
793        let penalty = bspline_penalty_matrix(&t, 15, 4, 2);
794        let _k = 15; // may differ from actual due to knot construction
795        let actual_k = (penalty.len() as f64).sqrt() as usize;
796        for i in 0..actual_k {
797            for j in 0..actual_k {
798                assert!(
799                    (penalty[i + j * actual_k] - penalty[j + i * actual_k]).abs() < 1e-10,
800                    "Penalty matrix not symmetric at ({}, {})",
801                    i,
802                    j
803                );
804            }
805        }
806    }
807
808    #[test]
809    fn test_bspline_penalty_matrix_positive_semidefinite() {
810        let t = uniform_grid(101);
811        let penalty = bspline_penalty_matrix(&t, 10, 4, 2);
812        let k = (penalty.len() as f64).sqrt() as usize;
813        // Diagonal elements should be non-negative
814        for i in 0..k {
815            assert!(
816                penalty[i + i * k] >= -1e-10,
817                "Diagonal element {} is negative: {}",
818                i,
819                penalty[i + i * k]
820            );
821        }
822    }
823
824    #[test]
825    fn test_fourier_penalty_diagonal() {
826        let penalty = fourier_penalty_matrix(7, 1.0, 2);
827        // Should be diagonal
828        for i in 0..7 {
829            for j in 0..7 {
830                if i != j {
831                    assert!(
832                        penalty[i + j * 7].abs() < 1e-10,
833                        "Off-diagonal ({},{}) = {}",
834                        i,
835                        j,
836                        penalty[i + j * 7]
837                    );
838                }
839            }
840        }
841        // Constant term should have zero penalty
842        assert!(penalty[0].abs() < 1e-10);
843        // Higher frequency terms should have larger penalties
844        assert!(penalty[1 + 7] > 0.0);
845        assert!(penalty[3 + 3 * 7] > penalty[1 + 7]);
846    }
847
848    #[test]
849    fn test_smooth_basis_bspline() {
850        let m = 101;
851        let n = 5;
852        let t = uniform_grid(m);
853
854        // Generate noisy sine curves
855        let mut data = FdMatrix::zeros(n, m);
856        for i in 0..n {
857            for j in 0..m {
858                data[(i, j)] = (2.0 * PI * t[j]).sin() + 0.1 * (i as f64 * 0.3 + j as f64 * 0.01);
859            }
860        }
861
862        let nbasis = 15;
863        let penalty = bspline_penalty_matrix(&t, nbasis, 4, 2);
864        let _actual_k = (penalty.len() as f64).sqrt() as usize;
865
866        let fdpar = FdPar {
867            basis_type: BasisType::Bspline { order: 4 },
868            nbasis,
869            lambda: 1e-4,
870            lfd_order: 2,
871            penalty_matrix: penalty,
872        };
873
874        let result = smooth_basis(&data, &t, &fdpar);
875        assert!(result.is_ok(), "smooth_basis should succeed");
876
877        let res = result.unwrap();
878        assert_eq!(res.fitted.shape(), (n, m));
879        assert_eq!(res.coefficients.nrows(), n);
880        assert!(res.edf > 0.0, "EDF should be positive");
881        assert!(res.gcv > 0.0, "GCV should be positive");
882    }
883
884    #[test]
885    fn test_smooth_basis_fourier() {
886        let m = 101;
887        let n = 3;
888        let t = uniform_grid(m);
889
890        let mut data = FdMatrix::zeros(n, m);
891        for i in 0..n {
892            for j in 0..m {
893                data[(i, j)] = (2.0 * PI * t[j]).sin() + (4.0 * PI * t[j]).cos();
894            }
895        }
896
897        let nbasis = 7;
898        let period = 1.0;
899        let penalty = fourier_penalty_matrix(nbasis, period, 2);
900
901        let fdpar = FdPar {
902            basis_type: BasisType::Fourier { period },
903            nbasis,
904            lambda: 1e-6,
905            lfd_order: 2,
906            penalty_matrix: penalty,
907        };
908
909        let result = smooth_basis(&data, &t, &fdpar);
910        assert!(result.is_ok());
911
912        let res = result.unwrap();
913        // Fourier basis should fit periodic data well
914        for j in 0..m {
915            let expected = (2.0 * PI * t[j]).sin() + (4.0 * PI * t[j]).cos();
916            assert!(
917                (res.fitted[(0, j)] - expected).abs() < 0.1,
918                "Fourier fit poor at j={}: got {}, expected {}",
919                j,
920                res.fitted[(0, j)],
921                expected
922            );
923        }
924    }
925
926    #[test]
927    fn test_smooth_basis_gcv_selects_reasonable_lambda() {
928        let m = 101;
929        let n = 5;
930        let t = uniform_grid(m);
931
932        let mut data = FdMatrix::zeros(n, m);
933        for i in 0..n {
934            for j in 0..m {
935                data[(i, j)] =
936                    (2.0 * PI * t[j]).sin() + 0.1 * ((i * 37 + j * 13) % 20) as f64 / 20.0;
937            }
938        }
939
940        let basis_type = BasisType::Bspline { order: 4 };
941        let result = smooth_basis_gcv(&data, &t, &basis_type, 15, 2, (-8.0, 4.0), 25);
942        assert!(result.is_some(), "GCV search should succeed");
943    }
944
945    #[test]
946    fn test_smooth_basis_large_lambda_reduces_edf() {
947        let m = 101;
948        let n = 3;
949        let t = uniform_grid(m);
950
951        let mut data = FdMatrix::zeros(n, m);
952        for i in 0..n {
953            for j in 0..m {
954                data[(i, j)] = (2.0 * PI * t[j]).sin();
955            }
956        }
957
958        let nbasis = 15;
959        let penalty = bspline_penalty_matrix(&t, nbasis, 4, 2);
960        let _actual_k = (penalty.len() as f64).sqrt() as usize;
961
962        let fdpar_small = FdPar {
963            basis_type: BasisType::Bspline { order: 4 },
964            nbasis,
965            lambda: 1e-8,
966            lfd_order: 2,
967            penalty_matrix: penalty.clone(),
968        };
969        let fdpar_large = FdPar {
970            basis_type: BasisType::Bspline { order: 4 },
971            nbasis,
972            lambda: 1e2,
973            lfd_order: 2,
974            penalty_matrix: penalty,
975        };
976
977        let res_small = smooth_basis(&data, &t, &fdpar_small).unwrap();
978        let res_large = smooth_basis(&data, &t, &fdpar_large).unwrap();
979
980        assert!(
981            res_large.edf < res_small.edf,
982            "Larger lambda should reduce EDF: {} vs {}",
983            res_large.edf,
984            res_small.edf
985        );
986    }
987
988    // ============== basis_nbasis_cv tests ==============
989
990    #[test]
991    fn test_basis_nbasis_cv_gcv() {
992        let m = 101;
993        let n = 5;
994        let t = uniform_grid(m);
995        let mut data = FdMatrix::zeros(n, m);
996        for i in 0..n {
997            for j in 0..m {
998                data[(i, j)] =
999                    (2.0 * PI * t[j]).sin() + 0.1 * ((i * 37 + j * 13) % 20) as f64 / 20.0;
1000            }
1001        }
1002
1003        let nbasis_range: Vec<usize> = (4..=20).step_by(2).collect();
1004        let result = basis_nbasis_cv(
1005            &data,
1006            &t,
1007            &nbasis_range,
1008            &BasisType::Bspline { order: 4 },
1009            BasisCriterion::Gcv,
1010            5,
1011            1e-4,
1012        );
1013        assert!(result.is_some());
1014        let res = result.unwrap();
1015        assert!(nbasis_range.contains(&res.optimal_nbasis));
1016        assert_eq!(res.scores.len(), nbasis_range.len());
1017        assert_eq!(res.criterion, BasisCriterion::Gcv);
1018    }
1019
1020    #[test]
1021    fn test_basis_nbasis_cv_aic_bic() {
1022        let m = 51;
1023        let n = 5;
1024        let t = uniform_grid(m);
1025        let mut data = FdMatrix::zeros(n, m);
1026        for i in 0..n {
1027            for j in 0..m {
1028                data[(i, j)] = (2.0 * PI * t[j]).sin();
1029            }
1030        }
1031
1032        let nbasis_range: Vec<usize> = vec![5, 7, 9, 11];
1033        let aic_result = basis_nbasis_cv(
1034            &data,
1035            &t,
1036            &nbasis_range,
1037            &BasisType::Bspline { order: 4 },
1038            BasisCriterion::Aic,
1039            5,
1040            0.0,
1041        );
1042        let bic_result = basis_nbasis_cv(
1043            &data,
1044            &t,
1045            &nbasis_range,
1046            &BasisType::Bspline { order: 4 },
1047            BasisCriterion::Bic,
1048            5,
1049            0.0,
1050        );
1051        assert!(aic_result.is_some());
1052        assert!(bic_result.is_some());
1053    }
1054
1055    #[test]
1056    fn test_basis_nbasis_cv_kfold() {
1057        let m = 51;
1058        let n = 10;
1059        let t = uniform_grid(m);
1060        let mut data = FdMatrix::zeros(n, m);
1061        for i in 0..n {
1062            for j in 0..m {
1063                data[(i, j)] = (2.0 * PI * t[j]).sin() + 0.05 * ((i * 7 + j * 3) % 10) as f64;
1064            }
1065        }
1066
1067        let nbasis_range: Vec<usize> = vec![5, 7, 9];
1068        let result = basis_nbasis_cv(
1069            &data,
1070            &t,
1071            &nbasis_range,
1072            &BasisType::Bspline { order: 4 },
1073            BasisCriterion::Cv,
1074            5,
1075            1e-4,
1076        );
1077        assert!(result.is_some());
1078        let res = result.unwrap();
1079        assert!(nbasis_range.contains(&res.optimal_nbasis));
1080        assert_eq!(res.criterion, BasisCriterion::Cv);
1081    }
1082
1083    // ============== Comprehensive additional tests ==============
1084
1085    // Helper: generate standard test data (sine + high-freq component)
1086    fn make_test_data(n: usize, m: usize) -> (FdMatrix, Vec<f64>) {
1087        let t: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
1088        let mut data = FdMatrix::zeros(n, m);
1089        for i in 0..n {
1090            for j in 0..m {
1091                data[(i, j)] = (2.0 * PI * t[j]).sin()
1092                    + 0.1 * (10.0 * t[j]).sin()
1093                    + 0.05 * ((i * 37 + j * 13) % 20) as f64 / 20.0;
1094            }
1095        }
1096        (data, t)
1097    }
1098
1099    // Helper: create an FdPar for B-spline smoothing
1100    fn make_bspline_fdpar(argvals: &[f64], nbasis: usize, lambda: f64) -> FdPar {
1101        let penalty = bspline_penalty_matrix(argvals, nbasis, 4, 2);
1102        FdPar {
1103            basis_type: BasisType::Bspline { order: 4 },
1104            nbasis,
1105            lambda,
1106            lfd_order: 2,
1107            penalty_matrix: penalty,
1108        }
1109    }
1110
1111    // Helper: create an FdPar for Fourier smoothing
1112    fn make_fourier_fdpar(nbasis: usize, period: f64, lambda: f64) -> FdPar {
1113        let penalty = fourier_penalty_matrix(nbasis, period, 2);
1114        FdPar {
1115            basis_type: BasisType::Fourier { period },
1116            nbasis,
1117            lambda,
1118            lfd_order: 2,
1119            penalty_matrix: penalty,
1120        }
1121    }
1122
1123    // ─── BasisType enum tests ───────────────────────────────────────────────
1124
1125    #[test]
1126    fn test_basis_type_bspline_variant() {
1127        let bt = BasisType::Bspline { order: 4 };
1128        assert_eq!(bt, BasisType::Bspline { order: 4 });
1129        // Different orders are not equal
1130        assert_ne!(bt, BasisType::Bspline { order: 3 });
1131    }
1132
1133    #[test]
1134    fn test_basis_type_fourier_variant() {
1135        let bt = BasisType::Fourier { period: 1.0 };
1136        assert_eq!(bt, BasisType::Fourier { period: 1.0 });
1137        assert_ne!(bt, BasisType::Fourier { period: 2.0 });
1138    }
1139
1140    #[test]
1141    fn test_basis_type_cross_variant_inequality() {
1142        let bspline = BasisType::Bspline { order: 4 };
1143        let fourier = BasisType::Fourier { period: 1.0 };
1144        assert_ne!(bspline, fourier);
1145    }
1146
1147    #[test]
1148    fn test_basis_type_clone_and_debug() {
1149        let bt = BasisType::Bspline { order: 4 };
1150        let cloned = bt.clone();
1151        assert_eq!(bt, cloned);
1152        let debug_str = format!("{:?}", bt);
1153        assert!(debug_str.contains("Bspline"));
1154        assert!(debug_str.contains("4"));
1155    }
1156
1157    // ─── FdPar struct tests ─────────────────────────────────────────────────
1158
1159    #[test]
1160    fn test_fdpar_construction_and_fields() {
1161        let penalty = vec![1.0, 0.0, 0.0, 1.0];
1162        let fdpar = FdPar {
1163            basis_type: BasisType::Bspline { order: 4 },
1164            nbasis: 2,
1165            lambda: 0.01,
1166            lfd_order: 2,
1167            penalty_matrix: penalty.clone(),
1168        };
1169        assert_eq!(fdpar.nbasis, 2);
1170        assert!((fdpar.lambda - 0.01).abs() < 1e-15);
1171        assert_eq!(fdpar.lfd_order, 2);
1172        assert_eq!(fdpar.penalty_matrix.len(), 4);
1173    }
1174
1175    #[test]
1176    fn test_fdpar_clone_and_debug() {
1177        let t = uniform_grid(50);
1178        let fdpar = make_bspline_fdpar(&t, 8, 1e-3);
1179        let cloned = fdpar.clone();
1180        assert_eq!(fdpar, cloned);
1181        let debug_str = format!("{:?}", fdpar);
1182        assert!(debug_str.contains("FdPar"));
1183    }
1184
1185    // ─── BasisCriterion enum tests ──────────────────────────────────────────
1186
1187    #[test]
1188    fn test_basis_criterion_variants() {
1189        assert_eq!(BasisCriterion::Gcv, BasisCriterion::Gcv);
1190        assert_eq!(BasisCriterion::Cv, BasisCriterion::Cv);
1191        assert_eq!(BasisCriterion::Aic, BasisCriterion::Aic);
1192        assert_eq!(BasisCriterion::Bic, BasisCriterion::Bic);
1193        assert_ne!(BasisCriterion::Gcv, BasisCriterion::Aic);
1194        assert_ne!(BasisCriterion::Cv, BasisCriterion::Bic);
1195    }
1196
1197    #[test]
1198    fn test_basis_criterion_copy() {
1199        let c = BasisCriterion::Gcv;
1200        let copied = c; // Copy
1201        assert_eq!(c, copied);
1202    }
1203
1204    #[test]
1205    fn test_basis_criterion_debug() {
1206        let debug_str = format!("{:?}", BasisCriterion::Bic);
1207        assert!(debug_str.contains("Bic"));
1208    }
1209
1210    // ─── SmoothBasisResult tests ────────────────────────────────────────────
1211
1212    #[test]
1213    fn test_smooth_basis_result_all_fields() {
1214        let (data, t) = make_test_data(3, 50);
1215        let fdpar = make_bspline_fdpar(&t, 10, 1e-4);
1216        let res = smooth_basis(&data, &t, &fdpar).unwrap();
1217
1218        // coefficients: n curves x k basis functions
1219        assert_eq!(res.coefficients.nrows(), 3);
1220        assert!(res.coefficients.ncols() > 0);
1221        assert_eq!(res.nbasis, res.coefficients.ncols());
1222        // fitted: n x m
1223        assert_eq!(res.fitted.shape(), (3, 50));
1224        // edf should be between 1 and nbasis
1225        assert!(res.edf > 0.0 && res.edf <= res.nbasis as f64);
1226        // gcv, aic, bic should be finite
1227        assert!(res.gcv.is_finite());
1228        assert!(res.aic.is_finite());
1229        assert!(res.bic.is_finite());
1230        // penalty_matrix should be k x k
1231        let k = res.nbasis;
1232        assert_eq!(res.penalty_matrix.len(), k * k);
1233    }
1234
1235    #[test]
1236    fn test_smooth_basis_result_clone() {
1237        let (data, t) = make_test_data(2, 50);
1238        let fdpar = make_bspline_fdpar(&t, 8, 1e-3);
1239        let res = smooth_basis(&data, &t, &fdpar).unwrap();
1240        let cloned = res.clone();
1241        assert_eq!(res, cloned);
1242    }
1243
1244    // ─── smooth_basis: B-spline detailed tests ──────────────────────────────
1245
1246    #[test]
1247    fn test_smooth_basis_bspline_coefficient_shape() {
1248        let (data, t) = make_test_data(4, 50);
1249        let nbasis = 12;
1250        let fdpar = make_bspline_fdpar(&t, nbasis, 1e-4);
1251        let res = smooth_basis(&data, &t, &fdpar).unwrap();
1252        assert_eq!(res.coefficients.nrows(), 4);
1253        // actual nbasis may differ from requested due to knot construction
1254        assert!(res.coefficients.ncols() >= 2);
1255        assert_eq!(res.nbasis, res.coefficients.ncols());
1256    }
1257
1258    #[test]
1259    fn test_smooth_basis_bspline_fitted_values_shape() {
1260        let m = 80;
1261        let n = 6;
1262        let (data, t) = make_test_data(n, m);
1263        let fdpar = make_bspline_fdpar(&t, 15, 1e-4);
1264        let res = smooth_basis(&data, &t, &fdpar).unwrap();
1265        assert_eq!(res.fitted.shape(), (n, m));
1266    }
1267
1268    #[test]
1269    fn test_smooth_basis_bspline_zero_lambda_interpolates() {
1270        // With lambda=0, the smoother should nearly interpolate the data
1271        let m = 30;
1272        let n = 2;
1273        let (data, t) = make_test_data(n, m);
1274        let fdpar = make_bspline_fdpar(&t, 15, 0.0);
1275        let res = smooth_basis(&data, &t, &fdpar).unwrap();
1276
1277        // Residuals should be very small (near interpolation)
1278        let mut max_resid = 0.0_f64;
1279        for i in 0..n {
1280            for j in 0..m {
1281                let resid = (data[(i, j)] - res.fitted[(i, j)]).abs();
1282                max_resid = max_resid.max(resid);
1283            }
1284        }
1285        assert!(
1286            max_resid < 0.5,
1287            "Zero-lambda B-spline should closely interpolate; max_resid = {}",
1288            max_resid
1289        );
1290    }
1291
1292    #[test]
1293    fn test_smooth_basis_bspline_large_lambda_oversmooths() {
1294        // With very large lambda, the fit should be much smoother (lower variance)
1295        // than with small lambda
1296        let m = 50;
1297        let n = 1;
1298        let (data, t) = make_test_data(n, m);
1299
1300        let fdpar_small = make_bspline_fdpar(&t, 15, 1e-6);
1301        let res_small = smooth_basis(&data, &t, &fdpar_small).unwrap();
1302
1303        let fdpar_large = make_bspline_fdpar(&t, 15, 1e6);
1304        let res_large = smooth_basis(&data, &t, &fdpar_large).unwrap();
1305
1306        let compute_variance = |fitted: &FdMatrix, row: usize, ncols: usize| -> f64 {
1307            let vals: Vec<f64> = (0..ncols).map(|j| fitted[(row, j)]).collect();
1308            let mean = vals.iter().sum::<f64>() / ncols as f64;
1309            vals.iter().map(|&v| (v - mean).powi(2)).sum::<f64>() / ncols as f64
1310        };
1311
1312        let var_small = compute_variance(&res_small.fitted, 0, m);
1313        let var_large = compute_variance(&res_large.fitted, 0, m);
1314        assert!(
1315            var_large < var_small,
1316            "Large lambda should yield lower variance fit: var_large={}, var_small={}",
1317            var_large,
1318            var_small
1319        );
1320    }
1321
1322    #[test]
1323    fn test_smooth_basis_bspline_penalty_effect_on_smoothness() {
1324        // Compare roughness of fits with small vs large lambda
1325        let m = 50;
1326        let n = 1;
1327        let (data, t) = make_test_data(n, m);
1328
1329        let fdpar_small = make_bspline_fdpar(&t, 15, 1e-8);
1330        let fdpar_large = make_bspline_fdpar(&t, 15, 1.0);
1331
1332        let res_small = smooth_basis(&data, &t, &fdpar_small).unwrap();
1333        let res_large = smooth_basis(&data, &t, &fdpar_large).unwrap();
1334
1335        // Measure roughness as sum of squared second differences
1336        let roughness = |fitted: &FdMatrix, row: usize, ncols: usize| -> f64 {
1337            (1..ncols - 1)
1338                .map(|j| {
1339                    let d2 = fitted[(row, j + 1)] - 2.0 * fitted[(row, j)] + fitted[(row, j - 1)];
1340                    d2 * d2
1341                })
1342                .sum::<f64>()
1343        };
1344
1345        let r_small = roughness(&res_small.fitted, 0, m);
1346        let r_large = roughness(&res_large.fitted, 0, m);
1347        assert!(
1348            r_large < r_small,
1349            "Larger lambda should produce smoother fit: roughness_large={}, roughness_small={}",
1350            r_large,
1351            r_small
1352        );
1353    }
1354
1355    #[test]
1356    fn test_smooth_basis_bspline_single_curve() {
1357        let m = 50;
1358        let (data, t) = make_test_data(1, m);
1359        let fdpar = make_bspline_fdpar(&t, 10, 1e-4);
1360        let res = smooth_basis(&data, &t, &fdpar).unwrap();
1361        assert_eq!(res.fitted.nrows(), 1);
1362        assert_eq!(res.fitted.ncols(), m);
1363        assert!(res.gcv.is_finite());
1364    }
1365
1366    #[test]
1367    fn test_smooth_basis_bspline_many_curves() {
1368        let m = 50;
1369        let n = 20;
1370        let (data, t) = make_test_data(n, m);
1371        let fdpar = make_bspline_fdpar(&t, 10, 1e-4);
1372        let res = smooth_basis(&data, &t, &fdpar).unwrap();
1373        assert_eq!(res.fitted.nrows(), n);
1374        assert_eq!(res.coefficients.nrows(), n);
1375    }
1376
1377    #[test]
1378    fn test_smooth_basis_bspline_minimal_nbasis() {
1379        // nbasis = 2 is the minimum allowed
1380        let m = 50;
1381        let (data, t) = make_test_data(1, m);
1382        let fdpar = make_bspline_fdpar(&t, 2, 1e-4);
1383        let res = smooth_basis(&data, &t, &fdpar);
1384        // Should succeed (or at least not panic); the fit may be poor
1385        assert!(res.is_ok());
1386    }
1387
1388    #[test]
1389    fn test_smooth_basis_bspline_different_orders() {
1390        let m = 50;
1391        let (data, t) = make_test_data(2, m);
1392        // Order 3 (quadratic B-splines)
1393        let penalty3 = bspline_penalty_matrix(&t, 10, 3, 2);
1394        let fdpar3 = FdPar {
1395            basis_type: BasisType::Bspline { order: 3 },
1396            nbasis: 10,
1397            lambda: 1e-4,
1398            lfd_order: 2,
1399            penalty_matrix: penalty3,
1400        };
1401        let res3 = smooth_basis(&data, &t, &fdpar3);
1402        assert!(res3.is_ok());
1403
1404        // Order 5 (quartic B-splines)
1405        let penalty5 = bspline_penalty_matrix(&t, 10, 5, 2);
1406        let fdpar5 = FdPar {
1407            basis_type: BasisType::Bspline { order: 5 },
1408            nbasis: 10,
1409            lambda: 1e-4,
1410            lfd_order: 2,
1411            penalty_matrix: penalty5,
1412        };
1413        let res5 = smooth_basis(&data, &t, &fdpar5);
1414        assert!(res5.is_ok());
1415    }
1416
1417    // ─── smooth_basis: Fourier detailed tests ───────────────────────────────
1418
1419    #[test]
1420    fn test_smooth_basis_fourier_coefficient_shape() {
1421        let m = 50;
1422        let n = 3;
1423        let t = uniform_grid(m);
1424        let mut data = FdMatrix::zeros(n, m);
1425        for i in 0..n {
1426            for j in 0..m {
1427                data[(i, j)] = (2.0 * PI * t[j]).sin();
1428            }
1429        }
1430        let nbasis = 7;
1431        let fdpar = make_fourier_fdpar(nbasis, 1.0, 1e-6);
1432        let res = smooth_basis(&data, &t, &fdpar).unwrap();
1433        assert_eq!(res.coefficients.nrows(), n);
1434        assert_eq!(res.coefficients.ncols(), nbasis);
1435        assert_eq!(res.nbasis, nbasis);
1436    }
1437
1438    #[test]
1439    fn test_smooth_basis_fourier_fits_pure_sine() {
1440        // Fourier basis should perfectly fit a pure sine with enough basis fns
1441        let m = 100;
1442        let t = uniform_grid(m);
1443        let mut data = FdMatrix::zeros(1, m);
1444        for j in 0..m {
1445            data[(0, j)] = (2.0 * PI * t[j]).sin();
1446        }
1447        let fdpar = make_fourier_fdpar(5, 1.0, 1e-8);
1448        let res = smooth_basis(&data, &t, &fdpar).unwrap();
1449
1450        for j in 0..m {
1451            let expected = (2.0 * PI * t[j]).sin();
1452            assert!(
1453                (res.fitted[(0, j)] - expected).abs() < 0.05,
1454                "Fourier should fit pure sine; j={}, got={}, expected={}",
1455                j,
1456                res.fitted[(0, j)],
1457                expected
1458            );
1459        }
1460    }
1461
1462    #[test]
1463    fn test_smooth_basis_fourier_different_periods() {
1464        let m = 50;
1465        let t = uniform_grid(m);
1466        let mut data = FdMatrix::zeros(1, m);
1467        for j in 0..m {
1468            data[(0, j)] = (2.0 * PI * t[j]).sin();
1469        }
1470
1471        // Period = 1.0 (matches the data)
1472        let fdpar1 = make_fourier_fdpar(7, 1.0, 1e-6);
1473        let res1 = smooth_basis(&data, &t, &fdpar1).unwrap();
1474
1475        // Period = 2.0 (mismatch, but should still produce a result)
1476        let fdpar2 = make_fourier_fdpar(7, 2.0, 1e-6);
1477        let res2 = smooth_basis(&data, &t, &fdpar2).unwrap();
1478
1479        // Both should succeed and have valid shapes
1480        assert_eq!(res1.fitted.shape(), (1, m));
1481        assert_eq!(res2.fitted.shape(), (1, m));
1482    }
1483
1484    #[test]
1485    fn test_smooth_basis_fourier_zero_lambda() {
1486        let m = 50;
1487        let t = uniform_grid(m);
1488        let mut data = FdMatrix::zeros(1, m);
1489        for j in 0..m {
1490            data[(0, j)] = (2.0 * PI * t[j]).sin() + (4.0 * PI * t[j]).cos();
1491        }
1492        let fdpar = make_fourier_fdpar(9, 1.0, 0.0);
1493        let res = smooth_basis(&data, &t, &fdpar).unwrap();
1494        assert_eq!(res.fitted.shape(), (1, m));
1495        // EDF should be close to nbasis with zero penalty
1496        assert!(res.edf > 1.0);
1497    }
1498
1499    #[test]
1500    fn test_smooth_basis_fourier_large_lambda() {
1501        let m = 50;
1502        let t = uniform_grid(m);
1503        let mut data = FdMatrix::zeros(1, m);
1504        for j in 0..m {
1505            data[(0, j)] = (2.0 * PI * t[j]).sin();
1506        }
1507        let fdpar = make_fourier_fdpar(9, 1.0, 1e6);
1508        let res = smooth_basis(&data, &t, &fdpar).unwrap();
1509        // EDF should be very small with huge penalty
1510        assert!(
1511            res.edf < 5.0,
1512            "Large lambda should reduce EDF; edf={}",
1513            res.edf
1514        );
1515    }
1516
1517    // ─── smooth_basis: Lambda comparison tests ──────────────────────────────
1518
1519    #[test]
1520    fn test_smooth_basis_lambda_gradient_edf() {
1521        // EDF should monotonically decrease with increasing lambda
1522        let m = 50;
1523        let (data, t) = make_test_data(3, m);
1524        let lambdas = [1e-8, 1e-4, 1e-2, 1.0, 1e2];
1525        let mut prev_edf = f64::INFINITY;
1526        for &lam in &lambdas {
1527            let fdpar = make_bspline_fdpar(&t, 12, lam);
1528            let res = smooth_basis(&data, &t, &fdpar).unwrap();
1529            assert!(
1530                res.edf <= prev_edf + 0.01,
1531                "EDF should decrease: lambda={}, edf={}, prev_edf={}",
1532                lam,
1533                res.edf,
1534                prev_edf
1535            );
1536            prev_edf = res.edf;
1537        }
1538    }
1539
1540    #[test]
1541    fn test_smooth_basis_lambda_gradient_rss() {
1542        // RSS should monotonically increase with increasing lambda
1543        let m = 50;
1544        let n = 2;
1545        let (data, t) = make_test_data(n, m);
1546        let lambdas = [0.0, 1e-6, 1e-2, 1.0, 1e4];
1547        let mut prev_rss = -1.0;
1548        for &lam in &lambdas {
1549            let fdpar = make_bspline_fdpar(&t, 12, lam);
1550            let res = smooth_basis(&data, &t, &fdpar).unwrap();
1551            let mut rss = 0.0;
1552            for i in 0..n {
1553                for j in 0..m {
1554                    rss += (data[(i, j)] - res.fitted[(i, j)]).powi(2);
1555                }
1556            }
1557            assert!(
1558                rss >= prev_rss - 1e-8,
1559                "RSS should increase: lambda={}, rss={}, prev_rss={}",
1560                lam,
1561                rss,
1562                prev_rss
1563            );
1564            prev_rss = rss;
1565        }
1566    }
1567
1568    // ─── smooth_basis: Error cases ──────────────────────────────────────────
1569
1570    #[test]
1571    fn test_smooth_basis_empty_data_rows() {
1572        let t = uniform_grid(50);
1573        let data = FdMatrix::zeros(0, 50);
1574        let fdpar = make_bspline_fdpar(&t, 10, 1e-4);
1575        let res = smooth_basis(&data, &t, &fdpar);
1576        assert!(res.is_err());
1577    }
1578
1579    #[test]
1580    fn test_smooth_basis_empty_data_cols() {
1581        let data = FdMatrix::zeros(5, 0);
1582        let fdpar = FdPar {
1583            basis_type: BasisType::Bspline { order: 4 },
1584            nbasis: 10,
1585            lambda: 1e-4,
1586            lfd_order: 2,
1587            penalty_matrix: vec![0.0; 100],
1588        };
1589        let res = smooth_basis(&data, &[], &fdpar);
1590        assert!(res.is_err());
1591    }
1592
1593    #[test]
1594    fn test_smooth_basis_mismatched_argvals() {
1595        let t = uniform_grid(50);
1596        let data = FdMatrix::zeros(3, 40); // m=40 but argvals has 50
1597        let fdpar = make_bspline_fdpar(&t, 10, 1e-4);
1598        let res = smooth_basis(&data, &t, &fdpar);
1599        assert!(res.is_err());
1600    }
1601
1602    #[test]
1603    fn test_smooth_basis_nbasis_too_small() {
1604        let t = uniform_grid(50);
1605        let data = FdMatrix::zeros(3, 50);
1606        // nbasis = 1, which is below minimum of 2
1607        let fdpar = FdPar {
1608            basis_type: BasisType::Bspline { order: 4 },
1609            nbasis: 1,
1610            lambda: 1e-4,
1611            lfd_order: 2,
1612            penalty_matrix: vec![0.0; 1],
1613        };
1614        let res = smooth_basis(&data, &t, &fdpar);
1615        assert!(res.is_err());
1616    }
1617
1618    #[test]
1619    fn test_smooth_basis_error_is_invalid_dimension() {
1620        let t = uniform_grid(50);
1621        let data = FdMatrix::zeros(0, 50);
1622        let fdpar = make_bspline_fdpar(&t, 10, 1e-4);
1623        let err = smooth_basis(&data, &t, &fdpar).unwrap_err();
1624        match err {
1625            crate::FdarError::InvalidDimension { .. } => {} // expected
1626            other => panic!("Expected InvalidDimension, got {:?}", other),
1627        }
1628    }
1629
1630    // ─── Penalty matrix detailed tests ──────────────────────────────────────
1631
1632    #[test]
1633    fn test_bspline_penalty_matrix_different_orders() {
1634        let t = uniform_grid(101);
1635        // Order 1 penalty (penalize derivatives)
1636        let p1 = bspline_penalty_matrix(&t, 10, 4, 1);
1637        // Order 2 penalty (penalize curvature)
1638        let p2 = bspline_penalty_matrix(&t, 10, 4, 2);
1639        // Both should be square and same size
1640        assert_eq!(p1.len(), p2.len());
1641        // But they should differ
1642        let diff: f64 = p1.iter().zip(p2.iter()).map(|(a, b)| (a - b).abs()).sum();
1643        assert!(
1644            diff > 1e-10,
1645            "Different lfd_orders should produce different penalties"
1646        );
1647    }
1648
1649    #[test]
1650    fn test_bspline_penalty_matrix_edge_cases() {
1651        // Too few argvals
1652        let t = vec![0.0];
1653        let p = bspline_penalty_matrix(&t, 10, 4, 2);
1654        // Should return zero matrix
1655        assert!(p.iter().all(|&v| v == 0.0));
1656
1657        // nbasis < 2
1658        let t2 = uniform_grid(50);
1659        let p2 = bspline_penalty_matrix(&t2, 1, 4, 2);
1660        assert!(p2.iter().all(|&v| v == 0.0));
1661
1662        // lfd_order >= order
1663        let p3 = bspline_penalty_matrix(&t2, 10, 4, 4);
1664        assert!(p3.iter().all(|&v| v == 0.0));
1665    }
1666
1667    #[test]
1668    fn test_bspline_penalty_nonnegative_diagonal() {
1669        let t = uniform_grid(101);
1670        for nbasis in [5, 10, 20] {
1671            let p = bspline_penalty_matrix(&t, nbasis, 4, 2);
1672            let k = (p.len() as f64).sqrt() as usize;
1673            for i in 0..k {
1674                assert!(
1675                    p[i + i * k] >= -1e-10,
1676                    "Diagonal ({},{}) negative for nbasis={}: {}",
1677                    i,
1678                    i,
1679                    nbasis,
1680                    p[i + i * k]
1681                );
1682            }
1683        }
1684    }
1685
1686    #[test]
1687    fn test_fourier_penalty_increasing_with_frequency() {
1688        let penalty = fourier_penalty_matrix(11, 1.0, 2);
1689        let k = 11;
1690        // Constant term is zero
1691        assert!(penalty[0].abs() < 1e-15);
1692        // Pairs: (1,2) -> freq 1, (3,4) -> freq 2, etc.
1693        let mut prev_eigenval = 0.0;
1694        for freq in 1..=5 {
1695            let idx_sin = 2 * freq - 1;
1696            let eigenval = penalty[idx_sin + idx_sin * k];
1697            assert!(
1698                eigenval > prev_eigenval,
1699                "Higher frequency should have larger penalty: freq={}, eigenval={}, prev={}",
1700                freq,
1701                eigenval,
1702                prev_eigenval
1703            );
1704            prev_eigenval = eigenval;
1705            // cos and sin of same frequency should have same penalty
1706            let idx_cos = 2 * freq;
1707            if idx_cos < k {
1708                assert!(
1709                    (penalty[idx_cos + idx_cos * k] - eigenval).abs() < 1e-10,
1710                    "Sin and cos penalty should match at freq {}",
1711                    freq
1712                );
1713            }
1714        }
1715    }
1716
1717    #[test]
1718    fn test_fourier_penalty_different_periods() {
1719        let p1 = fourier_penalty_matrix(7, 1.0, 2);
1720        let p2 = fourier_penalty_matrix(7, 2.0, 2);
1721        // Longer period -> smaller omega -> smaller penalty eigenvalues
1722        for i in 1..7 {
1723            assert!(
1724                p2[i + i * 7] < p1[i + i * 7] || (p1[i + i * 7] == 0.0 && p2[i + i * 7] == 0.0),
1725                "Longer period should have smaller penalties at i={}",
1726                i
1727            );
1728        }
1729    }
1730
1731    #[test]
1732    fn test_fourier_penalty_first_order() {
1733        // lfd_order = 1: penalize first derivative
1734        let p = fourier_penalty_matrix(5, 1.0, 1);
1735        // Eigenvalues: (2*pi*freq)^2 for lfd_order=1
1736        let omega1 = 2.0 * PI;
1737        let expected1 = omega1.powi(2);
1738        assert!(
1739            (p[1 + 5] - expected1).abs() < 1e-6,
1740            "First-order penalty eigenval: got {}, expected {}",
1741            p[1 + 5],
1742            expected1
1743        );
1744    }
1745
1746    #[test]
1747    fn test_fourier_penalty_zero_nbasis() {
1748        let p = fourier_penalty_matrix(0, 1.0, 2);
1749        assert!(p.is_empty());
1750    }
1751
1752    #[test]
1753    fn test_fourier_penalty_nbasis_one() {
1754        let p = fourier_penalty_matrix(1, 1.0, 2);
1755        assert_eq!(p.len(), 1);
1756        assert!(p[0].abs() < 1e-15); // constant term has zero penalty
1757    }
1758
1759    // ─── smooth_basis_gcv detailed tests ────────────────────────────────────
1760
1761    #[test]
1762    fn test_smooth_basis_gcv_returns_valid_result() {
1763        let (data, t) = make_test_data(5, 50);
1764        let bt = BasisType::Bspline { order: 4 };
1765        let result = smooth_basis_gcv(&data, &t, &bt, 12, 2, (-6.0, 2.0), 20);
1766        assert!(result.is_some());
1767        let res = result.unwrap();
1768        assert_eq!(res.fitted.shape(), (5, 50));
1769        assert!(res.gcv.is_finite());
1770        assert!(res.edf > 0.0);
1771    }
1772
1773    #[test]
1774    fn test_smooth_basis_gcv_fourier() {
1775        let m = 80;
1776        let t = uniform_grid(m);
1777        let mut data = FdMatrix::zeros(3, m);
1778        for i in 0..3 {
1779            for j in 0..m {
1780                data[(i, j)] = (2.0 * PI * t[j]).sin() + 0.5 * (4.0 * PI * t[j]).cos();
1781            }
1782        }
1783        let bt = BasisType::Fourier { period: 1.0 };
1784        let result = smooth_basis_gcv(&data, &t, &bt, 9, 2, (-8.0, 4.0), 25);
1785        assert!(result.is_some());
1786        let res = result.unwrap();
1787        assert_eq!(res.fitted.nrows(), 3);
1788        assert_eq!(res.nbasis, 9);
1789    }
1790
1791    #[test]
1792    fn test_smooth_basis_gcv_selects_finite_gcv() {
1793        let (data, t) = make_test_data(5, 60);
1794        let bt = BasisType::Bspline { order: 4 };
1795        let res = smooth_basis_gcv(&data, &t, &bt, 12, 2, (-6.0, 2.0), 15).unwrap();
1796        assert!(res.gcv.is_finite());
1797        assert!(res.gcv > 0.0);
1798    }
1799
1800    #[test]
1801    fn test_smooth_basis_gcv_empty_data() {
1802        let data = FdMatrix::zeros(0, 50);
1803        let t = uniform_grid(50);
1804        let bt = BasisType::Bspline { order: 4 };
1805        let result = smooth_basis_gcv(&data, &t, &bt, 10, 2, (-6.0, 2.0), 10);
1806        // Should return None since smooth_basis will error for empty data
1807        assert!(result.is_none());
1808    }
1809
1810    #[test]
1811    fn test_smooth_basis_gcv_empty_argvals() {
1812        let data = FdMatrix::zeros(5, 0);
1813        let bt = BasisType::Bspline { order: 4 };
1814        let result = smooth_basis_gcv(&data, &[], &bt, 10, 2, (-6.0, 2.0), 10);
1815        assert!(result.is_none());
1816    }
1817
1818    #[test]
1819    fn test_smooth_basis_gcv_nbasis_too_small() {
1820        let (data, t) = make_test_data(5, 50);
1821        let bt = BasisType::Bspline { order: 4 };
1822        let result = smooth_basis_gcv(&data, &t, &bt, 1, 2, (-6.0, 2.0), 10);
1823        assert!(result.is_none());
1824    }
1825
1826    #[test]
1827    fn test_smooth_basis_gcv_ngrid_too_small() {
1828        let (data, t) = make_test_data(5, 50);
1829        let bt = BasisType::Bspline { order: 4 };
1830        let result = smooth_basis_gcv(&data, &t, &bt, 10, 2, (-6.0, 2.0), 1);
1831        assert!(result.is_none());
1832    }
1833
1834    #[test]
1835    fn test_smooth_basis_gcv_narrow_range() {
1836        let (data, t) = make_test_data(3, 50);
1837        let bt = BasisType::Bspline { order: 4 };
1838        // Very narrow search range
1839        let result = smooth_basis_gcv(&data, &t, &bt, 10, 2, (-3.0, -2.0), 5);
1840        assert!(result.is_some());
1841    }
1842
1843    #[test]
1844    fn test_smooth_basis_gcv_wide_range() {
1845        let (data, t) = make_test_data(3, 50);
1846        let bt = BasisType::Bspline { order: 4 };
1847        // Very wide search range
1848        let result = smooth_basis_gcv(&data, &t, &bt, 10, 2, (-12.0, 8.0), 30);
1849        assert!(result.is_some());
1850    }
1851
1852    // ─── basis_nbasis_cv detailed tests ─────────────────────────────────────
1853
1854    #[test]
1855    fn test_basis_nbasis_cv_scores_length() {
1856        let (data, t) = make_test_data(5, 50);
1857        let nbasis_range: Vec<usize> = vec![4, 6, 8, 10, 12];
1858        let res = basis_nbasis_cv(
1859            &data,
1860            &t,
1861            &nbasis_range,
1862            &BasisType::Bspline { order: 4 },
1863            BasisCriterion::Gcv,
1864            5,
1865            1e-4,
1866        )
1867        .unwrap();
1868        assert_eq!(res.scores.len(), 5);
1869        assert_eq!(res.nbasis_range.len(), 5);
1870        assert_eq!(res.nbasis_range, nbasis_range);
1871    }
1872
1873    #[test]
1874    fn test_basis_nbasis_cv_optimal_within_range() {
1875        let (data, t) = make_test_data(8, 50);
1876        let nbasis_range: Vec<usize> = vec![5, 7, 9, 11, 13, 15];
1877        for criterion in [
1878            BasisCriterion::Gcv,
1879            BasisCriterion::Aic,
1880            BasisCriterion::Bic,
1881        ] {
1882            let res = basis_nbasis_cv(
1883                &data,
1884                &t,
1885                &nbasis_range,
1886                &BasisType::Bspline { order: 4 },
1887                criterion,
1888                5,
1889                1e-4,
1890            )
1891            .unwrap();
1892            assert!(
1893                nbasis_range.contains(&res.optimal_nbasis),
1894                "optimal_nbasis {} not in range for {:?}",
1895                res.optimal_nbasis,
1896                criterion
1897            );
1898        }
1899    }
1900
1901    #[test]
1902    fn test_basis_nbasis_cv_fourier_gcv() {
1903        let m = 80;
1904        let t = uniform_grid(m);
1905        let mut data = FdMatrix::zeros(5, m);
1906        for i in 0..5 {
1907            for j in 0..m {
1908                data[(i, j)] = (2.0 * PI * t[j]).sin()
1909                    + 0.3 * (4.0 * PI * t[j]).cos()
1910                    + 0.02 * ((i * 7 + j * 3) % 10) as f64;
1911            }
1912        }
1913        let nbasis_range: Vec<usize> = vec![5, 7, 9, 11];
1914        let res = basis_nbasis_cv(
1915            &data,
1916            &t,
1917            &nbasis_range,
1918            &BasisType::Fourier { period: 1.0 },
1919            BasisCriterion::Gcv,
1920            5,
1921            1e-4,
1922        )
1923        .unwrap();
1924        assert!(nbasis_range.contains(&res.optimal_nbasis));
1925    }
1926
1927    #[test]
1928    fn test_basis_nbasis_cv_fourier_cv() {
1929        let m = 60;
1930        let t = uniform_grid(m);
1931        let n = 10;
1932        let mut data = FdMatrix::zeros(n, m);
1933        for i in 0..n {
1934            for j in 0..m {
1935                data[(i, j)] = (2.0 * PI * t[j]).sin() + 0.02 * ((i * 11 + j) % 15) as f64;
1936            }
1937        }
1938        let nbasis_range: Vec<usize> = vec![5, 7, 9];
1939        let res = basis_nbasis_cv(
1940            &data,
1941            &t,
1942            &nbasis_range,
1943            &BasisType::Fourier { period: 1.0 },
1944            BasisCriterion::Cv,
1945            5,
1946            1e-4,
1947        )
1948        .unwrap();
1949        assert!(nbasis_range.contains(&res.optimal_nbasis));
1950        assert_eq!(res.criterion, BasisCriterion::Cv);
1951    }
1952
1953    #[test]
1954    fn test_basis_nbasis_cv_with_nbasis_below_minimum() {
1955        // Range includes nbasis = 1 which is invalid
1956        let (data, t) = make_test_data(5, 50);
1957        let nbasis_range: Vec<usize> = vec![1, 5, 10];
1958        let res = basis_nbasis_cv(
1959            &data,
1960            &t,
1961            &nbasis_range,
1962            &BasisType::Bspline { order: 4 },
1963            BasisCriterion::Gcv,
1964            5,
1965            1e-4,
1966        )
1967        .unwrap();
1968        // Score for nbasis=1 should be infinity, so optimal should be 5 or 10
1969        assert!(
1970            res.optimal_nbasis >= 5,
1971            "Should skip invalid nbasis=1, got optimal={}",
1972            res.optimal_nbasis
1973        );
1974        assert!(res.scores[0].is_infinite());
1975    }
1976
1977    #[test]
1978    fn test_basis_nbasis_cv_empty_range() {
1979        let (data, t) = make_test_data(5, 50);
1980        let nbasis_range: Vec<usize> = vec![];
1981        let result = basis_nbasis_cv(
1982            &data,
1983            &t,
1984            &nbasis_range,
1985            &BasisType::Bspline { order: 4 },
1986            BasisCriterion::Gcv,
1987            5,
1988            1e-4,
1989        );
1990        assert!(result.is_none());
1991    }
1992
1993    #[test]
1994    fn test_basis_nbasis_cv_empty_data() {
1995        let data = FdMatrix::zeros(0, 50);
1996        let t = uniform_grid(50);
1997        let nbasis_range: Vec<usize> = vec![5, 10];
1998        let result = basis_nbasis_cv(
1999            &data,
2000            &t,
2001            &nbasis_range,
2002            &BasisType::Bspline { order: 4 },
2003            BasisCriterion::Gcv,
2004            5,
2005            1e-4,
2006        );
2007        assert!(result.is_none());
2008    }
2009
2010    #[test]
2011    fn test_basis_nbasis_cv_mismatched_argvals() {
2012        let data = FdMatrix::zeros(5, 50);
2013        let t = uniform_grid(40); // mismatch
2014        let nbasis_range: Vec<usize> = vec![5, 10];
2015        let result = basis_nbasis_cv(
2016            &data,
2017            &t,
2018            &nbasis_range,
2019            &BasisType::Bspline { order: 4 },
2020            BasisCriterion::Gcv,
2021            5,
2022            1e-4,
2023        );
2024        assert!(result.is_none());
2025    }
2026
2027    #[test]
2028    fn test_basis_nbasis_cv_single_nbasis() {
2029        let (data, t) = make_test_data(5, 50);
2030        let nbasis_range: Vec<usize> = vec![10];
2031        let res = basis_nbasis_cv(
2032            &data,
2033            &t,
2034            &nbasis_range,
2035            &BasisType::Bspline { order: 4 },
2036            BasisCriterion::Gcv,
2037            5,
2038            1e-4,
2039        )
2040        .unwrap();
2041        assert_eq!(res.optimal_nbasis, 10);
2042        assert_eq!(res.scores.len(), 1);
2043    }
2044
2045    #[test]
2046    fn test_basis_nbasis_cv_bic_penalizes_more_than_aic() {
2047        // BIC penalizes complexity more heavily than AIC, so it should generally
2048        // select the same or fewer basis functions
2049        let (data, t) = make_test_data(5, 80);
2050        let nbasis_range: Vec<usize> = (4..=20).step_by(2).collect();
2051
2052        let aic_res = basis_nbasis_cv(
2053            &data,
2054            &t,
2055            &nbasis_range,
2056            &BasisType::Bspline { order: 4 },
2057            BasisCriterion::Aic,
2058            5,
2059            1e-4,
2060        )
2061        .unwrap();
2062        let bic_res = basis_nbasis_cv(
2063            &data,
2064            &t,
2065            &nbasis_range,
2066            &BasisType::Bspline { order: 4 },
2067            BasisCriterion::Bic,
2068            5,
2069            1e-4,
2070        )
2071        .unwrap();
2072        // BIC should select at most as many basis functions as AIC
2073        // (not guaranteed in all cases, but typical behavior)
2074        assert!(
2075            bic_res.optimal_nbasis <= aic_res.optimal_nbasis + 4,
2076            "BIC selected {} vs AIC selected {} -- BIC should not select much more than AIC",
2077            bic_res.optimal_nbasis,
2078            aic_res.optimal_nbasis
2079        );
2080    }
2081
2082    // ─── Fitted values quality tests ────────────────────────────────────────
2083
2084    #[test]
2085    fn test_smooth_basis_fitted_close_to_data() {
2086        // With moderate penalty and enough basis functions, fitted should be close to data
2087        let m = 50;
2088        let n = 3;
2089        let t = uniform_grid(m);
2090        let mut data = FdMatrix::zeros(n, m);
2091        for i in 0..n {
2092            for j in 0..m {
2093                data[(i, j)] = (2.0 * PI * t[j]).sin();
2094            }
2095        }
2096        let fdpar = make_bspline_fdpar(&t, 15, 1e-6);
2097        let res = smooth_basis(&data, &t, &fdpar).unwrap();
2098
2099        let mut max_err = 0.0_f64;
2100        for i in 0..n {
2101            for j in 0..m {
2102                let err = (data[(i, j)] - res.fitted[(i, j)]).abs();
2103                max_err = max_err.max(err);
2104            }
2105        }
2106        assert!(
2107            max_err < 0.1,
2108            "Fitted should be close to smooth data; max_err={}",
2109            max_err
2110        );
2111    }
2112
2113    #[test]
2114    fn test_smooth_basis_constant_data() {
2115        // Constant data should be fit exactly
2116        let m = 50;
2117        let n = 2;
2118        let t = uniform_grid(m);
2119        let mut data = FdMatrix::zeros(n, m);
2120        for i in 0..n {
2121            for j in 0..m {
2122                data[(i, j)] = 3.15;
2123            }
2124        }
2125        let fdpar = make_bspline_fdpar(&t, 10, 1e-4);
2126        let res = smooth_basis(&data, &t, &fdpar).unwrap();
2127        for i in 0..n {
2128            for j in 0..m {
2129                assert!(
2130                    (res.fitted[(i, j)] - 3.15).abs() < 0.01,
2131                    "Constant data should be fit well at ({},{}): got {}",
2132                    i,
2133                    j,
2134                    res.fitted[(i, j)]
2135                );
2136            }
2137        }
2138    }
2139
2140    #[test]
2141    fn test_smooth_basis_linear_data() {
2142        // Linear data should be fit well with cubic B-splines
2143        let m = 50;
2144        let t = uniform_grid(m);
2145        let mut data = FdMatrix::zeros(1, m);
2146        for j in 0..m {
2147            data[(0, j)] = 2.0 * t[j] + 1.0;
2148        }
2149        let fdpar = make_bspline_fdpar(&t, 10, 1e-4);
2150        let res = smooth_basis(&data, &t, &fdpar).unwrap();
2151        for j in 0..m {
2152            let expected = 2.0 * t[j] + 1.0;
2153            assert!(
2154                (res.fitted[(0, j)] - expected).abs() < 0.05,
2155                "Linear data should be fit well at j={}: got {}, expected {}",
2156                j,
2157                res.fitted[(0, j)],
2158                expected
2159            );
2160        }
2161    }
2162
2163    // ─── EDF and diagnostic tests ───────────────────────────────────────────
2164
2165    #[test]
2166    fn test_smooth_basis_edf_bounded() {
2167        let m = 50;
2168        let (data, t) = make_test_data(3, m);
2169        let fdpar = make_bspline_fdpar(&t, 12, 1e-4);
2170        let res = smooth_basis(&data, &t, &fdpar).unwrap();
2171        // EDF should be between 1 and m (evaluation points)
2172        assert!(
2173            res.edf > 0.0 && res.edf <= m as f64,
2174            "EDF should be in (0, {}]; got {}",
2175            m,
2176            res.edf
2177        );
2178    }
2179
2180    #[test]
2181    fn test_smooth_basis_gcv_aic_bic_all_finite() {
2182        let (data, t) = make_test_data(4, 60);
2183        let fdpar = make_bspline_fdpar(&t, 12, 1e-3);
2184        let res = smooth_basis(&data, &t, &fdpar).unwrap();
2185        assert!(res.gcv.is_finite(), "GCV should be finite: {}", res.gcv);
2186        assert!(res.aic.is_finite(), "AIC should be finite: {}", res.aic);
2187        assert!(res.bic.is_finite(), "BIC should be finite: {}", res.bic);
2188    }
2189
2190    // ─── Penalty matrix size consistency tests ──────────────────────────────
2191
2192    #[test]
2193    fn test_smooth_basis_penalty_matrix_in_result() {
2194        let (data, t) = make_test_data(3, 50);
2195        let nbasis = 10;
2196        let fdpar = make_bspline_fdpar(&t, nbasis, 1e-4);
2197        let res = smooth_basis(&data, &t, &fdpar).unwrap();
2198        let k = res.nbasis;
2199        assert_eq!(
2200            res.penalty_matrix.len(),
2201            k * k,
2202            "Penalty matrix should be k*k = {}*{} = {}; got {}",
2203            k,
2204            k,
2205            k * k,
2206            res.penalty_matrix.len()
2207        );
2208    }
2209
2210    // ─── Regression: multiple identical curves ──────────────────────────────
2211
2212    #[test]
2213    fn test_smooth_basis_identical_curves_same_coefficients() {
2214        let m = 50;
2215        let t = uniform_grid(m);
2216        let curve: Vec<f64> = (0..m).map(|j| (2.0 * PI * t[j]).sin()).collect();
2217        let n = 4;
2218        let mut data = FdMatrix::zeros(n, m);
2219        for i in 0..n {
2220            for j in 0..m {
2221                data[(i, j)] = curve[j];
2222            }
2223        }
2224        let fdpar = make_bspline_fdpar(&t, 10, 1e-4);
2225        let res = smooth_basis(&data, &t, &fdpar).unwrap();
2226
2227        // All curves should have the same coefficients
2228        let k = res.coefficients.ncols();
2229        for i in 1..n {
2230            for j in 0..k {
2231                assert!(
2232                    (res.coefficients[(i, j)] - res.coefficients[(0, j)]).abs() < 1e-10,
2233                    "Identical curves should have identical coefficients: curve {} col {} differs",
2234                    i,
2235                    j
2236                );
2237            }
2238        }
2239    }
2240
2241    // ─── Cross-validation: different numbers of folds ───────────────────────
2242
2243    #[test]
2244    fn test_basis_nbasis_cv_different_nfolds() {
2245        let (data, t) = make_test_data(12, 50);
2246        let nbasis_range: Vec<usize> = vec![5, 8, 11];
2247        for nfolds in [2, 3, 5, 10] {
2248            let res = basis_nbasis_cv(
2249                &data,
2250                &t,
2251                &nbasis_range,
2252                &BasisType::Bspline { order: 4 },
2253                BasisCriterion::Cv,
2254                nfolds,
2255                1e-4,
2256            );
2257            assert!(res.is_some(), "CV should succeed with nfolds={}", nfolds);
2258            let r = res.unwrap();
2259            assert!(nbasis_range.contains(&r.optimal_nbasis));
2260        }
2261    }
2262
2263    // ─── Large nbasis / more basis than reasonable ──────────────────────────
2264
2265    #[test]
2266    fn test_smooth_basis_many_basis_functions() {
2267        let m = 100;
2268        let (data, t) = make_test_data(2, m);
2269        // Many basis functions relative to data points
2270        let fdpar = make_bspline_fdpar(&t, 40, 1e-2);
2271        let res = smooth_basis(&data, &t, &fdpar);
2272        assert!(
2273            res.is_ok(),
2274            "Should handle many basis functions with penalty"
2275        );
2276    }
2277
2278    // ─── evaluate_basis internal function (indirectly tested) ───────────────
2279
2280    #[test]
2281    fn test_smooth_basis_bspline_vs_fourier_different_results() {
2282        let m = 50;
2283        let (data, t) = make_test_data(2, m);
2284        let fdpar_bs = make_bspline_fdpar(&t, 9, 1e-4);
2285        let fdpar_f = make_fourier_fdpar(9, 1.0, 1e-4);
2286        let res_bs = smooth_basis(&data, &t, &fdpar_bs).unwrap();
2287        let res_f = smooth_basis(&data, &t, &fdpar_f).unwrap();
2288        // Results should differ between the two basis types
2289        let diff: f64 = (0..m)
2290            .map(|j| (res_bs.fitted[(0, j)] - res_f.fitted[(0, j)]).abs())
2291            .sum();
2292        // They fit the same data, so some difference is expected but not huge
2293        assert!(
2294            diff > 1e-10,
2295            "B-spline and Fourier fits should differ for the same data"
2296        );
2297    }
2298
2299    // ─── compute_gcv edge cases (indirectly tested) ─────────────────────────
2300
2301    #[test]
2302    fn test_smooth_basis_gcv_positive_for_noisy_data() {
2303        let m = 50;
2304        let t = uniform_grid(m);
2305        let mut data = FdMatrix::zeros(1, m);
2306        for j in 0..m {
2307            // Noisy data
2308            data[(0, j)] = (2.0 * PI * t[j]).sin() + 0.5 * ((j * 37) % 20) as f64 / 20.0 - 0.25;
2309        }
2310        let fdpar = make_bspline_fdpar(&t, 10, 1e-3);
2311        let res = smooth_basis(&data, &t, &fdpar).unwrap();
2312        assert!(res.gcv > 0.0, "GCV should be positive for noisy data");
2313    }
2314
2315    // ─── Penalty order (lfd_order) tests ────────────────────────────────────
2316
2317    #[test]
2318    fn test_smooth_basis_different_lfd_orders() {
2319        let m = 50;
2320        let (data, t) = make_test_data(2, m);
2321
2322        // lfd_order = 1 (penalize first derivative)
2323        let penalty1 = bspline_penalty_matrix(&t, 10, 4, 1);
2324        let fdpar1 = FdPar {
2325            basis_type: BasisType::Bspline { order: 4 },
2326            nbasis: 10,
2327            lambda: 1e-2,
2328            lfd_order: 1,
2329            penalty_matrix: penalty1,
2330        };
2331        let res1 = smooth_basis(&data, &t, &fdpar1);
2332        assert!(res1.is_ok());
2333
2334        // lfd_order = 2 (penalize second derivative)
2335        let penalty2 = bspline_penalty_matrix(&t, 10, 4, 2);
2336        let fdpar2 = FdPar {
2337            basis_type: BasisType::Bspline { order: 4 },
2338            nbasis: 10,
2339            lambda: 1e-2,
2340            lfd_order: 2,
2341            penalty_matrix: penalty2,
2342        };
2343        let res2 = smooth_basis(&data, &t, &fdpar2);
2344        assert!(res2.is_ok());
2345
2346        // Different penalty orders should produce different fitted values
2347        let r1 = res1.unwrap();
2348        let r2 = res2.unwrap();
2349        let diff: f64 = (0..m)
2350            .map(|j| (r1.fitted[(0, j)] - r2.fitted[(0, j)]).abs())
2351            .sum();
2352        assert!(
2353            diff > 1e-10,
2354            "Different lfd_orders should produce different fits"
2355        );
2356    }
2357
2358    // ─── BasisNbasisCvResult field tests ────────────────────────────────────
2359
2360    #[test]
2361    fn test_basis_nbasis_cv_result_fields() {
2362        let (data, t) = make_test_data(6, 50);
2363        let nbasis_range: Vec<usize> = vec![5, 7, 9, 11, 13];
2364        let res = basis_nbasis_cv(
2365            &data,
2366            &t,
2367            &nbasis_range,
2368            &BasisType::Bspline { order: 4 },
2369            BasisCriterion::Aic,
2370            5,
2371            1e-4,
2372        )
2373        .unwrap();
2374
2375        assert!(nbasis_range.contains(&res.optimal_nbasis));
2376        assert_eq!(res.scores.len(), nbasis_range.len());
2377        assert_eq!(res.nbasis_range, nbasis_range);
2378        assert_eq!(res.criterion, BasisCriterion::Aic);
2379        // optimal_nbasis should correspond to minimum score
2380        let min_score = res.scores.iter().copied().fold(f64::INFINITY, f64::min);
2381        let best_idx = res
2382            .scores
2383            .iter()
2384            .position(|&s| (s - min_score).abs() < 1e-15)
2385            .unwrap();
2386        assert_eq!(res.optimal_nbasis, nbasis_range[best_idx]);
2387    }
2388
2389    #[test]
2390    fn test_basis_nbasis_cv_result_clone() {
2391        let (data, t) = make_test_data(5, 50);
2392        let nbasis_range: Vec<usize> = vec![5, 10];
2393        let res = basis_nbasis_cv(
2394            &data,
2395            &t,
2396            &nbasis_range,
2397            &BasisType::Bspline { order: 4 },
2398            BasisCriterion::Gcv,
2399            5,
2400            1e-4,
2401        )
2402        .unwrap();
2403        let cloned = res.clone();
2404        assert_eq!(res, cloned);
2405    }
2406
2407    // ─── Non-uniform argvals ────────────────────────────────────────────────
2408
2409    #[test]
2410    fn test_smooth_basis_nonuniform_argvals() {
2411        let m = 50;
2412        // Non-uniform grid: denser at the ends
2413        let t: Vec<f64> = (0..m)
2414            .map(|i| {
2415                let x = i as f64 / (m - 1) as f64;
2416                0.5 * (1.0 - (PI * x).cos())
2417            })
2418            .collect();
2419        let mut data = FdMatrix::zeros(2, m);
2420        for i in 0..2 {
2421            for j in 0..m {
2422                data[(i, j)] = (2.0 * PI * t[j]).sin() + 0.1 * i as f64;
2423            }
2424        }
2425        let fdpar = make_bspline_fdpar(&t, 10, 1e-4);
2426        let res = smooth_basis(&data, &t, &fdpar);
2427        assert!(res.is_ok(), "Should handle non-uniform argvals");
2428        let r = res.unwrap();
2429        assert_eq!(r.fitted.shape(), (2, m));
2430    }
2431
2432    // ─── Numerical stability with extreme lambda ────────────────────────────
2433
2434    #[test]
2435    fn test_smooth_basis_very_small_lambda() {
2436        let m = 50;
2437        let (data, t) = make_test_data(2, m);
2438        let fdpar = make_bspline_fdpar(&t, 10, 1e-15);
2439        let res = smooth_basis(&data, &t, &fdpar);
2440        assert!(res.is_ok(), "Should handle very small lambda");
2441    }
2442
2443    #[test]
2444    fn test_smooth_basis_very_large_lambda() {
2445        let m = 50;
2446        let (data, t) = make_test_data(2, m);
2447        let fdpar = make_bspline_fdpar(&t, 10, 1e10);
2448        let res = smooth_basis(&data, &t, &fdpar);
2449        assert!(res.is_ok(), "Should handle very large lambda");
2450    }
2451
2452    // ─── Multiple curves consistency ────────────────────────────────────────
2453
2454    #[test]
2455    fn test_smooth_basis_multi_curve_vs_single_curve() {
2456        // Smoothing multiple curves at once should give the same result as smoothing each individually
2457        let m = 50;
2458        let n = 3;
2459        let (data, t) = make_test_data(n, m);
2460        let fdpar = make_bspline_fdpar(&t, 10, 1e-3);
2461
2462        // All at once
2463        let res_all = smooth_basis(&data, &t, &fdpar).unwrap();
2464
2465        // One at a time
2466        for i in 0..n {
2467            let mut single = FdMatrix::zeros(1, m);
2468            for j in 0..m {
2469                single[(0, j)] = data[(i, j)];
2470            }
2471            let res_single = smooth_basis(&single, &t, &fdpar).unwrap();
2472            for j in 0..m {
2473                assert!(
2474                    (res_all.fitted[(i, j)] - res_single.fitted[(0, j)]).abs() < 1e-10,
2475                    "Multi-curve fit should match single-curve fit: curve {} point {}",
2476                    i,
2477                    j
2478                );
2479            }
2480        }
2481    }
2482
2483    // ─── BasisCriterion comparison: all criteria produce finite scores ──────
2484
2485    #[test]
2486    fn test_basis_nbasis_cv_all_criteria_finite_scores() {
2487        let (data, t) = make_test_data(10, 60);
2488        let nbasis_range: Vec<usize> = vec![5, 7, 9, 11];
2489
2490        for criterion in [
2491            BasisCriterion::Gcv,
2492            BasisCriterion::Aic,
2493            BasisCriterion::Bic,
2494            BasisCriterion::Cv,
2495        ] {
2496            let res = basis_nbasis_cv(
2497                &data,
2498                &t,
2499                &nbasis_range,
2500                &BasisType::Bspline { order: 4 },
2501                criterion,
2502                5,
2503                1e-4,
2504            )
2505            .unwrap();
2506            // At least some scores should be finite (valid nbasis values)
2507            let finite_count = res.scores.iter().filter(|s| s.is_finite()).count();
2508            assert!(
2509                finite_count > 0,
2510                "At least one score should be finite for {:?}",
2511                criterion
2512            );
2513        }
2514    }
2515
2516    // ─── SmoothBasisGcvConfig tests ────────────────────────────────────────
2517
2518    #[test]
2519    fn test_smooth_basis_gcv_config_default() {
2520        let config = SmoothBasisGcvConfig::default();
2521        assert_eq!(config.basis_type, BasisType::Bspline { order: 4 });
2522        assert_eq!(config.nbasis, 15);
2523        assert_eq!(config.lfd_order, 2);
2524        assert_eq!(config.log_lambda_range, (-10.0, 2.0));
2525        assert_eq!(config.n_grid, 50);
2526    }
2527
2528    #[test]
2529    fn test_smooth_basis_gcv_config_clone_eq() {
2530        let config = SmoothBasisGcvConfig {
2531            nbasis: 20,
2532            ..SmoothBasisGcvConfig::default()
2533        };
2534        let cloned = config.clone();
2535        assert_eq!(config, cloned);
2536    }
2537
2538    #[test]
2539    fn test_smooth_basis_gcv_config_debug() {
2540        let config = SmoothBasisGcvConfig::default();
2541        let debug_str = format!("{:?}", config);
2542        assert!(debug_str.contains("SmoothBasisGcvConfig"));
2543        assert!(debug_str.contains("nbasis"));
2544    }
2545
2546    #[test]
2547    fn test_smooth_basis_gcv_config_partial_override() {
2548        let config = SmoothBasisGcvConfig {
2549            basis_type: BasisType::Fourier { period: 2.0 },
2550            n_grid: 100,
2551            ..SmoothBasisGcvConfig::default()
2552        };
2553        assert_eq!(config.basis_type, BasisType::Fourier { period: 2.0 });
2554        assert_eq!(config.n_grid, 100);
2555        // defaults preserved
2556        assert_eq!(config.nbasis, 15);
2557        assert_eq!(config.lfd_order, 2);
2558    }
2559
2560    #[test]
2561    fn test_smooth_basis_gcv_with_config_default() {
2562        let (data, t) = make_test_data(5, 101);
2563        let config = SmoothBasisGcvConfig::default();
2564        let result = smooth_basis_gcv_with_config(&data, &t, &config);
2565        assert!(result.is_ok(), "GCV with default config should succeed");
2566        let res = result.unwrap();
2567        assert_eq!(res.fitted.shape(), (5, 101));
2568        assert!(res.edf > 0.0);
2569        assert!(res.gcv.is_finite());
2570    }
2571
2572    #[test]
2573    fn test_smooth_basis_gcv_with_config_custom() {
2574        let (data, t) = make_test_data(3, 50);
2575        let config = SmoothBasisGcvConfig {
2576            nbasis: 10,
2577            log_lambda_range: (-6.0, 0.0),
2578            n_grid: 15,
2579            ..SmoothBasisGcvConfig::default()
2580        };
2581        let result = smooth_basis_gcv_with_config(&data, &t, &config);
2582        assert!(result.is_ok());
2583    }
2584
2585    #[test]
2586    fn test_smooth_basis_gcv_with_config_matches_direct() {
2587        let (data, t) = make_test_data(3, 50);
2588        let config = SmoothBasisGcvConfig {
2589            nbasis: 10,
2590            log_lambda_range: (-6.0, 0.0),
2591            n_grid: 20,
2592            ..SmoothBasisGcvConfig::default()
2593        };
2594        let with_config = smooth_basis_gcv_with_config(&data, &t, &config).unwrap();
2595        let direct = smooth_basis_gcv(
2596            &data,
2597            &t,
2598            &config.basis_type,
2599            config.nbasis,
2600            config.lfd_order,
2601            config.log_lambda_range,
2602            config.n_grid,
2603        )
2604        .unwrap();
2605        assert_eq!(with_config.gcv, direct.gcv);
2606        assert_eq!(with_config.edf, direct.edf);
2607        assert_eq!(with_config.nbasis, direct.nbasis);
2608    }
2609
2610    #[test]
2611    fn test_smooth_basis_gcv_with_config_fourier() {
2612        let m = 100;
2613        let t = uniform_grid(m);
2614        let mut data = FdMatrix::zeros(2, m);
2615        for i in 0..2 {
2616            for j in 0..m {
2617                data[(i, j)] = (2.0 * PI * t[j]).sin() + (4.0 * PI * t[j]).cos();
2618            }
2619        }
2620        let config = SmoothBasisGcvConfig {
2621            basis_type: BasisType::Fourier { period: 1.0 },
2622            nbasis: 7,
2623            n_grid: 20,
2624            ..SmoothBasisGcvConfig::default()
2625        };
2626        let result = smooth_basis_gcv_with_config(&data, &t, &config);
2627        assert!(result.is_ok());
2628    }
2629
2630    // ─── BasisNbasisCvConfig tests ─────────────────────────────────────────
2631
2632    #[test]
2633    fn test_basis_nbasis_cv_config_default() {
2634        let config = BasisNbasisCvConfig::default();
2635        assert_eq!(config.basis_type, BasisType::Bspline { order: 4 });
2636        assert_eq!(config.nbasis_range, (5, 30));
2637        assert!((config.lambda - 1e-4).abs() < 1e-15);
2638        assert_eq!(config.lfd_order, 2);
2639        assert_eq!(config.n_folds, 5);
2640        assert_eq!(config.criterion, BasisCriterion::Gcv);
2641    }
2642
2643    #[test]
2644    fn test_basis_nbasis_cv_config_clone_eq() {
2645        let config = BasisNbasisCvConfig {
2646            nbasis_range: (4, 15),
2647            ..BasisNbasisCvConfig::default()
2648        };
2649        let cloned = config.clone();
2650        assert_eq!(config, cloned);
2651    }
2652
2653    #[test]
2654    fn test_basis_nbasis_cv_config_debug() {
2655        let config = BasisNbasisCvConfig::default();
2656        let debug_str = format!("{:?}", config);
2657        assert!(debug_str.contains("BasisNbasisCvConfig"));
2658        assert!(debug_str.contains("nbasis_range"));
2659    }
2660
2661    #[test]
2662    fn test_basis_nbasis_cv_config_partial_override() {
2663        let config = BasisNbasisCvConfig {
2664            criterion: BasisCriterion::Aic,
2665            lambda: 1e-2,
2666            ..BasisNbasisCvConfig::default()
2667        };
2668        assert_eq!(config.criterion, BasisCriterion::Aic);
2669        assert!((config.lambda - 1e-2).abs() < 1e-15);
2670        // defaults preserved
2671        assert_eq!(config.nbasis_range, (5, 30));
2672        assert_eq!(config.n_folds, 5);
2673    }
2674
2675    #[test]
2676    fn test_basis_nbasis_cv_with_config_default() {
2677        let (data, t) = make_test_data(5, 51);
2678        let config = BasisNbasisCvConfig {
2679            nbasis_range: (5, 12),
2680            ..BasisNbasisCvConfig::default()
2681        };
2682        let result = basis_nbasis_cv_with_config(&data, &t, &config);
2683        assert!(
2684            result.is_ok(),
2685            "nbasis CV with default config should succeed"
2686        );
2687        let res = result.unwrap();
2688        assert!(res.optimal_nbasis >= 5 && res.optimal_nbasis <= 12);
2689        assert_eq!(res.scores.len(), 8); // 5..=12 = 8 values
2690        assert_eq!(res.criterion, BasisCriterion::Gcv);
2691    }
2692
2693    #[test]
2694    fn test_basis_nbasis_cv_with_config_aic() {
2695        let (data, t) = make_test_data(5, 51);
2696        let config = BasisNbasisCvConfig {
2697            nbasis_range: (5, 10),
2698            criterion: BasisCriterion::Aic,
2699            ..BasisNbasisCvConfig::default()
2700        };
2701        let result = basis_nbasis_cv_with_config(&data, &t, &config);
2702        assert!(result.is_ok());
2703        assert_eq!(result.unwrap().criterion, BasisCriterion::Aic);
2704    }
2705
2706    #[test]
2707    fn test_basis_nbasis_cv_with_config_cv_folds() {
2708        let (data, t) = make_test_data(10, 51);
2709        let config = BasisNbasisCvConfig {
2710            nbasis_range: (5, 9),
2711            criterion: BasisCriterion::Cv,
2712            n_folds: 3,
2713            ..BasisNbasisCvConfig::default()
2714        };
2715        let result = basis_nbasis_cv_with_config(&data, &t, &config);
2716        assert!(result.is_ok());
2717        assert_eq!(result.unwrap().criterion, BasisCriterion::Cv);
2718    }
2719
2720    #[test]
2721    fn test_basis_nbasis_cv_with_config_matches_direct() {
2722        let (data, t) = make_test_data(5, 51);
2723        let config = BasisNbasisCvConfig {
2724            nbasis_range: (5, 10),
2725            criterion: BasisCriterion::Bic,
2726            lambda: 1e-3,
2727            ..BasisNbasisCvConfig::default()
2728        };
2729        let with_config = basis_nbasis_cv_with_config(&data, &t, &config).unwrap();
2730        let nbasis_range: Vec<usize> = (5..=10).collect();
2731        let direct = basis_nbasis_cv(
2732            &data,
2733            &t,
2734            &nbasis_range,
2735            &config.basis_type,
2736            config.criterion,
2737            config.n_folds,
2738            config.lambda,
2739        )
2740        .unwrap();
2741        assert_eq!(with_config.optimal_nbasis, direct.optimal_nbasis);
2742        assert_eq!(with_config.scores, direct.scores);
2743        assert_eq!(with_config.nbasis_range, direct.nbasis_range);
2744    }
2745
2746    #[test]
2747    fn test_basis_nbasis_cv_with_config_nbasis_range_expansion() {
2748        let (data, t) = make_test_data(5, 51);
2749        let config = BasisNbasisCvConfig {
2750            nbasis_range: (7, 7), // single value
2751            ..BasisNbasisCvConfig::default()
2752        };
2753        let result = basis_nbasis_cv_with_config(&data, &t, &config);
2754        assert!(result.is_ok());
2755        let res = result.unwrap();
2756        assert_eq!(res.optimal_nbasis, 7);
2757        assert_eq!(res.scores.len(), 1);
2758    }
2759}