use crate::basis::{bspline_basis, fourier_basis_with_period};
use crate::helpers::simpsons_weights;
use crate::matrix::FdMatrix;
use nalgebra::DMatrix;
use std::f64::consts::PI;
#[derive(Debug, Clone, PartialEq)]
pub enum BasisType {
Bspline { order: usize },
Fourier { period: f64 },
}
#[derive(Debug, Clone, PartialEq)]
pub struct FdPar {
pub basis_type: BasisType,
pub nbasis: usize,
pub lambda: f64,
pub lfd_order: usize,
pub penalty_matrix: Vec<f64>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct SmoothBasisResult {
pub coefficients: FdMatrix,
pub fitted: FdMatrix,
pub edf: f64,
pub gcv: f64,
pub aic: f64,
pub bic: f64,
pub penalty_matrix: Vec<f64>,
pub nbasis: usize,
}
pub fn bspline_penalty_matrix(
argvals: &[f64],
nbasis: usize,
order: usize,
lfd_order: usize,
) -> Vec<f64> {
if nbasis < 2 || order < 1 || lfd_order >= order || argvals.len() < 2 {
return vec![0.0; nbasis * nbasis];
}
let nknots = nbasis.saturating_sub(order).max(2);
let n_sub = 10;
let t_min = argvals[0];
let t_max = argvals[argvals.len() - 1];
let n_quad = (argvals.len() - 1) * n_sub + 1;
let quad_t: Vec<f64> = (0..n_quad)
.map(|i| t_min + (t_max - t_min) * i as f64 / (n_quad - 1) as f64)
.collect();
let basis_fine = bspline_basis(&quad_t, nknots, order);
let actual_nbasis = basis_fine.len() / n_quad;
let h = (t_max - t_min) / (n_quad - 1) as f64;
let deriv_basis = differentiate_basis_columns(&basis_fine, n_quad, actual_nbasis, h, lfd_order);
let weights = simpsons_weights(&quad_t);
integrate_symmetric_penalty(&deriv_basis, &weights, actual_nbasis, n_quad)
}
pub fn fourier_penalty_matrix(nbasis: usize, period: f64, lfd_order: usize) -> Vec<f64> {
let k = nbasis;
let mut penalty = vec![0.0; k * k];
let mut freq = 1;
let mut idx = 1;
while idx < k {
let omega = 2.0 * PI * f64::from(freq) / period;
let eigenval = omega.powi(2 * lfd_order as i32);
if idx < k {
penalty[idx + idx * k] = eigenval;
idx += 1;
}
if idx < k {
penalty[idx + idx * k] = eigenval;
idx += 1;
}
freq += 1;
}
penalty
}
pub fn smooth_basis(
data: &FdMatrix,
argvals: &[f64],
fdpar: &FdPar,
) -> Result<SmoothBasisResult, crate::FdarError> {
let (n, m) = data.shape();
if n == 0 || m == 0 || argvals.len() != m || fdpar.nbasis < 2 {
return Err(crate::FdarError::InvalidDimension {
parameter: "data/argvals/fdpar",
expected: "n > 0, m > 0, argvals.len() == m, nbasis >= 2".to_string(),
actual: format!(
"n={}, m={}, argvals.len()={}, nbasis={}",
n,
m,
argvals.len(),
fdpar.nbasis
),
});
}
let (basis_flat, actual_nbasis) = evaluate_basis(argvals, &fdpar.basis_type, fdpar.nbasis);
let k = actual_nbasis;
let b_mat = DMatrix::from_column_slice(m, k, &basis_flat);
let r_mat = DMatrix::from_column_slice(k, k, &fdpar.penalty_matrix);
let btb = b_mat.transpose() * &b_mat;
let ridge_eps = 1e-10;
let system: DMatrix<f64> =
&btb + fdpar.lambda * &r_mat + ridge_eps * DMatrix::<f64>::identity(k, k);
let system_inv =
invert_penalized_system(&system, k).ok_or_else(|| crate::FdarError::ComputationFailed {
operation: "matrix inversion",
detail: "failed to invert penalized system (Φ'Φ + λR); try increasing lambda or reducing the number of basis functions".to_string(),
})?;
let h_mat = &b_mat * &system_inv * b_mat.transpose();
let edf: f64 = (0..m).map(|i| h_mat[(i, i)]).sum();
let proj = &system_inv * b_mat.transpose();
let (all_coefs, all_fitted, total_rss) = project_all_curves(data, &b_mat, &proj, n, m, k);
let total_points = (n * m) as f64;
let gcv = compute_gcv(total_rss, total_points, edf, m);
let mse = total_rss / total_points;
let total_edf = n as f64 * edf;
let aic = total_points * mse.max(1e-300).ln() + 2.0 * total_edf;
let bic = total_points * mse.max(1e-300).ln() + total_points.ln() * total_edf;
Ok(SmoothBasisResult {
coefficients: all_coefs,
fitted: all_fitted,
edf,
gcv,
aic,
bic,
penalty_matrix: fdpar.penalty_matrix.clone(),
nbasis: k,
})
}
pub fn smooth_basis_gcv(
data: &FdMatrix,
argvals: &[f64],
basis_type: &BasisType,
nbasis: usize,
lfd_order: usize,
log_lambda_range: (f64, f64),
n_grid: usize,
) -> Option<SmoothBasisResult> {
let m = argvals.len();
if m == 0 || nbasis < 2 || n_grid < 2 {
return None;
}
let penalty = match basis_type {
BasisType::Bspline { order } => bspline_penalty_matrix(argvals, nbasis, *order, lfd_order),
BasisType::Fourier { period } => fourier_penalty_matrix(nbasis, *period, lfd_order),
};
let (lo, hi) = log_lambda_range;
let mut best_gcv = f64::INFINITY;
let mut best_result: Option<SmoothBasisResult> = None;
for i in 0..n_grid {
let log_lam = lo + (hi - lo) * i as f64 / (n_grid - 1) as f64;
let lam = 10.0_f64.powf(log_lam);
let fdpar = FdPar {
basis_type: basis_type.clone(),
nbasis,
lambda: lam,
lfd_order,
penalty_matrix: penalty.clone(),
};
if let Ok(result) = smooth_basis(data, argvals, &fdpar) {
if result.gcv < best_gcv {
best_gcv = result.gcv;
best_result = Some(result);
}
}
}
best_result
}
#[derive(Debug, Clone, PartialEq)]
pub struct SmoothBasisGcvConfig {
pub basis_type: BasisType,
pub nbasis: usize,
pub lfd_order: usize,
pub log_lambda_range: (f64, f64),
pub n_grid: usize,
}
impl Default for SmoothBasisGcvConfig {
fn default() -> Self {
Self {
basis_type: BasisType::Bspline { order: 4 },
nbasis: 15,
lfd_order: 2,
log_lambda_range: (-10.0, 2.0),
n_grid: 50,
}
}
}
#[must_use = "expensive computation whose result should not be discarded"]
pub fn smooth_basis_gcv_with_config(
data: &FdMatrix,
argvals: &[f64],
config: &SmoothBasisGcvConfig,
) -> Result<SmoothBasisResult, crate::FdarError> {
smooth_basis_gcv(
data,
argvals,
&config.basis_type,
config.nbasis,
config.lfd_order,
config.log_lambda_range,
config.n_grid,
)
.ok_or_else(|| crate::FdarError::ComputationFailed {
operation: "smooth_basis_gcv_with_config",
detail: "no valid smoothing result found in GCV lambda search".to_string(),
})
}
#[derive(Debug, Clone, PartialEq)]
pub struct BasisNbasisCvConfig {
pub basis_type: BasisType,
pub nbasis_range: (usize, usize),
pub lambda: f64,
pub lfd_order: usize,
pub n_folds: usize,
pub criterion: BasisCriterion,
}
impl Default for BasisNbasisCvConfig {
fn default() -> Self {
Self {
basis_type: BasisType::Bspline { order: 4 },
nbasis_range: (5, 30),
lambda: 1e-4,
lfd_order: 2,
n_folds: 5,
criterion: BasisCriterion::Gcv,
}
}
}
#[must_use = "expensive computation whose result should not be discarded"]
pub fn basis_nbasis_cv_with_config(
data: &FdMatrix,
argvals: &[f64],
config: &BasisNbasisCvConfig,
) -> Result<BasisNbasisCvResult, crate::FdarError> {
let nbasis_range: Vec<usize> = (config.nbasis_range.0..=config.nbasis_range.1).collect();
basis_nbasis_cv(
data,
argvals,
&nbasis_range,
&config.basis_type,
config.criterion,
config.n_folds,
config.lambda,
)
.ok_or_else(|| crate::FdarError::ComputationFailed {
operation: "basis_nbasis_cv_with_config",
detail: "no valid result found in nbasis CV search".to_string(),
})
}
fn differentiate_basis_columns(
basis: &[f64],
n_quad: usize,
nbasis: usize,
h: f64,
lfd_order: usize,
) -> Vec<f64> {
let mut deriv = basis.to_vec();
for _ in 0..lfd_order {
let mut new_deriv = vec![0.0; n_quad * nbasis];
for j in 0..nbasis {
let col: Vec<f64> = (0..n_quad).map(|i| deriv[i + j * n_quad]).collect();
let grad = crate::helpers::gradient_uniform(&col, h);
for i in 0..n_quad {
new_deriv[i + j * n_quad] = grad[i];
}
}
deriv = new_deriv;
}
deriv
}
fn integrate_symmetric_penalty(
deriv_basis: &[f64],
weights: &[f64],
k: usize,
n_quad: usize,
) -> Vec<f64> {
let mut penalty = vec![0.0; k * k];
for j in 0..k {
for l in j..k {
let mut val = 0.0;
for i in 0..n_quad {
val += deriv_basis[i + j * n_quad] * deriv_basis[i + l * n_quad] * weights[i];
}
penalty[j + l * k] = val;
penalty[l + j * k] = val;
}
}
penalty
}
fn evaluate_basis(argvals: &[f64], basis_type: &BasisType, nbasis: usize) -> (Vec<f64>, usize) {
let m = argvals.len();
match basis_type {
BasisType::Bspline { order } => {
let nknots = nbasis.saturating_sub(*order).max(2);
let basis = bspline_basis(argvals, nknots, *order);
let actual = basis.len() / m;
(basis, actual)
}
BasisType::Fourier { period } => {
let basis = fourier_basis_with_period(argvals, nbasis, *period);
(basis, nbasis)
}
}
}
fn invert_penalized_system(system: &DMatrix<f64>, k: usize) -> Option<DMatrix<f64>> {
if let Some(chol) = system.clone().cholesky() {
return Some(chol.inverse());
}
let svd = nalgebra::SVD::new(system.clone(), true, true);
let u = svd.u.as_ref()?;
let v_t = svd.v_t.as_ref()?;
let max_sv: f64 = svd.singular_values.iter().copied().fold(0.0_f64, f64::max);
let eps = 1e-10 * max_sv;
let mut inv = DMatrix::<f64>::zeros(k, k);
for ii in 0..k {
for jj in 0..k {
let mut sum = 0.0;
for s in 0..k.min(svd.singular_values.len()) {
if svd.singular_values[s] > eps {
sum += v_t[(s, ii)] / svd.singular_values[s] * u[(jj, s)];
}
}
inv[(ii, jj)] = sum;
}
}
Some(inv)
}
fn project_all_curves(
data: &FdMatrix,
b_mat: &DMatrix<f64>,
proj: &DMatrix<f64>,
n: usize,
m: usize,
k: usize,
) -> (FdMatrix, FdMatrix, f64) {
let mut all_coefs = FdMatrix::zeros(n, k);
let mut all_fitted = FdMatrix::zeros(n, m);
let mut total_rss = 0.0;
for i in 0..n {
let curve: Vec<f64> = (0..m).map(|j| data[(i, j)]).collect();
let y_vec = nalgebra::DVector::from_vec(curve.clone());
let coefs = proj * &y_vec;
for j in 0..k {
all_coefs[(i, j)] = coefs[j];
}
let fitted = b_mat * &coefs;
for j in 0..m {
all_fitted[(i, j)] = fitted[j];
let resid = curve[j] - fitted[j];
total_rss += resid * resid;
}
}
(all_coefs, all_fitted, total_rss)
}
fn compute_gcv(rss: f64, n_points: f64, edf: f64, m: usize) -> f64 {
let gcv_denom = 1.0 - edf / m as f64;
if gcv_denom.abs() > 1e-10 {
(rss / n_points) / (gcv_denom * gcv_denom)
} else {
f64::INFINITY
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum BasisCriterion {
Gcv,
Cv,
Aic,
Bic,
}
#[derive(Debug, Clone, PartialEq)]
pub struct BasisNbasisCvResult {
pub optimal_nbasis: usize,
pub scores: Vec<f64>,
pub nbasis_range: Vec<usize>,
pub criterion: BasisCriterion,
}
fn evaluate_nbasis_info_criterion(
data: &FdMatrix,
argvals: &[f64],
nbasis_range: &[usize],
basis_type: &BasisType,
criterion: BasisCriterion,
lambda: f64,
) -> Vec<f64> {
let mut scores = Vec::with_capacity(nbasis_range.len());
for &nb in nbasis_range {
if nb < 2 {
scores.push(f64::INFINITY);
continue;
}
let penalty = match basis_type {
BasisType::Bspline { order } => bspline_penalty_matrix(argvals, nb, *order, 2),
BasisType::Fourier { period } => fourier_penalty_matrix(nb, *period, 2),
};
let fdpar = FdPar {
basis_type: basis_type.clone(),
nbasis: nb,
lambda,
lfd_order: 2,
penalty_matrix: penalty,
};
match smooth_basis(data, argvals, &fdpar) {
Ok(result) => {
let score = match criterion {
BasisCriterion::Gcv => result.gcv,
BasisCriterion::Aic => result.aic,
BasisCriterion::Bic => result.bic,
BasisCriterion::Cv => unreachable!(),
};
scores.push(score);
}
Err(_) => scores.push(f64::INFINITY),
}
}
scores
}
fn evaluate_nbasis_cv(
data: &FdMatrix,
argvals: &[f64],
nbasis_range: &[usize],
basis_type: &BasisType,
lambda: f64,
n_folds: usize,
) -> Vec<f64> {
let (n, m) = data.shape();
let n_folds = n_folds.max(2);
let folds = crate::cv::create_folds(n, n_folds, 42);
let mut scores = Vec::with_capacity(nbasis_range.len());
for &nb in nbasis_range {
if nb < 2 {
scores.push(f64::INFINITY);
continue;
}
let penalty = match basis_type {
BasisType::Bspline { order } => bspline_penalty_matrix(argvals, nb, *order, 2),
BasisType::Fourier { period } => fourier_penalty_matrix(nb, *period, 2),
};
let mut total_mse = 0.0;
let mut total_count = 0;
for fold in 0..n_folds {
let (train_idx, test_idx) = crate::cv::fold_indices(&folds, fold);
if train_idx.is_empty() || test_idx.is_empty() {
continue;
}
let train_data = crate::cv::subset_rows(data, &train_idx);
let fdpar = FdPar {
basis_type: basis_type.clone(),
nbasis: nb,
lambda,
lfd_order: 2,
penalty_matrix: penalty.clone(),
};
if let Ok(train_result) = smooth_basis(&train_data, argvals, &fdpar) {
let (basis_flat, actual_k) = evaluate_basis(argvals, basis_type, nb);
let b_mat = DMatrix::from_column_slice(m, actual_k, &basis_flat);
let r_mat =
DMatrix::from_column_slice(actual_k, actual_k, &train_result.penalty_matrix);
let btb = b_mat.transpose() * &b_mat;
let ridge_eps = 1e-10;
let system: DMatrix<f64> = &btb
+ lambda * &r_mat
+ ridge_eps * DMatrix::<f64>::identity(actual_k, actual_k);
if let Some(system_inv) = invert_penalized_system(&system, actual_k) {
let proj = &system_inv * b_mat.transpose();
for &ti in &test_idx {
let curve: Vec<f64> = (0..m).map(|j| data[(ti, j)]).collect();
let y_vec = nalgebra::DVector::from_vec(curve.clone());
let coefs = &proj * &y_vec;
let fitted = &b_mat * &coefs;
let mse: f64 =
(0..m).map(|j| (curve[j] - fitted[j]).powi(2)).sum::<f64>() / m as f64;
total_mse += mse;
total_count += 1;
}
}
}
}
if total_count > 0 {
scores.push(total_mse / f64::from(total_count));
} else {
scores.push(f64::INFINITY);
}
}
scores
}
pub fn basis_nbasis_cv(
data: &FdMatrix,
argvals: &[f64],
nbasis_range: &[usize],
basis_type: &BasisType,
criterion: BasisCriterion,
n_folds: usize,
lambda: f64,
) -> Option<BasisNbasisCvResult> {
let (n, m) = data.shape();
if n == 0 || m == 0 || argvals.len() != m || nbasis_range.is_empty() {
return None;
}
let scores = match criterion {
BasisCriterion::Gcv | BasisCriterion::Aic | BasisCriterion::Bic => {
evaluate_nbasis_info_criterion(
data,
argvals,
nbasis_range,
basis_type,
criterion,
lambda,
)
}
BasisCriterion::Cv => {
evaluate_nbasis_cv(data, argvals, nbasis_range, basis_type, lambda, n_folds)
}
};
let (best_idx, _) = scores
.iter()
.enumerate()
.min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))?;
Some(BasisNbasisCvResult {
optimal_nbasis: nbasis_range[best_idx],
scores,
nbasis_range: nbasis_range.to_vec(),
criterion,
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_helpers::uniform_grid;
use std::f64::consts::PI;
#[test]
fn test_bspline_penalty_matrix_symmetric() {
let t = uniform_grid(101);
let penalty = bspline_penalty_matrix(&t, 15, 4, 2);
let _k = 15; let actual_k = (penalty.len() as f64).sqrt() as usize;
for i in 0..actual_k {
for j in 0..actual_k {
assert!(
(penalty[i + j * actual_k] - penalty[j + i * actual_k]).abs() < 1e-10,
"Penalty matrix not symmetric at ({}, {})",
i,
j
);
}
}
}
#[test]
fn test_bspline_penalty_matrix_positive_semidefinite() {
let t = uniform_grid(101);
let penalty = bspline_penalty_matrix(&t, 10, 4, 2);
let k = (penalty.len() as f64).sqrt() as usize;
for i in 0..k {
assert!(
penalty[i + i * k] >= -1e-10,
"Diagonal element {} is negative: {}",
i,
penalty[i + i * k]
);
}
}
#[test]
fn test_fourier_penalty_diagonal() {
let penalty = fourier_penalty_matrix(7, 1.0, 2);
for i in 0..7 {
for j in 0..7 {
if i != j {
assert!(
penalty[i + j * 7].abs() < 1e-10,
"Off-diagonal ({},{}) = {}",
i,
j,
penalty[i + j * 7]
);
}
}
}
assert!(penalty[0].abs() < 1e-10);
assert!(penalty[1 + 7] > 0.0);
assert!(penalty[3 + 3 * 7] > penalty[1 + 7]);
}
#[test]
fn test_smooth_basis_bspline() {
let m = 101;
let n = 5;
let t = uniform_grid(m);
let mut data = FdMatrix::zeros(n, m);
for i in 0..n {
for j in 0..m {
data[(i, j)] = (2.0 * PI * t[j]).sin() + 0.1 * (i as f64 * 0.3 + j as f64 * 0.01);
}
}
let nbasis = 15;
let penalty = bspline_penalty_matrix(&t, nbasis, 4, 2);
let _actual_k = (penalty.len() as f64).sqrt() as usize;
let fdpar = FdPar {
basis_type: BasisType::Bspline { order: 4 },
nbasis,
lambda: 1e-4,
lfd_order: 2,
penalty_matrix: penalty,
};
let result = smooth_basis(&data, &t, &fdpar);
assert!(result.is_ok(), "smooth_basis should succeed");
let res = result.unwrap();
assert_eq!(res.fitted.shape(), (n, m));
assert_eq!(res.coefficients.nrows(), n);
assert!(res.edf > 0.0, "EDF should be positive");
assert!(res.gcv > 0.0, "GCV should be positive");
}
#[test]
fn test_smooth_basis_fourier() {
let m = 101;
let n = 3;
let t = uniform_grid(m);
let mut data = FdMatrix::zeros(n, m);
for i in 0..n {
for j in 0..m {
data[(i, j)] = (2.0 * PI * t[j]).sin() + (4.0 * PI * t[j]).cos();
}
}
let nbasis = 7;
let period = 1.0;
let penalty = fourier_penalty_matrix(nbasis, period, 2);
let fdpar = FdPar {
basis_type: BasisType::Fourier { period },
nbasis,
lambda: 1e-6,
lfd_order: 2,
penalty_matrix: penalty,
};
let result = smooth_basis(&data, &t, &fdpar);
assert!(result.is_ok());
let res = result.unwrap();
for j in 0..m {
let expected = (2.0 * PI * t[j]).sin() + (4.0 * PI * t[j]).cos();
assert!(
(res.fitted[(0, j)] - expected).abs() < 0.1,
"Fourier fit poor at j={}: got {}, expected {}",
j,
res.fitted[(0, j)],
expected
);
}
}
#[test]
fn test_smooth_basis_gcv_selects_reasonable_lambda() {
let m = 101;
let n = 5;
let t = uniform_grid(m);
let mut data = FdMatrix::zeros(n, m);
for i in 0..n {
for j in 0..m {
data[(i, j)] =
(2.0 * PI * t[j]).sin() + 0.1 * ((i * 37 + j * 13) % 20) as f64 / 20.0;
}
}
let basis_type = BasisType::Bspline { order: 4 };
let result = smooth_basis_gcv(&data, &t, &basis_type, 15, 2, (-8.0, 4.0), 25);
assert!(result.is_some(), "GCV search should succeed");
}
#[test]
fn test_smooth_basis_large_lambda_reduces_edf() {
let m = 101;
let n = 3;
let t = uniform_grid(m);
let mut data = FdMatrix::zeros(n, m);
for i in 0..n {
for j in 0..m {
data[(i, j)] = (2.0 * PI * t[j]).sin();
}
}
let nbasis = 15;
let penalty = bspline_penalty_matrix(&t, nbasis, 4, 2);
let _actual_k = (penalty.len() as f64).sqrt() as usize;
let fdpar_small = FdPar {
basis_type: BasisType::Bspline { order: 4 },
nbasis,
lambda: 1e-8,
lfd_order: 2,
penalty_matrix: penalty.clone(),
};
let fdpar_large = FdPar {
basis_type: BasisType::Bspline { order: 4 },
nbasis,
lambda: 1e2,
lfd_order: 2,
penalty_matrix: penalty,
};
let res_small = smooth_basis(&data, &t, &fdpar_small).unwrap();
let res_large = smooth_basis(&data, &t, &fdpar_large).unwrap();
assert!(
res_large.edf < res_small.edf,
"Larger lambda should reduce EDF: {} vs {}",
res_large.edf,
res_small.edf
);
}
#[test]
fn test_basis_nbasis_cv_gcv() {
let m = 101;
let n = 5;
let t = uniform_grid(m);
let mut data = FdMatrix::zeros(n, m);
for i in 0..n {
for j in 0..m {
data[(i, j)] =
(2.0 * PI * t[j]).sin() + 0.1 * ((i * 37 + j * 13) % 20) as f64 / 20.0;
}
}
let nbasis_range: Vec<usize> = (4..=20).step_by(2).collect();
let result = basis_nbasis_cv(
&data,
&t,
&nbasis_range,
&BasisType::Bspline { order: 4 },
BasisCriterion::Gcv,
5,
1e-4,
);
assert!(result.is_some());
let res = result.unwrap();
assert!(nbasis_range.contains(&res.optimal_nbasis));
assert_eq!(res.scores.len(), nbasis_range.len());
assert_eq!(res.criterion, BasisCriterion::Gcv);
}
#[test]
fn test_basis_nbasis_cv_aic_bic() {
let m = 51;
let n = 5;
let t = uniform_grid(m);
let mut data = FdMatrix::zeros(n, m);
for i in 0..n {
for j in 0..m {
data[(i, j)] = (2.0 * PI * t[j]).sin();
}
}
let nbasis_range: Vec<usize> = vec![5, 7, 9, 11];
let aic_result = basis_nbasis_cv(
&data,
&t,
&nbasis_range,
&BasisType::Bspline { order: 4 },
BasisCriterion::Aic,
5,
0.0,
);
let bic_result = basis_nbasis_cv(
&data,
&t,
&nbasis_range,
&BasisType::Bspline { order: 4 },
BasisCriterion::Bic,
5,
0.0,
);
assert!(aic_result.is_some());
assert!(bic_result.is_some());
}
#[test]
fn test_basis_nbasis_cv_kfold() {
let m = 51;
let n = 10;
let t = uniform_grid(m);
let mut data = FdMatrix::zeros(n, m);
for i in 0..n {
for j in 0..m {
data[(i, j)] = (2.0 * PI * t[j]).sin() + 0.05 * ((i * 7 + j * 3) % 10) as f64;
}
}
let nbasis_range: Vec<usize> = vec![5, 7, 9];
let result = basis_nbasis_cv(
&data,
&t,
&nbasis_range,
&BasisType::Bspline { order: 4 },
BasisCriterion::Cv,
5,
1e-4,
);
assert!(result.is_some());
let res = result.unwrap();
assert!(nbasis_range.contains(&res.optimal_nbasis));
assert_eq!(res.criterion, BasisCriterion::Cv);
}
fn make_test_data(n: usize, m: usize) -> (FdMatrix, Vec<f64>) {
let t: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
let mut data = FdMatrix::zeros(n, m);
for i in 0..n {
for j in 0..m {
data[(i, j)] = (2.0 * PI * t[j]).sin()
+ 0.1 * (10.0 * t[j]).sin()
+ 0.05 * ((i * 37 + j * 13) % 20) as f64 / 20.0;
}
}
(data, t)
}
fn make_bspline_fdpar(argvals: &[f64], nbasis: usize, lambda: f64) -> FdPar {
let penalty = bspline_penalty_matrix(argvals, nbasis, 4, 2);
FdPar {
basis_type: BasisType::Bspline { order: 4 },
nbasis,
lambda,
lfd_order: 2,
penalty_matrix: penalty,
}
}
fn make_fourier_fdpar(nbasis: usize, period: f64, lambda: f64) -> FdPar {
let penalty = fourier_penalty_matrix(nbasis, period, 2);
FdPar {
basis_type: BasisType::Fourier { period },
nbasis,
lambda,
lfd_order: 2,
penalty_matrix: penalty,
}
}
#[test]
fn test_basis_type_bspline_variant() {
let bt = BasisType::Bspline { order: 4 };
assert_eq!(bt, BasisType::Bspline { order: 4 });
assert_ne!(bt, BasisType::Bspline { order: 3 });
}
#[test]
fn test_basis_type_fourier_variant() {
let bt = BasisType::Fourier { period: 1.0 };
assert_eq!(bt, BasisType::Fourier { period: 1.0 });
assert_ne!(bt, BasisType::Fourier { period: 2.0 });
}
#[test]
fn test_basis_type_cross_variant_inequality() {
let bspline = BasisType::Bspline { order: 4 };
let fourier = BasisType::Fourier { period: 1.0 };
assert_ne!(bspline, fourier);
}
#[test]
fn test_basis_type_clone_and_debug() {
let bt = BasisType::Bspline { order: 4 };
let cloned = bt.clone();
assert_eq!(bt, cloned);
let debug_str = format!("{:?}", bt);
assert!(debug_str.contains("Bspline"));
assert!(debug_str.contains("4"));
}
#[test]
fn test_fdpar_construction_and_fields() {
let penalty = vec![1.0, 0.0, 0.0, 1.0];
let fdpar = FdPar {
basis_type: BasisType::Bspline { order: 4 },
nbasis: 2,
lambda: 0.01,
lfd_order: 2,
penalty_matrix: penalty.clone(),
};
assert_eq!(fdpar.nbasis, 2);
assert!((fdpar.lambda - 0.01).abs() < 1e-15);
assert_eq!(fdpar.lfd_order, 2);
assert_eq!(fdpar.penalty_matrix.len(), 4);
}
#[test]
fn test_fdpar_clone_and_debug() {
let t = uniform_grid(50);
let fdpar = make_bspline_fdpar(&t, 8, 1e-3);
let cloned = fdpar.clone();
assert_eq!(fdpar, cloned);
let debug_str = format!("{:?}", fdpar);
assert!(debug_str.contains("FdPar"));
}
#[test]
fn test_basis_criterion_variants() {
assert_eq!(BasisCriterion::Gcv, BasisCriterion::Gcv);
assert_eq!(BasisCriterion::Cv, BasisCriterion::Cv);
assert_eq!(BasisCriterion::Aic, BasisCriterion::Aic);
assert_eq!(BasisCriterion::Bic, BasisCriterion::Bic);
assert_ne!(BasisCriterion::Gcv, BasisCriterion::Aic);
assert_ne!(BasisCriterion::Cv, BasisCriterion::Bic);
}
#[test]
fn test_basis_criterion_copy() {
let c = BasisCriterion::Gcv;
let copied = c; assert_eq!(c, copied);
}
#[test]
fn test_basis_criterion_debug() {
let debug_str = format!("{:?}", BasisCriterion::Bic);
assert!(debug_str.contains("Bic"));
}
#[test]
fn test_smooth_basis_result_all_fields() {
let (data, t) = make_test_data(3, 50);
let fdpar = make_bspline_fdpar(&t, 10, 1e-4);
let res = smooth_basis(&data, &t, &fdpar).unwrap();
assert_eq!(res.coefficients.nrows(), 3);
assert!(res.coefficients.ncols() > 0);
assert_eq!(res.nbasis, res.coefficients.ncols());
assert_eq!(res.fitted.shape(), (3, 50));
assert!(res.edf > 0.0 && res.edf <= res.nbasis as f64);
assert!(res.gcv.is_finite());
assert!(res.aic.is_finite());
assert!(res.bic.is_finite());
let k = res.nbasis;
assert_eq!(res.penalty_matrix.len(), k * k);
}
#[test]
fn test_smooth_basis_result_clone() {
let (data, t) = make_test_data(2, 50);
let fdpar = make_bspline_fdpar(&t, 8, 1e-3);
let res = smooth_basis(&data, &t, &fdpar).unwrap();
let cloned = res.clone();
assert_eq!(res, cloned);
}
#[test]
fn test_smooth_basis_bspline_coefficient_shape() {
let (data, t) = make_test_data(4, 50);
let nbasis = 12;
let fdpar = make_bspline_fdpar(&t, nbasis, 1e-4);
let res = smooth_basis(&data, &t, &fdpar).unwrap();
assert_eq!(res.coefficients.nrows(), 4);
assert!(res.coefficients.ncols() >= 2);
assert_eq!(res.nbasis, res.coefficients.ncols());
}
#[test]
fn test_smooth_basis_bspline_fitted_values_shape() {
let m = 80;
let n = 6;
let (data, t) = make_test_data(n, m);
let fdpar = make_bspline_fdpar(&t, 15, 1e-4);
let res = smooth_basis(&data, &t, &fdpar).unwrap();
assert_eq!(res.fitted.shape(), (n, m));
}
#[test]
fn test_smooth_basis_bspline_zero_lambda_interpolates() {
let m = 30;
let n = 2;
let (data, t) = make_test_data(n, m);
let fdpar = make_bspline_fdpar(&t, 15, 0.0);
let res = smooth_basis(&data, &t, &fdpar).unwrap();
let mut max_resid = 0.0_f64;
for i in 0..n {
for j in 0..m {
let resid = (data[(i, j)] - res.fitted[(i, j)]).abs();
max_resid = max_resid.max(resid);
}
}
assert!(
max_resid < 0.5,
"Zero-lambda B-spline should closely interpolate; max_resid = {}",
max_resid
);
}
#[test]
fn test_smooth_basis_bspline_large_lambda_oversmooths() {
let m = 50;
let n = 1;
let (data, t) = make_test_data(n, m);
let fdpar_small = make_bspline_fdpar(&t, 15, 1e-6);
let res_small = smooth_basis(&data, &t, &fdpar_small).unwrap();
let fdpar_large = make_bspline_fdpar(&t, 15, 1e6);
let res_large = smooth_basis(&data, &t, &fdpar_large).unwrap();
let compute_variance = |fitted: &FdMatrix, row: usize, ncols: usize| -> f64 {
let vals: Vec<f64> = (0..ncols).map(|j| fitted[(row, j)]).collect();
let mean = vals.iter().sum::<f64>() / ncols as f64;
vals.iter().map(|&v| (v - mean).powi(2)).sum::<f64>() / ncols as f64
};
let var_small = compute_variance(&res_small.fitted, 0, m);
let var_large = compute_variance(&res_large.fitted, 0, m);
assert!(
var_large < var_small,
"Large lambda should yield lower variance fit: var_large={}, var_small={}",
var_large,
var_small
);
}
#[test]
fn test_smooth_basis_bspline_penalty_effect_on_smoothness() {
let m = 50;
let n = 1;
let (data, t) = make_test_data(n, m);
let fdpar_small = make_bspline_fdpar(&t, 15, 1e-8);
let fdpar_large = make_bspline_fdpar(&t, 15, 1.0);
let res_small = smooth_basis(&data, &t, &fdpar_small).unwrap();
let res_large = smooth_basis(&data, &t, &fdpar_large).unwrap();
let roughness = |fitted: &FdMatrix, row: usize, ncols: usize| -> f64 {
(1..ncols - 1)
.map(|j| {
let d2 = fitted[(row, j + 1)] - 2.0 * fitted[(row, j)] + fitted[(row, j - 1)];
d2 * d2
})
.sum::<f64>()
};
let r_small = roughness(&res_small.fitted, 0, m);
let r_large = roughness(&res_large.fitted, 0, m);
assert!(
r_large < r_small,
"Larger lambda should produce smoother fit: roughness_large={}, roughness_small={}",
r_large,
r_small
);
}
#[test]
fn test_smooth_basis_bspline_single_curve() {
let m = 50;
let (data, t) = make_test_data(1, m);
let fdpar = make_bspline_fdpar(&t, 10, 1e-4);
let res = smooth_basis(&data, &t, &fdpar).unwrap();
assert_eq!(res.fitted.nrows(), 1);
assert_eq!(res.fitted.ncols(), m);
assert!(res.gcv.is_finite());
}
#[test]
fn test_smooth_basis_bspline_many_curves() {
let m = 50;
let n = 20;
let (data, t) = make_test_data(n, m);
let fdpar = make_bspline_fdpar(&t, 10, 1e-4);
let res = smooth_basis(&data, &t, &fdpar).unwrap();
assert_eq!(res.fitted.nrows(), n);
assert_eq!(res.coefficients.nrows(), n);
}
#[test]
fn test_smooth_basis_bspline_minimal_nbasis() {
let m = 50;
let (data, t) = make_test_data(1, m);
let fdpar = make_bspline_fdpar(&t, 2, 1e-4);
let res = smooth_basis(&data, &t, &fdpar);
assert!(res.is_ok());
}
#[test]
fn test_smooth_basis_bspline_different_orders() {
let m = 50;
let (data, t) = make_test_data(2, m);
let penalty3 = bspline_penalty_matrix(&t, 10, 3, 2);
let fdpar3 = FdPar {
basis_type: BasisType::Bspline { order: 3 },
nbasis: 10,
lambda: 1e-4,
lfd_order: 2,
penalty_matrix: penalty3,
};
let res3 = smooth_basis(&data, &t, &fdpar3);
assert!(res3.is_ok());
let penalty5 = bspline_penalty_matrix(&t, 10, 5, 2);
let fdpar5 = FdPar {
basis_type: BasisType::Bspline { order: 5 },
nbasis: 10,
lambda: 1e-4,
lfd_order: 2,
penalty_matrix: penalty5,
};
let res5 = smooth_basis(&data, &t, &fdpar5);
assert!(res5.is_ok());
}
#[test]
fn test_smooth_basis_fourier_coefficient_shape() {
let m = 50;
let n = 3;
let t = uniform_grid(m);
let mut data = FdMatrix::zeros(n, m);
for i in 0..n {
for j in 0..m {
data[(i, j)] = (2.0 * PI * t[j]).sin();
}
}
let nbasis = 7;
let fdpar = make_fourier_fdpar(nbasis, 1.0, 1e-6);
let res = smooth_basis(&data, &t, &fdpar).unwrap();
assert_eq!(res.coefficients.nrows(), n);
assert_eq!(res.coefficients.ncols(), nbasis);
assert_eq!(res.nbasis, nbasis);
}
#[test]
fn test_smooth_basis_fourier_fits_pure_sine() {
let m = 100;
let t = uniform_grid(m);
let mut data = FdMatrix::zeros(1, m);
for j in 0..m {
data[(0, j)] = (2.0 * PI * t[j]).sin();
}
let fdpar = make_fourier_fdpar(5, 1.0, 1e-8);
let res = smooth_basis(&data, &t, &fdpar).unwrap();
for j in 0..m {
let expected = (2.0 * PI * t[j]).sin();
assert!(
(res.fitted[(0, j)] - expected).abs() < 0.05,
"Fourier should fit pure sine; j={}, got={}, expected={}",
j,
res.fitted[(0, j)],
expected
);
}
}
#[test]
fn test_smooth_basis_fourier_different_periods() {
let m = 50;
let t = uniform_grid(m);
let mut data = FdMatrix::zeros(1, m);
for j in 0..m {
data[(0, j)] = (2.0 * PI * t[j]).sin();
}
let fdpar1 = make_fourier_fdpar(7, 1.0, 1e-6);
let res1 = smooth_basis(&data, &t, &fdpar1).unwrap();
let fdpar2 = make_fourier_fdpar(7, 2.0, 1e-6);
let res2 = smooth_basis(&data, &t, &fdpar2).unwrap();
assert_eq!(res1.fitted.shape(), (1, m));
assert_eq!(res2.fitted.shape(), (1, m));
}
#[test]
fn test_smooth_basis_fourier_zero_lambda() {
let m = 50;
let t = uniform_grid(m);
let mut data = FdMatrix::zeros(1, m);
for j in 0..m {
data[(0, j)] = (2.0 * PI * t[j]).sin() + (4.0 * PI * t[j]).cos();
}
let fdpar = make_fourier_fdpar(9, 1.0, 0.0);
let res = smooth_basis(&data, &t, &fdpar).unwrap();
assert_eq!(res.fitted.shape(), (1, m));
assert!(res.edf > 1.0);
}
#[test]
fn test_smooth_basis_fourier_large_lambda() {
let m = 50;
let t = uniform_grid(m);
let mut data = FdMatrix::zeros(1, m);
for j in 0..m {
data[(0, j)] = (2.0 * PI * t[j]).sin();
}
let fdpar = make_fourier_fdpar(9, 1.0, 1e6);
let res = smooth_basis(&data, &t, &fdpar).unwrap();
assert!(
res.edf < 5.0,
"Large lambda should reduce EDF; edf={}",
res.edf
);
}
#[test]
fn test_smooth_basis_lambda_gradient_edf() {
let m = 50;
let (data, t) = make_test_data(3, m);
let lambdas = [1e-8, 1e-4, 1e-2, 1.0, 1e2];
let mut prev_edf = f64::INFINITY;
for &lam in &lambdas {
let fdpar = make_bspline_fdpar(&t, 12, lam);
let res = smooth_basis(&data, &t, &fdpar).unwrap();
assert!(
res.edf <= prev_edf + 0.01,
"EDF should decrease: lambda={}, edf={}, prev_edf={}",
lam,
res.edf,
prev_edf
);
prev_edf = res.edf;
}
}
#[test]
fn test_smooth_basis_lambda_gradient_rss() {
let m = 50;
let n = 2;
let (data, t) = make_test_data(n, m);
let lambdas = [0.0, 1e-6, 1e-2, 1.0, 1e4];
let mut prev_rss = -1.0;
for &lam in &lambdas {
let fdpar = make_bspline_fdpar(&t, 12, lam);
let res = smooth_basis(&data, &t, &fdpar).unwrap();
let mut rss = 0.0;
for i in 0..n {
for j in 0..m {
rss += (data[(i, j)] - res.fitted[(i, j)]).powi(2);
}
}
assert!(
rss >= prev_rss - 1e-8,
"RSS should increase: lambda={}, rss={}, prev_rss={}",
lam,
rss,
prev_rss
);
prev_rss = rss;
}
}
#[test]
fn test_smooth_basis_empty_data_rows() {
let t = uniform_grid(50);
let data = FdMatrix::zeros(0, 50);
let fdpar = make_bspline_fdpar(&t, 10, 1e-4);
let res = smooth_basis(&data, &t, &fdpar);
assert!(res.is_err());
}
#[test]
fn test_smooth_basis_empty_data_cols() {
let data = FdMatrix::zeros(5, 0);
let fdpar = FdPar {
basis_type: BasisType::Bspline { order: 4 },
nbasis: 10,
lambda: 1e-4,
lfd_order: 2,
penalty_matrix: vec![0.0; 100],
};
let res = smooth_basis(&data, &[], &fdpar);
assert!(res.is_err());
}
#[test]
fn test_smooth_basis_mismatched_argvals() {
let t = uniform_grid(50);
let data = FdMatrix::zeros(3, 40); let fdpar = make_bspline_fdpar(&t, 10, 1e-4);
let res = smooth_basis(&data, &t, &fdpar);
assert!(res.is_err());
}
#[test]
fn test_smooth_basis_nbasis_too_small() {
let t = uniform_grid(50);
let data = FdMatrix::zeros(3, 50);
let fdpar = FdPar {
basis_type: BasisType::Bspline { order: 4 },
nbasis: 1,
lambda: 1e-4,
lfd_order: 2,
penalty_matrix: vec![0.0; 1],
};
let res = smooth_basis(&data, &t, &fdpar);
assert!(res.is_err());
}
#[test]
fn test_smooth_basis_error_is_invalid_dimension() {
let t = uniform_grid(50);
let data = FdMatrix::zeros(0, 50);
let fdpar = make_bspline_fdpar(&t, 10, 1e-4);
let err = smooth_basis(&data, &t, &fdpar).unwrap_err();
match err {
crate::FdarError::InvalidDimension { .. } => {} other => panic!("Expected InvalidDimension, got {:?}", other),
}
}
#[test]
fn test_bspline_penalty_matrix_different_orders() {
let t = uniform_grid(101);
let p1 = bspline_penalty_matrix(&t, 10, 4, 1);
let p2 = bspline_penalty_matrix(&t, 10, 4, 2);
assert_eq!(p1.len(), p2.len());
let diff: f64 = p1.iter().zip(p2.iter()).map(|(a, b)| (a - b).abs()).sum();
assert!(
diff > 1e-10,
"Different lfd_orders should produce different penalties"
);
}
#[test]
fn test_bspline_penalty_matrix_edge_cases() {
let t = vec![0.0];
let p = bspline_penalty_matrix(&t, 10, 4, 2);
assert!(p.iter().all(|&v| v == 0.0));
let t2 = uniform_grid(50);
let p2 = bspline_penalty_matrix(&t2, 1, 4, 2);
assert!(p2.iter().all(|&v| v == 0.0));
let p3 = bspline_penalty_matrix(&t2, 10, 4, 4);
assert!(p3.iter().all(|&v| v == 0.0));
}
#[test]
fn test_bspline_penalty_nonnegative_diagonal() {
let t = uniform_grid(101);
for nbasis in [5, 10, 20] {
let p = bspline_penalty_matrix(&t, nbasis, 4, 2);
let k = (p.len() as f64).sqrt() as usize;
for i in 0..k {
assert!(
p[i + i * k] >= -1e-10,
"Diagonal ({},{}) negative for nbasis={}: {}",
i,
i,
nbasis,
p[i + i * k]
);
}
}
}
#[test]
fn test_fourier_penalty_increasing_with_frequency() {
let penalty = fourier_penalty_matrix(11, 1.0, 2);
let k = 11;
assert!(penalty[0].abs() < 1e-15);
let mut prev_eigenval = 0.0;
for freq in 1..=5 {
let idx_sin = 2 * freq - 1;
let eigenval = penalty[idx_sin + idx_sin * k];
assert!(
eigenval > prev_eigenval,
"Higher frequency should have larger penalty: freq={}, eigenval={}, prev={}",
freq,
eigenval,
prev_eigenval
);
prev_eigenval = eigenval;
let idx_cos = 2 * freq;
if idx_cos < k {
assert!(
(penalty[idx_cos + idx_cos * k] - eigenval).abs() < 1e-10,
"Sin and cos penalty should match at freq {}",
freq
);
}
}
}
#[test]
fn test_fourier_penalty_different_periods() {
let p1 = fourier_penalty_matrix(7, 1.0, 2);
let p2 = fourier_penalty_matrix(7, 2.0, 2);
for i in 1..7 {
assert!(
p2[i + i * 7] < p1[i + i * 7] || (p1[i + i * 7] == 0.0 && p2[i + i * 7] == 0.0),
"Longer period should have smaller penalties at i={}",
i
);
}
}
#[test]
fn test_fourier_penalty_first_order() {
let p = fourier_penalty_matrix(5, 1.0, 1);
let omega1 = 2.0 * PI;
let expected1 = omega1.powi(2);
assert!(
(p[1 + 5] - expected1).abs() < 1e-6,
"First-order penalty eigenval: got {}, expected {}",
p[1 + 5],
expected1
);
}
#[test]
fn test_fourier_penalty_zero_nbasis() {
let p = fourier_penalty_matrix(0, 1.0, 2);
assert!(p.is_empty());
}
#[test]
fn test_fourier_penalty_nbasis_one() {
let p = fourier_penalty_matrix(1, 1.0, 2);
assert_eq!(p.len(), 1);
assert!(p[0].abs() < 1e-15); }
#[test]
fn test_smooth_basis_gcv_returns_valid_result() {
let (data, t) = make_test_data(5, 50);
let bt = BasisType::Bspline { order: 4 };
let result = smooth_basis_gcv(&data, &t, &bt, 12, 2, (-6.0, 2.0), 20);
assert!(result.is_some());
let res = result.unwrap();
assert_eq!(res.fitted.shape(), (5, 50));
assert!(res.gcv.is_finite());
assert!(res.edf > 0.0);
}
#[test]
fn test_smooth_basis_gcv_fourier() {
let m = 80;
let t = uniform_grid(m);
let mut data = FdMatrix::zeros(3, m);
for i in 0..3 {
for j in 0..m {
data[(i, j)] = (2.0 * PI * t[j]).sin() + 0.5 * (4.0 * PI * t[j]).cos();
}
}
let bt = BasisType::Fourier { period: 1.0 };
let result = smooth_basis_gcv(&data, &t, &bt, 9, 2, (-8.0, 4.0), 25);
assert!(result.is_some());
let res = result.unwrap();
assert_eq!(res.fitted.nrows(), 3);
assert_eq!(res.nbasis, 9);
}
#[test]
fn test_smooth_basis_gcv_selects_finite_gcv() {
let (data, t) = make_test_data(5, 60);
let bt = BasisType::Bspline { order: 4 };
let res = smooth_basis_gcv(&data, &t, &bt, 12, 2, (-6.0, 2.0), 15).unwrap();
assert!(res.gcv.is_finite());
assert!(res.gcv > 0.0);
}
#[test]
fn test_smooth_basis_gcv_empty_data() {
let data = FdMatrix::zeros(0, 50);
let t = uniform_grid(50);
let bt = BasisType::Bspline { order: 4 };
let result = smooth_basis_gcv(&data, &t, &bt, 10, 2, (-6.0, 2.0), 10);
assert!(result.is_none());
}
#[test]
fn test_smooth_basis_gcv_empty_argvals() {
let data = FdMatrix::zeros(5, 0);
let bt = BasisType::Bspline { order: 4 };
let result = smooth_basis_gcv(&data, &[], &bt, 10, 2, (-6.0, 2.0), 10);
assert!(result.is_none());
}
#[test]
fn test_smooth_basis_gcv_nbasis_too_small() {
let (data, t) = make_test_data(5, 50);
let bt = BasisType::Bspline { order: 4 };
let result = smooth_basis_gcv(&data, &t, &bt, 1, 2, (-6.0, 2.0), 10);
assert!(result.is_none());
}
#[test]
fn test_smooth_basis_gcv_ngrid_too_small() {
let (data, t) = make_test_data(5, 50);
let bt = BasisType::Bspline { order: 4 };
let result = smooth_basis_gcv(&data, &t, &bt, 10, 2, (-6.0, 2.0), 1);
assert!(result.is_none());
}
#[test]
fn test_smooth_basis_gcv_narrow_range() {
let (data, t) = make_test_data(3, 50);
let bt = BasisType::Bspline { order: 4 };
let result = smooth_basis_gcv(&data, &t, &bt, 10, 2, (-3.0, -2.0), 5);
assert!(result.is_some());
}
#[test]
fn test_smooth_basis_gcv_wide_range() {
let (data, t) = make_test_data(3, 50);
let bt = BasisType::Bspline { order: 4 };
let result = smooth_basis_gcv(&data, &t, &bt, 10, 2, (-12.0, 8.0), 30);
assert!(result.is_some());
}
#[test]
fn test_basis_nbasis_cv_scores_length() {
let (data, t) = make_test_data(5, 50);
let nbasis_range: Vec<usize> = vec![4, 6, 8, 10, 12];
let res = basis_nbasis_cv(
&data,
&t,
&nbasis_range,
&BasisType::Bspline { order: 4 },
BasisCriterion::Gcv,
5,
1e-4,
)
.unwrap();
assert_eq!(res.scores.len(), 5);
assert_eq!(res.nbasis_range.len(), 5);
assert_eq!(res.nbasis_range, nbasis_range);
}
#[test]
fn test_basis_nbasis_cv_optimal_within_range() {
let (data, t) = make_test_data(8, 50);
let nbasis_range: Vec<usize> = vec![5, 7, 9, 11, 13, 15];
for criterion in [
BasisCriterion::Gcv,
BasisCriterion::Aic,
BasisCriterion::Bic,
] {
let res = basis_nbasis_cv(
&data,
&t,
&nbasis_range,
&BasisType::Bspline { order: 4 },
criterion,
5,
1e-4,
)
.unwrap();
assert!(
nbasis_range.contains(&res.optimal_nbasis),
"optimal_nbasis {} not in range for {:?}",
res.optimal_nbasis,
criterion
);
}
}
#[test]
fn test_basis_nbasis_cv_fourier_gcv() {
let m = 80;
let t = uniform_grid(m);
let mut data = FdMatrix::zeros(5, m);
for i in 0..5 {
for j in 0..m {
data[(i, j)] = (2.0 * PI * t[j]).sin()
+ 0.3 * (4.0 * PI * t[j]).cos()
+ 0.02 * ((i * 7 + j * 3) % 10) as f64;
}
}
let nbasis_range: Vec<usize> = vec![5, 7, 9, 11];
let res = basis_nbasis_cv(
&data,
&t,
&nbasis_range,
&BasisType::Fourier { period: 1.0 },
BasisCriterion::Gcv,
5,
1e-4,
)
.unwrap();
assert!(nbasis_range.contains(&res.optimal_nbasis));
}
#[test]
fn test_basis_nbasis_cv_fourier_cv() {
let m = 60;
let t = uniform_grid(m);
let n = 10;
let mut data = FdMatrix::zeros(n, m);
for i in 0..n {
for j in 0..m {
data[(i, j)] = (2.0 * PI * t[j]).sin() + 0.02 * ((i * 11 + j) % 15) as f64;
}
}
let nbasis_range: Vec<usize> = vec![5, 7, 9];
let res = basis_nbasis_cv(
&data,
&t,
&nbasis_range,
&BasisType::Fourier { period: 1.0 },
BasisCriterion::Cv,
5,
1e-4,
)
.unwrap();
assert!(nbasis_range.contains(&res.optimal_nbasis));
assert_eq!(res.criterion, BasisCriterion::Cv);
}
#[test]
fn test_basis_nbasis_cv_with_nbasis_below_minimum() {
let (data, t) = make_test_data(5, 50);
let nbasis_range: Vec<usize> = vec![1, 5, 10];
let res = basis_nbasis_cv(
&data,
&t,
&nbasis_range,
&BasisType::Bspline { order: 4 },
BasisCriterion::Gcv,
5,
1e-4,
)
.unwrap();
assert!(
res.optimal_nbasis >= 5,
"Should skip invalid nbasis=1, got optimal={}",
res.optimal_nbasis
);
assert!(res.scores[0].is_infinite());
}
#[test]
fn test_basis_nbasis_cv_empty_range() {
let (data, t) = make_test_data(5, 50);
let nbasis_range: Vec<usize> = vec![];
let result = basis_nbasis_cv(
&data,
&t,
&nbasis_range,
&BasisType::Bspline { order: 4 },
BasisCriterion::Gcv,
5,
1e-4,
);
assert!(result.is_none());
}
#[test]
fn test_basis_nbasis_cv_empty_data() {
let data = FdMatrix::zeros(0, 50);
let t = uniform_grid(50);
let nbasis_range: Vec<usize> = vec![5, 10];
let result = basis_nbasis_cv(
&data,
&t,
&nbasis_range,
&BasisType::Bspline { order: 4 },
BasisCriterion::Gcv,
5,
1e-4,
);
assert!(result.is_none());
}
#[test]
fn test_basis_nbasis_cv_mismatched_argvals() {
let data = FdMatrix::zeros(5, 50);
let t = uniform_grid(40); let nbasis_range: Vec<usize> = vec![5, 10];
let result = basis_nbasis_cv(
&data,
&t,
&nbasis_range,
&BasisType::Bspline { order: 4 },
BasisCriterion::Gcv,
5,
1e-4,
);
assert!(result.is_none());
}
#[test]
fn test_basis_nbasis_cv_single_nbasis() {
let (data, t) = make_test_data(5, 50);
let nbasis_range: Vec<usize> = vec![10];
let res = basis_nbasis_cv(
&data,
&t,
&nbasis_range,
&BasisType::Bspline { order: 4 },
BasisCriterion::Gcv,
5,
1e-4,
)
.unwrap();
assert_eq!(res.optimal_nbasis, 10);
assert_eq!(res.scores.len(), 1);
}
#[test]
fn test_basis_nbasis_cv_bic_penalizes_more_than_aic() {
let (data, t) = make_test_data(5, 80);
let nbasis_range: Vec<usize> = (4..=20).step_by(2).collect();
let aic_res = basis_nbasis_cv(
&data,
&t,
&nbasis_range,
&BasisType::Bspline { order: 4 },
BasisCriterion::Aic,
5,
1e-4,
)
.unwrap();
let bic_res = basis_nbasis_cv(
&data,
&t,
&nbasis_range,
&BasisType::Bspline { order: 4 },
BasisCriterion::Bic,
5,
1e-4,
)
.unwrap();
assert!(
bic_res.optimal_nbasis <= aic_res.optimal_nbasis + 4,
"BIC selected {} vs AIC selected {} -- BIC should not select much more than AIC",
bic_res.optimal_nbasis,
aic_res.optimal_nbasis
);
}
#[test]
fn test_smooth_basis_fitted_close_to_data() {
let m = 50;
let n = 3;
let t = uniform_grid(m);
let mut data = FdMatrix::zeros(n, m);
for i in 0..n {
for j in 0..m {
data[(i, j)] = (2.0 * PI * t[j]).sin();
}
}
let fdpar = make_bspline_fdpar(&t, 15, 1e-6);
let res = smooth_basis(&data, &t, &fdpar).unwrap();
let mut max_err = 0.0_f64;
for i in 0..n {
for j in 0..m {
let err = (data[(i, j)] - res.fitted[(i, j)]).abs();
max_err = max_err.max(err);
}
}
assert!(
max_err < 0.1,
"Fitted should be close to smooth data; max_err={}",
max_err
);
}
#[test]
fn test_smooth_basis_constant_data() {
let m = 50;
let n = 2;
let t = uniform_grid(m);
let mut data = FdMatrix::zeros(n, m);
for i in 0..n {
for j in 0..m {
data[(i, j)] = 3.15;
}
}
let fdpar = make_bspline_fdpar(&t, 10, 1e-4);
let res = smooth_basis(&data, &t, &fdpar).unwrap();
for i in 0..n {
for j in 0..m {
assert!(
(res.fitted[(i, j)] - 3.15).abs() < 0.01,
"Constant data should be fit well at ({},{}): got {}",
i,
j,
res.fitted[(i, j)]
);
}
}
}
#[test]
fn test_smooth_basis_linear_data() {
let m = 50;
let t = uniform_grid(m);
let mut data = FdMatrix::zeros(1, m);
for j in 0..m {
data[(0, j)] = 2.0 * t[j] + 1.0;
}
let fdpar = make_bspline_fdpar(&t, 10, 1e-4);
let res = smooth_basis(&data, &t, &fdpar).unwrap();
for j in 0..m {
let expected = 2.0 * t[j] + 1.0;
assert!(
(res.fitted[(0, j)] - expected).abs() < 0.05,
"Linear data should be fit well at j={}: got {}, expected {}",
j,
res.fitted[(0, j)],
expected
);
}
}
#[test]
fn test_smooth_basis_edf_bounded() {
let m = 50;
let (data, t) = make_test_data(3, m);
let fdpar = make_bspline_fdpar(&t, 12, 1e-4);
let res = smooth_basis(&data, &t, &fdpar).unwrap();
assert!(
res.edf > 0.0 && res.edf <= m as f64,
"EDF should be in (0, {}]; got {}",
m,
res.edf
);
}
#[test]
fn test_smooth_basis_gcv_aic_bic_all_finite() {
let (data, t) = make_test_data(4, 60);
let fdpar = make_bspline_fdpar(&t, 12, 1e-3);
let res = smooth_basis(&data, &t, &fdpar).unwrap();
assert!(res.gcv.is_finite(), "GCV should be finite: {}", res.gcv);
assert!(res.aic.is_finite(), "AIC should be finite: {}", res.aic);
assert!(res.bic.is_finite(), "BIC should be finite: {}", res.bic);
}
#[test]
fn test_smooth_basis_penalty_matrix_in_result() {
let (data, t) = make_test_data(3, 50);
let nbasis = 10;
let fdpar = make_bspline_fdpar(&t, nbasis, 1e-4);
let res = smooth_basis(&data, &t, &fdpar).unwrap();
let k = res.nbasis;
assert_eq!(
res.penalty_matrix.len(),
k * k,
"Penalty matrix should be k*k = {}*{} = {}; got {}",
k,
k,
k * k,
res.penalty_matrix.len()
);
}
#[test]
fn test_smooth_basis_identical_curves_same_coefficients() {
let m = 50;
let t = uniform_grid(m);
let curve: Vec<f64> = (0..m).map(|j| (2.0 * PI * t[j]).sin()).collect();
let n = 4;
let mut data = FdMatrix::zeros(n, m);
for i in 0..n {
for j in 0..m {
data[(i, j)] = curve[j];
}
}
let fdpar = make_bspline_fdpar(&t, 10, 1e-4);
let res = smooth_basis(&data, &t, &fdpar).unwrap();
let k = res.coefficients.ncols();
for i in 1..n {
for j in 0..k {
assert!(
(res.coefficients[(i, j)] - res.coefficients[(0, j)]).abs() < 1e-10,
"Identical curves should have identical coefficients: curve {} col {} differs",
i,
j
);
}
}
}
#[test]
fn test_basis_nbasis_cv_different_nfolds() {
let (data, t) = make_test_data(12, 50);
let nbasis_range: Vec<usize> = vec![5, 8, 11];
for nfolds in [2, 3, 5, 10] {
let res = basis_nbasis_cv(
&data,
&t,
&nbasis_range,
&BasisType::Bspline { order: 4 },
BasisCriterion::Cv,
nfolds,
1e-4,
);
assert!(res.is_some(), "CV should succeed with nfolds={}", nfolds);
let r = res.unwrap();
assert!(nbasis_range.contains(&r.optimal_nbasis));
}
}
#[test]
fn test_smooth_basis_many_basis_functions() {
let m = 100;
let (data, t) = make_test_data(2, m);
let fdpar = make_bspline_fdpar(&t, 40, 1e-2);
let res = smooth_basis(&data, &t, &fdpar);
assert!(
res.is_ok(),
"Should handle many basis functions with penalty"
);
}
#[test]
fn test_smooth_basis_bspline_vs_fourier_different_results() {
let m = 50;
let (data, t) = make_test_data(2, m);
let fdpar_bs = make_bspline_fdpar(&t, 9, 1e-4);
let fdpar_f = make_fourier_fdpar(9, 1.0, 1e-4);
let res_bs = smooth_basis(&data, &t, &fdpar_bs).unwrap();
let res_f = smooth_basis(&data, &t, &fdpar_f).unwrap();
let diff: f64 = (0..m)
.map(|j| (res_bs.fitted[(0, j)] - res_f.fitted[(0, j)]).abs())
.sum();
assert!(
diff > 1e-10,
"B-spline and Fourier fits should differ for the same data"
);
}
#[test]
fn test_smooth_basis_gcv_positive_for_noisy_data() {
let m = 50;
let t = uniform_grid(m);
let mut data = FdMatrix::zeros(1, m);
for j in 0..m {
data[(0, j)] = (2.0 * PI * t[j]).sin() + 0.5 * ((j * 37) % 20) as f64 / 20.0 - 0.25;
}
let fdpar = make_bspline_fdpar(&t, 10, 1e-3);
let res = smooth_basis(&data, &t, &fdpar).unwrap();
assert!(res.gcv > 0.0, "GCV should be positive for noisy data");
}
#[test]
fn test_smooth_basis_different_lfd_orders() {
let m = 50;
let (data, t) = make_test_data(2, m);
let penalty1 = bspline_penalty_matrix(&t, 10, 4, 1);
let fdpar1 = FdPar {
basis_type: BasisType::Bspline { order: 4 },
nbasis: 10,
lambda: 1e-2,
lfd_order: 1,
penalty_matrix: penalty1,
};
let res1 = smooth_basis(&data, &t, &fdpar1);
assert!(res1.is_ok());
let penalty2 = bspline_penalty_matrix(&t, 10, 4, 2);
let fdpar2 = FdPar {
basis_type: BasisType::Bspline { order: 4 },
nbasis: 10,
lambda: 1e-2,
lfd_order: 2,
penalty_matrix: penalty2,
};
let res2 = smooth_basis(&data, &t, &fdpar2);
assert!(res2.is_ok());
let r1 = res1.unwrap();
let r2 = res2.unwrap();
let diff: f64 = (0..m)
.map(|j| (r1.fitted[(0, j)] - r2.fitted[(0, j)]).abs())
.sum();
assert!(
diff > 1e-10,
"Different lfd_orders should produce different fits"
);
}
#[test]
fn test_basis_nbasis_cv_result_fields() {
let (data, t) = make_test_data(6, 50);
let nbasis_range: Vec<usize> = vec![5, 7, 9, 11, 13];
let res = basis_nbasis_cv(
&data,
&t,
&nbasis_range,
&BasisType::Bspline { order: 4 },
BasisCriterion::Aic,
5,
1e-4,
)
.unwrap();
assert!(nbasis_range.contains(&res.optimal_nbasis));
assert_eq!(res.scores.len(), nbasis_range.len());
assert_eq!(res.nbasis_range, nbasis_range);
assert_eq!(res.criterion, BasisCriterion::Aic);
let min_score = res.scores.iter().copied().fold(f64::INFINITY, f64::min);
let best_idx = res
.scores
.iter()
.position(|&s| (s - min_score).abs() < 1e-15)
.unwrap();
assert_eq!(res.optimal_nbasis, nbasis_range[best_idx]);
}
#[test]
fn test_basis_nbasis_cv_result_clone() {
let (data, t) = make_test_data(5, 50);
let nbasis_range: Vec<usize> = vec![5, 10];
let res = basis_nbasis_cv(
&data,
&t,
&nbasis_range,
&BasisType::Bspline { order: 4 },
BasisCriterion::Gcv,
5,
1e-4,
)
.unwrap();
let cloned = res.clone();
assert_eq!(res, cloned);
}
#[test]
fn test_smooth_basis_nonuniform_argvals() {
let m = 50;
let t: Vec<f64> = (0..m)
.map(|i| {
let x = i as f64 / (m - 1) as f64;
0.5 * (1.0 - (PI * x).cos())
})
.collect();
let mut data = FdMatrix::zeros(2, m);
for i in 0..2 {
for j in 0..m {
data[(i, j)] = (2.0 * PI * t[j]).sin() + 0.1 * i as f64;
}
}
let fdpar = make_bspline_fdpar(&t, 10, 1e-4);
let res = smooth_basis(&data, &t, &fdpar);
assert!(res.is_ok(), "Should handle non-uniform argvals");
let r = res.unwrap();
assert_eq!(r.fitted.shape(), (2, m));
}
#[test]
fn test_smooth_basis_very_small_lambda() {
let m = 50;
let (data, t) = make_test_data(2, m);
let fdpar = make_bspline_fdpar(&t, 10, 1e-15);
let res = smooth_basis(&data, &t, &fdpar);
assert!(res.is_ok(), "Should handle very small lambda");
}
#[test]
fn test_smooth_basis_very_large_lambda() {
let m = 50;
let (data, t) = make_test_data(2, m);
let fdpar = make_bspline_fdpar(&t, 10, 1e10);
let res = smooth_basis(&data, &t, &fdpar);
assert!(res.is_ok(), "Should handle very large lambda");
}
#[test]
fn test_smooth_basis_multi_curve_vs_single_curve() {
let m = 50;
let n = 3;
let (data, t) = make_test_data(n, m);
let fdpar = make_bspline_fdpar(&t, 10, 1e-3);
let res_all = smooth_basis(&data, &t, &fdpar).unwrap();
for i in 0..n {
let mut single = FdMatrix::zeros(1, m);
for j in 0..m {
single[(0, j)] = data[(i, j)];
}
let res_single = smooth_basis(&single, &t, &fdpar).unwrap();
for j in 0..m {
assert!(
(res_all.fitted[(i, j)] - res_single.fitted[(0, j)]).abs() < 1e-10,
"Multi-curve fit should match single-curve fit: curve {} point {}",
i,
j
);
}
}
}
#[test]
fn test_basis_nbasis_cv_all_criteria_finite_scores() {
let (data, t) = make_test_data(10, 60);
let nbasis_range: Vec<usize> = vec![5, 7, 9, 11];
for criterion in [
BasisCriterion::Gcv,
BasisCriterion::Aic,
BasisCriterion::Bic,
BasisCriterion::Cv,
] {
let res = basis_nbasis_cv(
&data,
&t,
&nbasis_range,
&BasisType::Bspline { order: 4 },
criterion,
5,
1e-4,
)
.unwrap();
let finite_count = res.scores.iter().filter(|s| s.is_finite()).count();
assert!(
finite_count > 0,
"At least one score should be finite for {:?}",
criterion
);
}
}
#[test]
fn test_smooth_basis_gcv_config_default() {
let config = SmoothBasisGcvConfig::default();
assert_eq!(config.basis_type, BasisType::Bspline { order: 4 });
assert_eq!(config.nbasis, 15);
assert_eq!(config.lfd_order, 2);
assert_eq!(config.log_lambda_range, (-10.0, 2.0));
assert_eq!(config.n_grid, 50);
}
#[test]
fn test_smooth_basis_gcv_config_clone_eq() {
let config = SmoothBasisGcvConfig {
nbasis: 20,
..SmoothBasisGcvConfig::default()
};
let cloned = config.clone();
assert_eq!(config, cloned);
}
#[test]
fn test_smooth_basis_gcv_config_debug() {
let config = SmoothBasisGcvConfig::default();
let debug_str = format!("{:?}", config);
assert!(debug_str.contains("SmoothBasisGcvConfig"));
assert!(debug_str.contains("nbasis"));
}
#[test]
fn test_smooth_basis_gcv_config_partial_override() {
let config = SmoothBasisGcvConfig {
basis_type: BasisType::Fourier { period: 2.0 },
n_grid: 100,
..SmoothBasisGcvConfig::default()
};
assert_eq!(config.basis_type, BasisType::Fourier { period: 2.0 });
assert_eq!(config.n_grid, 100);
assert_eq!(config.nbasis, 15);
assert_eq!(config.lfd_order, 2);
}
#[test]
fn test_smooth_basis_gcv_with_config_default() {
let (data, t) = make_test_data(5, 101);
let config = SmoothBasisGcvConfig::default();
let result = smooth_basis_gcv_with_config(&data, &t, &config);
assert!(result.is_ok(), "GCV with default config should succeed");
let res = result.unwrap();
assert_eq!(res.fitted.shape(), (5, 101));
assert!(res.edf > 0.0);
assert!(res.gcv.is_finite());
}
#[test]
fn test_smooth_basis_gcv_with_config_custom() {
let (data, t) = make_test_data(3, 50);
let config = SmoothBasisGcvConfig {
nbasis: 10,
log_lambda_range: (-6.0, 0.0),
n_grid: 15,
..SmoothBasisGcvConfig::default()
};
let result = smooth_basis_gcv_with_config(&data, &t, &config);
assert!(result.is_ok());
}
#[test]
fn test_smooth_basis_gcv_with_config_matches_direct() {
let (data, t) = make_test_data(3, 50);
let config = SmoothBasisGcvConfig {
nbasis: 10,
log_lambda_range: (-6.0, 0.0),
n_grid: 20,
..SmoothBasisGcvConfig::default()
};
let with_config = smooth_basis_gcv_with_config(&data, &t, &config).unwrap();
let direct = smooth_basis_gcv(
&data,
&t,
&config.basis_type,
config.nbasis,
config.lfd_order,
config.log_lambda_range,
config.n_grid,
)
.unwrap();
assert_eq!(with_config.gcv, direct.gcv);
assert_eq!(with_config.edf, direct.edf);
assert_eq!(with_config.nbasis, direct.nbasis);
}
#[test]
fn test_smooth_basis_gcv_with_config_fourier() {
let m = 100;
let t = uniform_grid(m);
let mut data = FdMatrix::zeros(2, m);
for i in 0..2 {
for j in 0..m {
data[(i, j)] = (2.0 * PI * t[j]).sin() + (4.0 * PI * t[j]).cos();
}
}
let config = SmoothBasisGcvConfig {
basis_type: BasisType::Fourier { period: 1.0 },
nbasis: 7,
n_grid: 20,
..SmoothBasisGcvConfig::default()
};
let result = smooth_basis_gcv_with_config(&data, &t, &config);
assert!(result.is_ok());
}
#[test]
fn test_basis_nbasis_cv_config_default() {
let config = BasisNbasisCvConfig::default();
assert_eq!(config.basis_type, BasisType::Bspline { order: 4 });
assert_eq!(config.nbasis_range, (5, 30));
assert!((config.lambda - 1e-4).abs() < 1e-15);
assert_eq!(config.lfd_order, 2);
assert_eq!(config.n_folds, 5);
assert_eq!(config.criterion, BasisCriterion::Gcv);
}
#[test]
fn test_basis_nbasis_cv_config_clone_eq() {
let config = BasisNbasisCvConfig {
nbasis_range: (4, 15),
..BasisNbasisCvConfig::default()
};
let cloned = config.clone();
assert_eq!(config, cloned);
}
#[test]
fn test_basis_nbasis_cv_config_debug() {
let config = BasisNbasisCvConfig::default();
let debug_str = format!("{:?}", config);
assert!(debug_str.contains("BasisNbasisCvConfig"));
assert!(debug_str.contains("nbasis_range"));
}
#[test]
fn test_basis_nbasis_cv_config_partial_override() {
let config = BasisNbasisCvConfig {
criterion: BasisCriterion::Aic,
lambda: 1e-2,
..BasisNbasisCvConfig::default()
};
assert_eq!(config.criterion, BasisCriterion::Aic);
assert!((config.lambda - 1e-2).abs() < 1e-15);
assert_eq!(config.nbasis_range, (5, 30));
assert_eq!(config.n_folds, 5);
}
#[test]
fn test_basis_nbasis_cv_with_config_default() {
let (data, t) = make_test_data(5, 51);
let config = BasisNbasisCvConfig {
nbasis_range: (5, 12),
..BasisNbasisCvConfig::default()
};
let result = basis_nbasis_cv_with_config(&data, &t, &config);
assert!(
result.is_ok(),
"nbasis CV with default config should succeed"
);
let res = result.unwrap();
assert!(res.optimal_nbasis >= 5 && res.optimal_nbasis <= 12);
assert_eq!(res.scores.len(), 8); assert_eq!(res.criterion, BasisCriterion::Gcv);
}
#[test]
fn test_basis_nbasis_cv_with_config_aic() {
let (data, t) = make_test_data(5, 51);
let config = BasisNbasisCvConfig {
nbasis_range: (5, 10),
criterion: BasisCriterion::Aic,
..BasisNbasisCvConfig::default()
};
let result = basis_nbasis_cv_with_config(&data, &t, &config);
assert!(result.is_ok());
assert_eq!(result.unwrap().criterion, BasisCriterion::Aic);
}
#[test]
fn test_basis_nbasis_cv_with_config_cv_folds() {
let (data, t) = make_test_data(10, 51);
let config = BasisNbasisCvConfig {
nbasis_range: (5, 9),
criterion: BasisCriterion::Cv,
n_folds: 3,
..BasisNbasisCvConfig::default()
};
let result = basis_nbasis_cv_with_config(&data, &t, &config);
assert!(result.is_ok());
assert_eq!(result.unwrap().criterion, BasisCriterion::Cv);
}
#[test]
fn test_basis_nbasis_cv_with_config_matches_direct() {
let (data, t) = make_test_data(5, 51);
let config = BasisNbasisCvConfig {
nbasis_range: (5, 10),
criterion: BasisCriterion::Bic,
lambda: 1e-3,
..BasisNbasisCvConfig::default()
};
let with_config = basis_nbasis_cv_with_config(&data, &t, &config).unwrap();
let nbasis_range: Vec<usize> = (5..=10).collect();
let direct = basis_nbasis_cv(
&data,
&t,
&nbasis_range,
&config.basis_type,
config.criterion,
config.n_folds,
config.lambda,
)
.unwrap();
assert_eq!(with_config.optimal_nbasis, direct.optimal_nbasis);
assert_eq!(with_config.scores, direct.scores);
assert_eq!(with_config.nbasis_range, direct.nbasis_range);
}
#[test]
fn test_basis_nbasis_cv_with_config_nbasis_range_expansion() {
let (data, t) = make_test_data(5, 51);
let config = BasisNbasisCvConfig {
nbasis_range: (7, 7), ..BasisNbasisCvConfig::default()
};
let result = basis_nbasis_cv_with_config(&data, &t, &config);
assert!(result.is_ok());
let res = result.unwrap();
assert_eq!(res.optimal_nbasis, 7);
assert_eq!(res.scores.len(), 1);
}
}