1use crate::basis::{bspline_basis, fourier_basis_with_period};
13use crate::helpers::simpsons_weights;
14use crate::matrix::FdMatrix;
15use nalgebra::DMatrix;
16use std::f64::consts::PI;
17
18#[derive(Debug, Clone, PartialEq)]
22pub enum BasisType {
23 Bspline { order: usize },
25 Fourier { period: f64 },
27}
28
29#[derive(Debug, Clone, PartialEq)]
31pub struct FdPar {
32 pub basis_type: BasisType,
34 pub nbasis: usize,
36 pub lambda: f64,
38 pub lfd_order: usize,
40 pub penalty_matrix: Vec<f64>,
42}
43
44#[derive(Debug, Clone, PartialEq)]
46#[non_exhaustive]
47pub struct SmoothBasisResult {
48 pub coefficients: FdMatrix,
50 pub fitted: FdMatrix,
52 pub edf: f64,
54 pub gcv: f64,
56 pub aic: f64,
58 pub bic: f64,
60 pub penalty_matrix: Vec<f64>,
62 pub nbasis: usize,
64}
65
66pub fn bspline_penalty_matrix(
83 argvals: &[f64],
84 nbasis: usize,
85 order: usize,
86 lfd_order: usize,
87) -> Vec<f64> {
88 if nbasis < 2 || order < 1 || lfd_order >= order || argvals.len() < 2 {
89 return vec![0.0; nbasis * nbasis];
90 }
91
92 let nknots = nbasis.saturating_sub(order).max(2);
93
94 let n_sub = 10;
96 let t_min = argvals[0];
97 let t_max = argvals[argvals.len() - 1];
98 let n_quad = (argvals.len() - 1) * n_sub + 1;
99 let quad_t: Vec<f64> = (0..n_quad)
100 .map(|i| t_min + (t_max - t_min) * i as f64 / (n_quad - 1) as f64)
101 .collect();
102
103 let basis_fine = bspline_basis(&quad_t, nknots, order);
105 let actual_nbasis = basis_fine.len() / n_quad;
106
107 let h = (t_max - t_min) / (n_quad - 1) as f64;
109 let deriv_basis = differentiate_basis_columns(&basis_fine, n_quad, actual_nbasis, h, lfd_order);
110
111 let weights = simpsons_weights(&quad_t);
113
114 integrate_symmetric_penalty(&deriv_basis, &weights, actual_nbasis, n_quad)
116}
117
118pub fn fourier_penalty_matrix(nbasis: usize, period: f64, lfd_order: usize) -> Vec<f64> {
130 let k = nbasis;
131 let mut penalty = vec![0.0; k * k];
132
133 let mut freq = 1;
139 let mut idx = 1;
140 while idx < k {
141 let omega = 2.0 * PI * f64::from(freq) / period;
142 let eigenval = omega.powi(2 * lfd_order as i32);
143
144 if idx < k {
146 penalty[idx + idx * k] = eigenval;
147 idx += 1;
148 }
149 if idx < k {
151 penalty[idx + idx * k] = eigenval;
152 idx += 1;
153 }
154 freq += 1;
155 }
156
157 penalty
158}
159
160pub fn smooth_basis(
175 data: &FdMatrix,
176 argvals: &[f64],
177 fdpar: &FdPar,
178) -> Result<SmoothBasisResult, crate::FdarError> {
179 let (n, m) = data.shape();
180 if n == 0 || m == 0 || argvals.len() != m || fdpar.nbasis < 2 {
181 return Err(crate::FdarError::InvalidDimension {
182 parameter: "data/argvals/fdpar",
183 expected: "n > 0, m > 0, argvals.len() == m, nbasis >= 2".to_string(),
184 actual: format!(
185 "n={}, m={}, argvals.len()={}, nbasis={}",
186 n,
187 m,
188 argvals.len(),
189 fdpar.nbasis
190 ),
191 });
192 }
193
194 let (basis_flat, actual_nbasis) = evaluate_basis(argvals, &fdpar.basis_type, fdpar.nbasis);
196 let k = actual_nbasis;
197
198 let b_mat = DMatrix::from_column_slice(m, k, &basis_flat);
199 let r_mat = DMatrix::from_column_slice(k, k, &fdpar.penalty_matrix);
200
201 let btb = b_mat.transpose() * &b_mat;
203 let ridge_eps = 1e-10;
204 let system: DMatrix<f64> =
205 &btb + fdpar.lambda * &r_mat + ridge_eps * DMatrix::<f64>::identity(k, k);
206
207 let system_inv =
209 invert_penalized_system(&system, k).ok_or_else(|| crate::FdarError::ComputationFailed {
210 operation: "matrix inversion",
211 detail: "failed to invert penalized system (Φ'Φ + λR); try increasing lambda or reducing the number of basis functions".to_string(),
212 })?;
213
214 let h_mat = &b_mat * &system_inv * b_mat.transpose();
216 let edf: f64 = (0..m).map(|i| h_mat[(i, i)]).sum();
217
218 let proj = &system_inv * b_mat.transpose();
220 let (all_coefs, all_fitted, total_rss) = project_all_curves(data, &b_mat, &proj, n, m, k);
221
222 let total_points = (n * m) as f64;
223 let gcv = compute_gcv(total_rss, total_points, edf, m);
224 let mse = total_rss / total_points;
225 let total_edf = n as f64 * edf;
227 let aic = total_points * mse.max(1e-300).ln() + 2.0 * total_edf;
228 let bic = total_points * mse.max(1e-300).ln() + total_points.ln() * total_edf;
229
230 Ok(SmoothBasisResult {
231 coefficients: all_coefs,
232 fitted: all_fitted,
233 edf,
234 gcv,
235 aic,
236 bic,
237 penalty_matrix: fdpar.penalty_matrix.clone(),
238 nbasis: k,
239 })
240}
241
242pub fn smooth_basis_gcv(
255 data: &FdMatrix,
256 argvals: &[f64],
257 basis_type: &BasisType,
258 nbasis: usize,
259 lfd_order: usize,
260 log_lambda_range: (f64, f64),
261 n_grid: usize,
262) -> Option<SmoothBasisResult> {
263 let m = argvals.len();
264 if m == 0 || nbasis < 2 || n_grid < 2 {
265 return None;
266 }
267
268 let penalty = match basis_type {
270 BasisType::Bspline { order } => bspline_penalty_matrix(argvals, nbasis, *order, lfd_order),
271 BasisType::Fourier { period } => fourier_penalty_matrix(nbasis, *period, lfd_order),
272 };
273
274 let (lo, hi) = log_lambda_range;
275 let mut best_gcv = f64::INFINITY;
276 let mut best_result: Option<SmoothBasisResult> = None;
277
278 for i in 0..n_grid {
279 let log_lam = lo + (hi - lo) * i as f64 / (n_grid - 1) as f64;
280 let lam = 10.0_f64.powf(log_lam);
281
282 let fdpar = FdPar {
283 basis_type: basis_type.clone(),
284 nbasis,
285 lambda: lam,
286 lfd_order,
287 penalty_matrix: penalty.clone(),
288 };
289
290 if let Ok(result) = smooth_basis(data, argvals, &fdpar) {
291 if result.gcv < best_gcv {
292 best_gcv = result.gcv;
293 best_result = Some(result);
294 }
295 }
296 }
297
298 best_result
299}
300
301#[derive(Debug, Clone, PartialEq)]
319pub struct SmoothBasisGcvConfig {
320 pub basis_type: BasisType,
322 pub nbasis: usize,
324 pub lfd_order: usize,
326 pub log_lambda_range: (f64, f64),
328 pub n_grid: usize,
330}
331
332impl Default for SmoothBasisGcvConfig {
333 fn default() -> Self {
334 Self {
335 basis_type: BasisType::Bspline { order: 4 },
336 nbasis: 15,
337 lfd_order: 2,
338 log_lambda_range: (-10.0, 2.0),
339 n_grid: 50,
340 }
341 }
342}
343
344#[must_use = "expensive computation whose result should not be discarded"]
359pub fn smooth_basis_gcv_with_config(
360 data: &FdMatrix,
361 argvals: &[f64],
362 config: &SmoothBasisGcvConfig,
363) -> Result<SmoothBasisResult, crate::FdarError> {
364 smooth_basis_gcv(
365 data,
366 argvals,
367 &config.basis_type,
368 config.nbasis,
369 config.lfd_order,
370 config.log_lambda_range,
371 config.n_grid,
372 )
373 .ok_or_else(|| crate::FdarError::ComputationFailed {
374 operation: "smooth_basis_gcv_with_config",
375 detail: "no valid smoothing result found in GCV lambda search".to_string(),
376 })
377}
378
379#[derive(Debug, Clone, PartialEq)]
395pub struct BasisNbasisCvConfig {
396 pub basis_type: BasisType,
398 pub nbasis_range: (usize, usize),
400 pub lambda: f64,
402 pub lfd_order: usize,
404 pub n_folds: usize,
406 pub criterion: BasisCriterion,
408}
409
410impl Default for BasisNbasisCvConfig {
411 fn default() -> Self {
412 Self {
413 basis_type: BasisType::Bspline { order: 4 },
414 nbasis_range: (5, 30),
415 lambda: 1e-4,
416 lfd_order: 2,
417 n_folds: 5,
418 criterion: BasisCriterion::Gcv,
419 }
420 }
421}
422
423#[must_use = "expensive computation whose result should not be discarded"]
441pub fn basis_nbasis_cv_with_config(
442 data: &FdMatrix,
443 argvals: &[f64],
444 config: &BasisNbasisCvConfig,
445) -> Result<BasisNbasisCvResult, crate::FdarError> {
446 let nbasis_range: Vec<usize> = (config.nbasis_range.0..=config.nbasis_range.1).collect();
447 basis_nbasis_cv(
448 data,
449 argvals,
450 &nbasis_range,
451 &config.basis_type,
452 config.criterion,
453 config.n_folds,
454 config.lambda,
455 )
456 .ok_or_else(|| crate::FdarError::ComputationFailed {
457 operation: "basis_nbasis_cv_with_config",
458 detail: "no valid result found in nbasis CV search".to_string(),
459 })
460}
461
462fn differentiate_basis_columns(
466 basis: &[f64],
467 n_quad: usize,
468 nbasis: usize,
469 h: f64,
470 lfd_order: usize,
471) -> Vec<f64> {
472 let mut deriv = basis.to_vec();
473 for _ in 0..lfd_order {
474 let mut new_deriv = vec![0.0; n_quad * nbasis];
475 for j in 0..nbasis {
476 let col: Vec<f64> = (0..n_quad).map(|i| deriv[i + j * n_quad]).collect();
477 let grad = crate::helpers::gradient_uniform(&col, h);
478 for i in 0..n_quad {
479 new_deriv[i + j * n_quad] = grad[i];
480 }
481 }
482 deriv = new_deriv;
483 }
484 deriv
485}
486
487fn integrate_symmetric_penalty(
489 deriv_basis: &[f64],
490 weights: &[f64],
491 k: usize,
492 n_quad: usize,
493) -> Vec<f64> {
494 let mut penalty = vec![0.0; k * k];
495 for j in 0..k {
496 for l in j..k {
497 let mut val = 0.0;
498 for i in 0..n_quad {
499 val += deriv_basis[i + j * n_quad] * deriv_basis[i + l * n_quad] * weights[i];
500 }
501 penalty[j + l * k] = val;
502 penalty[l + j * k] = val;
503 }
504 }
505 penalty
506}
507
508fn evaluate_basis(argvals: &[f64], basis_type: &BasisType, nbasis: usize) -> (Vec<f64>, usize) {
510 let m = argvals.len();
511 match basis_type {
512 BasisType::Bspline { order } => {
513 let nknots = nbasis.saturating_sub(*order).max(2);
514 let basis = bspline_basis(argvals, nknots, *order);
515 let actual = basis.len() / m;
516 (basis, actual)
517 }
518 BasisType::Fourier { period } => {
519 let basis = fourier_basis_with_period(argvals, nbasis, *period);
520 (basis, nbasis)
521 }
522 }
523}
524
525fn invert_penalized_system(system: &DMatrix<f64>, k: usize) -> Option<DMatrix<f64>> {
527 if let Some(chol) = system.clone().cholesky() {
528 return Some(chol.inverse());
529 }
530 let svd = nalgebra::SVD::new(system.clone(), true, true);
532 let u = svd.u.as_ref()?;
533 let v_t = svd.v_t.as_ref()?;
534 let max_sv: f64 = svd.singular_values.iter().copied().fold(0.0_f64, f64::max);
535 let eps = 1e-10 * max_sv;
536 let mut inv = DMatrix::<f64>::zeros(k, k);
537 for ii in 0..k {
538 for jj in 0..k {
539 let mut sum = 0.0;
540 for s in 0..k.min(svd.singular_values.len()) {
541 if svd.singular_values[s] > eps {
542 sum += v_t[(s, ii)] / svd.singular_values[s] * u[(jj, s)];
543 }
544 }
545 inv[(ii, jj)] = sum;
546 }
547 }
548 Some(inv)
549}
550
551fn project_all_curves(
553 data: &FdMatrix,
554 b_mat: &DMatrix<f64>,
555 proj: &DMatrix<f64>,
556 n: usize,
557 m: usize,
558 k: usize,
559) -> (FdMatrix, FdMatrix, f64) {
560 let mut all_coefs = FdMatrix::zeros(n, k);
561 let mut all_fitted = FdMatrix::zeros(n, m);
562 let mut total_rss = 0.0;
563
564 for i in 0..n {
565 let curve: Vec<f64> = (0..m).map(|j| data[(i, j)]).collect();
566 let y_vec = nalgebra::DVector::from_vec(curve.clone());
567 let coefs = proj * &y_vec;
568
569 for j in 0..k {
570 all_coefs[(i, j)] = coefs[j];
571 }
572 let fitted = b_mat * &coefs;
573 for j in 0..m {
574 all_fitted[(i, j)] = fitted[j];
575 let resid = curve[j] - fitted[j];
576 total_rss += resid * resid;
577 }
578 }
579
580 (all_coefs, all_fitted, total_rss)
581}
582
583fn compute_gcv(rss: f64, n_points: f64, edf: f64, m: usize) -> f64 {
585 let gcv_denom = 1.0 - edf / m as f64;
586 if gcv_denom.abs() > 1e-10 {
587 (rss / n_points) / (gcv_denom * gcv_denom)
588 } else {
589 f64::INFINITY
590 }
591}
592
593#[derive(Debug, Clone, Copy, PartialEq)]
597pub enum BasisCriterion {
598 Gcv,
600 Cv,
602 Aic,
604 Bic,
606}
607
608#[derive(Debug, Clone, PartialEq)]
610#[non_exhaustive]
611pub struct BasisNbasisCvResult {
612 pub optimal_nbasis: usize,
614 pub scores: Vec<f64>,
616 pub nbasis_range: Vec<usize>,
618 pub criterion: BasisCriterion,
620}
621
622fn evaluate_nbasis_info_criterion(
624 data: &FdMatrix,
625 argvals: &[f64],
626 nbasis_range: &[usize],
627 basis_type: &BasisType,
628 criterion: BasisCriterion,
629 lambda: f64,
630) -> Vec<f64> {
631 let mut scores = Vec::with_capacity(nbasis_range.len());
632 for &nb in nbasis_range {
633 if nb < 2 {
634 scores.push(f64::INFINITY);
635 continue;
636 }
637 let penalty = match basis_type {
638 BasisType::Bspline { order } => bspline_penalty_matrix(argvals, nb, *order, 2),
639 BasisType::Fourier { period } => fourier_penalty_matrix(nb, *period, 2),
640 };
641 let fdpar = FdPar {
642 basis_type: basis_type.clone(),
643 nbasis: nb,
644 lambda,
645 lfd_order: 2,
646 penalty_matrix: penalty,
647 };
648 match smooth_basis(data, argvals, &fdpar) {
649 Ok(result) => {
650 let score = match criterion {
651 BasisCriterion::Gcv => result.gcv,
652 BasisCriterion::Aic => result.aic,
653 BasisCriterion::Bic => result.bic,
654 BasisCriterion::Cv => unreachable!(),
655 };
656 scores.push(score);
657 }
658 Err(_) => scores.push(f64::INFINITY),
659 }
660 }
661 scores
662}
663
664fn evaluate_nbasis_cv(
666 data: &FdMatrix,
667 argvals: &[f64],
668 nbasis_range: &[usize],
669 basis_type: &BasisType,
670 lambda: f64,
671 n_folds: usize,
672) -> Vec<f64> {
673 let (n, m) = data.shape();
674 let n_folds = n_folds.max(2);
675 let folds = crate::cv::create_folds(n, n_folds, 42);
676 let mut scores = Vec::with_capacity(nbasis_range.len());
677
678 for &nb in nbasis_range {
679 if nb < 2 {
680 scores.push(f64::INFINITY);
681 continue;
682 }
683 let penalty = match basis_type {
684 BasisType::Bspline { order } => bspline_penalty_matrix(argvals, nb, *order, 2),
685 BasisType::Fourier { period } => fourier_penalty_matrix(nb, *period, 2),
686 };
687
688 let mut total_mse = 0.0;
689 let mut total_count = 0;
690
691 for fold in 0..n_folds {
692 let (train_idx, test_idx) = crate::cv::fold_indices(&folds, fold);
693 if train_idx.is_empty() || test_idx.is_empty() {
694 continue;
695 }
696 let train_data = crate::cv::subset_rows(data, &train_idx);
697 let fdpar = FdPar {
698 basis_type: basis_type.clone(),
699 nbasis: nb,
700 lambda,
701 lfd_order: 2,
702 penalty_matrix: penalty.clone(),
703 };
704
705 if let Ok(train_result) = smooth_basis(&train_data, argvals, &fdpar) {
706 let (basis_flat, actual_k) = evaluate_basis(argvals, basis_type, nb);
707 let b_mat = DMatrix::from_column_slice(m, actual_k, &basis_flat);
708 let r_mat =
709 DMatrix::from_column_slice(actual_k, actual_k, &train_result.penalty_matrix);
710 let btb = b_mat.transpose() * &b_mat;
711 let ridge_eps = 1e-10;
712 let system: DMatrix<f64> = &btb
713 + lambda * &r_mat
714 + ridge_eps * DMatrix::<f64>::identity(actual_k, actual_k);
715
716 if let Some(system_inv) = invert_penalized_system(&system, actual_k) {
717 let proj = &system_inv * b_mat.transpose();
718 for &ti in &test_idx {
719 let curve: Vec<f64> = (0..m).map(|j| data[(ti, j)]).collect();
720 let y_vec = nalgebra::DVector::from_vec(curve.clone());
721 let coefs = &proj * &y_vec;
722 let fitted = &b_mat * &coefs;
723 let mse: f64 =
724 (0..m).map(|j| (curve[j] - fitted[j]).powi(2)).sum::<f64>() / m as f64;
725 total_mse += mse;
726 total_count += 1;
727 }
728 }
729 }
730 }
731
732 if total_count > 0 {
733 scores.push(total_mse / f64::from(total_count));
734 } else {
735 scores.push(f64::INFINITY);
736 }
737 }
738 scores
739}
740
741pub fn basis_nbasis_cv(
744 data: &FdMatrix,
745 argvals: &[f64],
746 nbasis_range: &[usize],
747 basis_type: &BasisType,
748 criterion: BasisCriterion,
749 n_folds: usize,
750 lambda: f64,
751) -> Option<BasisNbasisCvResult> {
752 let (n, m) = data.shape();
753 if n == 0 || m == 0 || argvals.len() != m || nbasis_range.is_empty() {
754 return None;
755 }
756
757 let scores = match criterion {
758 BasisCriterion::Gcv | BasisCriterion::Aic | BasisCriterion::Bic => {
759 evaluate_nbasis_info_criterion(
760 data,
761 argvals,
762 nbasis_range,
763 basis_type,
764 criterion,
765 lambda,
766 )
767 }
768 BasisCriterion::Cv => {
769 evaluate_nbasis_cv(data, argvals, nbasis_range, basis_type, lambda, n_folds)
770 }
771 };
772
773 let (best_idx, _) = scores
774 .iter()
775 .enumerate()
776 .min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))?;
777
778 Some(BasisNbasisCvResult {
779 optimal_nbasis: nbasis_range[best_idx],
780 scores,
781 nbasis_range: nbasis_range.to_vec(),
782 criterion,
783 })
784}
785
786#[cfg(test)]
787mod tests {
788 use super::*;
789 use crate::test_helpers::uniform_grid;
790 use std::f64::consts::PI;
791
792 #[test]
793 fn test_bspline_penalty_matrix_symmetric() {
794 let t = uniform_grid(101);
795 let penalty = bspline_penalty_matrix(&t, 15, 4, 2);
796 let _k = 15; let actual_k = (penalty.len() as f64).sqrt() as usize;
798 for i in 0..actual_k {
799 for j in 0..actual_k {
800 assert!(
801 (penalty[i + j * actual_k] - penalty[j + i * actual_k]).abs() < 1e-10,
802 "Penalty matrix not symmetric at ({}, {})",
803 i,
804 j
805 );
806 }
807 }
808 }
809
810 #[test]
811 fn test_bspline_penalty_matrix_positive_semidefinite() {
812 let t = uniform_grid(101);
813 let penalty = bspline_penalty_matrix(&t, 10, 4, 2);
814 let k = (penalty.len() as f64).sqrt() as usize;
815 for i in 0..k {
817 assert!(
818 penalty[i + i * k] >= -1e-10,
819 "Diagonal element {} is negative: {}",
820 i,
821 penalty[i + i * k]
822 );
823 }
824 }
825
826 #[test]
827 fn test_fourier_penalty_diagonal() {
828 let penalty = fourier_penalty_matrix(7, 1.0, 2);
829 for i in 0..7 {
831 for j in 0..7 {
832 if i != j {
833 assert!(
834 penalty[i + j * 7].abs() < 1e-10,
835 "Off-diagonal ({},{}) = {}",
836 i,
837 j,
838 penalty[i + j * 7]
839 );
840 }
841 }
842 }
843 assert!(penalty[0].abs() < 1e-10);
845 assert!(penalty[1 + 7] > 0.0);
847 assert!(penalty[3 + 3 * 7] > penalty[1 + 7]);
848 }
849
850 #[test]
851 fn test_smooth_basis_bspline() {
852 let m = 101;
853 let n = 5;
854 let t = uniform_grid(m);
855
856 let mut data = FdMatrix::zeros(n, m);
858 for i in 0..n {
859 for j in 0..m {
860 data[(i, j)] = (2.0 * PI * t[j]).sin() + 0.1 * (i as f64 * 0.3 + j as f64 * 0.01);
861 }
862 }
863
864 let nbasis = 15;
865 let penalty = bspline_penalty_matrix(&t, nbasis, 4, 2);
866 let _actual_k = (penalty.len() as f64).sqrt() as usize;
867
868 let fdpar = FdPar {
869 basis_type: BasisType::Bspline { order: 4 },
870 nbasis,
871 lambda: 1e-4,
872 lfd_order: 2,
873 penalty_matrix: penalty,
874 };
875
876 let result = smooth_basis(&data, &t, &fdpar);
877 assert!(result.is_ok(), "smooth_basis should succeed");
878
879 let res = result.unwrap();
880 assert_eq!(res.fitted.shape(), (n, m));
881 assert_eq!(res.coefficients.nrows(), n);
882 assert!(res.edf > 0.0, "EDF should be positive");
883 assert!(res.gcv > 0.0, "GCV should be positive");
884 }
885
886 #[test]
887 fn test_smooth_basis_fourier() {
888 let m = 101;
889 let n = 3;
890 let t = uniform_grid(m);
891
892 let mut data = FdMatrix::zeros(n, m);
893 for i in 0..n {
894 for j in 0..m {
895 data[(i, j)] = (2.0 * PI * t[j]).sin() + (4.0 * PI * t[j]).cos();
896 }
897 }
898
899 let nbasis = 7;
900 let period = 1.0;
901 let penalty = fourier_penalty_matrix(nbasis, period, 2);
902
903 let fdpar = FdPar {
904 basis_type: BasisType::Fourier { period },
905 nbasis,
906 lambda: 1e-6,
907 lfd_order: 2,
908 penalty_matrix: penalty,
909 };
910
911 let result = smooth_basis(&data, &t, &fdpar);
912 assert!(result.is_ok());
913
914 let res = result.unwrap();
915 for j in 0..m {
917 let expected = (2.0 * PI * t[j]).sin() + (4.0 * PI * t[j]).cos();
918 assert!(
919 (res.fitted[(0, j)] - expected).abs() < 0.1,
920 "Fourier fit poor at j={}: got {}, expected {}",
921 j,
922 res.fitted[(0, j)],
923 expected
924 );
925 }
926 }
927
928 #[test]
929 fn test_smooth_basis_gcv_selects_reasonable_lambda() {
930 let m = 101;
931 let n = 5;
932 let t = uniform_grid(m);
933
934 let mut data = FdMatrix::zeros(n, m);
935 for i in 0..n {
936 for j in 0..m {
937 data[(i, j)] =
938 (2.0 * PI * t[j]).sin() + 0.1 * ((i * 37 + j * 13) % 20) as f64 / 20.0;
939 }
940 }
941
942 let basis_type = BasisType::Bspline { order: 4 };
943 let result = smooth_basis_gcv(&data, &t, &basis_type, 15, 2, (-8.0, 4.0), 25);
944 assert!(result.is_some(), "GCV search should succeed");
945 }
946
947 #[test]
948 fn test_smooth_basis_large_lambda_reduces_edf() {
949 let m = 101;
950 let n = 3;
951 let t = uniform_grid(m);
952
953 let mut data = FdMatrix::zeros(n, m);
954 for i in 0..n {
955 for j in 0..m {
956 data[(i, j)] = (2.0 * PI * t[j]).sin();
957 }
958 }
959
960 let nbasis = 15;
961 let penalty = bspline_penalty_matrix(&t, nbasis, 4, 2);
962 let _actual_k = (penalty.len() as f64).sqrt() as usize;
963
964 let fdpar_small = FdPar {
965 basis_type: BasisType::Bspline { order: 4 },
966 nbasis,
967 lambda: 1e-8,
968 lfd_order: 2,
969 penalty_matrix: penalty.clone(),
970 };
971 let fdpar_large = FdPar {
972 basis_type: BasisType::Bspline { order: 4 },
973 nbasis,
974 lambda: 1e2,
975 lfd_order: 2,
976 penalty_matrix: penalty,
977 };
978
979 let res_small = smooth_basis(&data, &t, &fdpar_small).unwrap();
980 let res_large = smooth_basis(&data, &t, &fdpar_large).unwrap();
981
982 assert!(
983 res_large.edf < res_small.edf,
984 "Larger lambda should reduce EDF: {} vs {}",
985 res_large.edf,
986 res_small.edf
987 );
988 }
989
990 #[test]
993 fn test_basis_nbasis_cv_gcv() {
994 let m = 101;
995 let n = 5;
996 let t = uniform_grid(m);
997 let mut data = FdMatrix::zeros(n, m);
998 for i in 0..n {
999 for j in 0..m {
1000 data[(i, j)] =
1001 (2.0 * PI * t[j]).sin() + 0.1 * ((i * 37 + j * 13) % 20) as f64 / 20.0;
1002 }
1003 }
1004
1005 let nbasis_range: Vec<usize> = (4..=20).step_by(2).collect();
1006 let result = basis_nbasis_cv(
1007 &data,
1008 &t,
1009 &nbasis_range,
1010 &BasisType::Bspline { order: 4 },
1011 BasisCriterion::Gcv,
1012 5,
1013 1e-4,
1014 );
1015 assert!(result.is_some());
1016 let res = result.unwrap();
1017 assert!(nbasis_range.contains(&res.optimal_nbasis));
1018 assert_eq!(res.scores.len(), nbasis_range.len());
1019 assert_eq!(res.criterion, BasisCriterion::Gcv);
1020 }
1021
1022 #[test]
1023 fn test_basis_nbasis_cv_aic_bic() {
1024 let m = 51;
1025 let n = 5;
1026 let t = uniform_grid(m);
1027 let mut data = FdMatrix::zeros(n, m);
1028 for i in 0..n {
1029 for j in 0..m {
1030 data[(i, j)] = (2.0 * PI * t[j]).sin();
1031 }
1032 }
1033
1034 let nbasis_range: Vec<usize> = vec![5, 7, 9, 11];
1035 let aic_result = basis_nbasis_cv(
1036 &data,
1037 &t,
1038 &nbasis_range,
1039 &BasisType::Bspline { order: 4 },
1040 BasisCriterion::Aic,
1041 5,
1042 0.0,
1043 );
1044 let bic_result = basis_nbasis_cv(
1045 &data,
1046 &t,
1047 &nbasis_range,
1048 &BasisType::Bspline { order: 4 },
1049 BasisCriterion::Bic,
1050 5,
1051 0.0,
1052 );
1053 assert!(aic_result.is_some());
1054 assert!(bic_result.is_some());
1055 }
1056
1057 #[test]
1058 fn test_basis_nbasis_cv_kfold() {
1059 let m = 51;
1060 let n = 10;
1061 let t = uniform_grid(m);
1062 let mut data = FdMatrix::zeros(n, m);
1063 for i in 0..n {
1064 for j in 0..m {
1065 data[(i, j)] = (2.0 * PI * t[j]).sin() + 0.05 * ((i * 7 + j * 3) % 10) as f64;
1066 }
1067 }
1068
1069 let nbasis_range: Vec<usize> = vec![5, 7, 9];
1070 let result = basis_nbasis_cv(
1071 &data,
1072 &t,
1073 &nbasis_range,
1074 &BasisType::Bspline { order: 4 },
1075 BasisCriterion::Cv,
1076 5,
1077 1e-4,
1078 );
1079 assert!(result.is_some());
1080 let res = result.unwrap();
1081 assert!(nbasis_range.contains(&res.optimal_nbasis));
1082 assert_eq!(res.criterion, BasisCriterion::Cv);
1083 }
1084
1085 fn make_test_data(n: usize, m: usize) -> (FdMatrix, Vec<f64>) {
1089 let t: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
1090 let mut data = FdMatrix::zeros(n, m);
1091 for i in 0..n {
1092 for j in 0..m {
1093 data[(i, j)] = (2.0 * PI * t[j]).sin()
1094 + 0.1 * (10.0 * t[j]).sin()
1095 + 0.05 * ((i * 37 + j * 13) % 20) as f64 / 20.0;
1096 }
1097 }
1098 (data, t)
1099 }
1100
1101 fn make_bspline_fdpar(argvals: &[f64], nbasis: usize, lambda: f64) -> FdPar {
1103 let penalty = bspline_penalty_matrix(argvals, nbasis, 4, 2);
1104 FdPar {
1105 basis_type: BasisType::Bspline { order: 4 },
1106 nbasis,
1107 lambda,
1108 lfd_order: 2,
1109 penalty_matrix: penalty,
1110 }
1111 }
1112
1113 fn make_fourier_fdpar(nbasis: usize, period: f64, lambda: f64) -> FdPar {
1115 let penalty = fourier_penalty_matrix(nbasis, period, 2);
1116 FdPar {
1117 basis_type: BasisType::Fourier { period },
1118 nbasis,
1119 lambda,
1120 lfd_order: 2,
1121 penalty_matrix: penalty,
1122 }
1123 }
1124
1125 #[test]
1128 fn test_basis_type_bspline_variant() {
1129 let bt = BasisType::Bspline { order: 4 };
1130 assert_eq!(bt, BasisType::Bspline { order: 4 });
1131 assert_ne!(bt, BasisType::Bspline { order: 3 });
1133 }
1134
1135 #[test]
1136 fn test_basis_type_fourier_variant() {
1137 let bt = BasisType::Fourier { period: 1.0 };
1138 assert_eq!(bt, BasisType::Fourier { period: 1.0 });
1139 assert_ne!(bt, BasisType::Fourier { period: 2.0 });
1140 }
1141
1142 #[test]
1143 fn test_basis_type_cross_variant_inequality() {
1144 let bspline = BasisType::Bspline { order: 4 };
1145 let fourier = BasisType::Fourier { period: 1.0 };
1146 assert_ne!(bspline, fourier);
1147 }
1148
1149 #[test]
1150 fn test_basis_type_clone_and_debug() {
1151 let bt = BasisType::Bspline { order: 4 };
1152 let cloned = bt.clone();
1153 assert_eq!(bt, cloned);
1154 let debug_str = format!("{:?}", bt);
1155 assert!(debug_str.contains("Bspline"));
1156 assert!(debug_str.contains("4"));
1157 }
1158
1159 #[test]
1162 fn test_fdpar_construction_and_fields() {
1163 let penalty = vec![1.0, 0.0, 0.0, 1.0];
1164 let fdpar = FdPar {
1165 basis_type: BasisType::Bspline { order: 4 },
1166 nbasis: 2,
1167 lambda: 0.01,
1168 lfd_order: 2,
1169 penalty_matrix: penalty.clone(),
1170 };
1171 assert_eq!(fdpar.nbasis, 2);
1172 assert!((fdpar.lambda - 0.01).abs() < 1e-15);
1173 assert_eq!(fdpar.lfd_order, 2);
1174 assert_eq!(fdpar.penalty_matrix.len(), 4);
1175 }
1176
1177 #[test]
1178 fn test_fdpar_clone_and_debug() {
1179 let t = uniform_grid(50);
1180 let fdpar = make_bspline_fdpar(&t, 8, 1e-3);
1181 let cloned = fdpar.clone();
1182 assert_eq!(fdpar, cloned);
1183 let debug_str = format!("{:?}", fdpar);
1184 assert!(debug_str.contains("FdPar"));
1185 }
1186
1187 #[test]
1190 fn test_basis_criterion_variants() {
1191 assert_eq!(BasisCriterion::Gcv, BasisCriterion::Gcv);
1192 assert_eq!(BasisCriterion::Cv, BasisCriterion::Cv);
1193 assert_eq!(BasisCriterion::Aic, BasisCriterion::Aic);
1194 assert_eq!(BasisCriterion::Bic, BasisCriterion::Bic);
1195 assert_ne!(BasisCriterion::Gcv, BasisCriterion::Aic);
1196 assert_ne!(BasisCriterion::Cv, BasisCriterion::Bic);
1197 }
1198
1199 #[test]
1200 fn test_basis_criterion_copy() {
1201 let c = BasisCriterion::Gcv;
1202 let copied = c; assert_eq!(c, copied);
1204 }
1205
1206 #[test]
1207 fn test_basis_criterion_debug() {
1208 let debug_str = format!("{:?}", BasisCriterion::Bic);
1209 assert!(debug_str.contains("Bic"));
1210 }
1211
1212 #[test]
1215 fn test_smooth_basis_result_all_fields() {
1216 let (data, t) = make_test_data(3, 50);
1217 let fdpar = make_bspline_fdpar(&t, 10, 1e-4);
1218 let res = smooth_basis(&data, &t, &fdpar).unwrap();
1219
1220 assert_eq!(res.coefficients.nrows(), 3);
1222 assert!(res.coefficients.ncols() > 0);
1223 assert_eq!(res.nbasis, res.coefficients.ncols());
1224 assert_eq!(res.fitted.shape(), (3, 50));
1226 assert!(res.edf > 0.0 && res.edf <= res.nbasis as f64);
1228 assert!(res.gcv.is_finite());
1230 assert!(res.aic.is_finite());
1231 assert!(res.bic.is_finite());
1232 let k = res.nbasis;
1234 assert_eq!(res.penalty_matrix.len(), k * k);
1235 }
1236
1237 #[test]
1238 fn test_smooth_basis_result_clone() {
1239 let (data, t) = make_test_data(2, 50);
1240 let fdpar = make_bspline_fdpar(&t, 8, 1e-3);
1241 let res = smooth_basis(&data, &t, &fdpar).unwrap();
1242 let cloned = res.clone();
1243 assert_eq!(res, cloned);
1244 }
1245
1246 #[test]
1249 fn test_smooth_basis_bspline_coefficient_shape() {
1250 let (data, t) = make_test_data(4, 50);
1251 let nbasis = 12;
1252 let fdpar = make_bspline_fdpar(&t, nbasis, 1e-4);
1253 let res = smooth_basis(&data, &t, &fdpar).unwrap();
1254 assert_eq!(res.coefficients.nrows(), 4);
1255 assert!(res.coefficients.ncols() >= 2);
1257 assert_eq!(res.nbasis, res.coefficients.ncols());
1258 }
1259
1260 #[test]
1261 fn test_smooth_basis_bspline_fitted_values_shape() {
1262 let m = 80;
1263 let n = 6;
1264 let (data, t) = make_test_data(n, m);
1265 let fdpar = make_bspline_fdpar(&t, 15, 1e-4);
1266 let res = smooth_basis(&data, &t, &fdpar).unwrap();
1267 assert_eq!(res.fitted.shape(), (n, m));
1268 }
1269
1270 #[test]
1271 fn test_smooth_basis_bspline_zero_lambda_interpolates() {
1272 let m = 30;
1274 let n = 2;
1275 let (data, t) = make_test_data(n, m);
1276 let fdpar = make_bspline_fdpar(&t, 15, 0.0);
1277 let res = smooth_basis(&data, &t, &fdpar).unwrap();
1278
1279 let mut max_resid = 0.0_f64;
1281 for i in 0..n {
1282 for j in 0..m {
1283 let resid = (data[(i, j)] - res.fitted[(i, j)]).abs();
1284 max_resid = max_resid.max(resid);
1285 }
1286 }
1287 assert!(
1288 max_resid < 0.5,
1289 "Zero-lambda B-spline should closely interpolate; max_resid = {}",
1290 max_resid
1291 );
1292 }
1293
1294 #[test]
1295 fn test_smooth_basis_bspline_large_lambda_oversmooths() {
1296 let m = 50;
1299 let n = 1;
1300 let (data, t) = make_test_data(n, m);
1301
1302 let fdpar_small = make_bspline_fdpar(&t, 15, 1e-6);
1303 let res_small = smooth_basis(&data, &t, &fdpar_small).unwrap();
1304
1305 let fdpar_large = make_bspline_fdpar(&t, 15, 1e6);
1306 let res_large = smooth_basis(&data, &t, &fdpar_large).unwrap();
1307
1308 let compute_variance = |fitted: &FdMatrix, row: usize, ncols: usize| -> f64 {
1309 let vals: Vec<f64> = (0..ncols).map(|j| fitted[(row, j)]).collect();
1310 let mean = vals.iter().sum::<f64>() / ncols as f64;
1311 vals.iter().map(|&v| (v - mean).powi(2)).sum::<f64>() / ncols as f64
1312 };
1313
1314 let var_small = compute_variance(&res_small.fitted, 0, m);
1315 let var_large = compute_variance(&res_large.fitted, 0, m);
1316 assert!(
1317 var_large < var_small,
1318 "Large lambda should yield lower variance fit: var_large={}, var_small={}",
1319 var_large,
1320 var_small
1321 );
1322 }
1323
1324 #[test]
1325 fn test_smooth_basis_bspline_penalty_effect_on_smoothness() {
1326 let m = 50;
1328 let n = 1;
1329 let (data, t) = make_test_data(n, m);
1330
1331 let fdpar_small = make_bspline_fdpar(&t, 15, 1e-8);
1332 let fdpar_large = make_bspline_fdpar(&t, 15, 1.0);
1333
1334 let res_small = smooth_basis(&data, &t, &fdpar_small).unwrap();
1335 let res_large = smooth_basis(&data, &t, &fdpar_large).unwrap();
1336
1337 let roughness = |fitted: &FdMatrix, row: usize, ncols: usize| -> f64 {
1339 (1..ncols - 1)
1340 .map(|j| {
1341 let d2 = fitted[(row, j + 1)] - 2.0 * fitted[(row, j)] + fitted[(row, j - 1)];
1342 d2 * d2
1343 })
1344 .sum::<f64>()
1345 };
1346
1347 let r_small = roughness(&res_small.fitted, 0, m);
1348 let r_large = roughness(&res_large.fitted, 0, m);
1349 assert!(
1350 r_large < r_small,
1351 "Larger lambda should produce smoother fit: roughness_large={}, roughness_small={}",
1352 r_large,
1353 r_small
1354 );
1355 }
1356
1357 #[test]
1358 fn test_smooth_basis_bspline_single_curve() {
1359 let m = 50;
1360 let (data, t) = make_test_data(1, m);
1361 let fdpar = make_bspline_fdpar(&t, 10, 1e-4);
1362 let res = smooth_basis(&data, &t, &fdpar).unwrap();
1363 assert_eq!(res.fitted.nrows(), 1);
1364 assert_eq!(res.fitted.ncols(), m);
1365 assert!(res.gcv.is_finite());
1366 }
1367
1368 #[test]
1369 fn test_smooth_basis_bspline_many_curves() {
1370 let m = 50;
1371 let n = 20;
1372 let (data, t) = make_test_data(n, m);
1373 let fdpar = make_bspline_fdpar(&t, 10, 1e-4);
1374 let res = smooth_basis(&data, &t, &fdpar).unwrap();
1375 assert_eq!(res.fitted.nrows(), n);
1376 assert_eq!(res.coefficients.nrows(), n);
1377 }
1378
1379 #[test]
1380 fn test_smooth_basis_bspline_minimal_nbasis() {
1381 let m = 50;
1383 let (data, t) = make_test_data(1, m);
1384 let fdpar = make_bspline_fdpar(&t, 2, 1e-4);
1385 let res = smooth_basis(&data, &t, &fdpar);
1386 assert!(res.is_ok());
1388 }
1389
1390 #[test]
1391 fn test_smooth_basis_bspline_different_orders() {
1392 let m = 50;
1393 let (data, t) = make_test_data(2, m);
1394 let penalty3 = bspline_penalty_matrix(&t, 10, 3, 2);
1396 let fdpar3 = FdPar {
1397 basis_type: BasisType::Bspline { order: 3 },
1398 nbasis: 10,
1399 lambda: 1e-4,
1400 lfd_order: 2,
1401 penalty_matrix: penalty3,
1402 };
1403 let res3 = smooth_basis(&data, &t, &fdpar3);
1404 assert!(res3.is_ok());
1405
1406 let penalty5 = bspline_penalty_matrix(&t, 10, 5, 2);
1408 let fdpar5 = FdPar {
1409 basis_type: BasisType::Bspline { order: 5 },
1410 nbasis: 10,
1411 lambda: 1e-4,
1412 lfd_order: 2,
1413 penalty_matrix: penalty5,
1414 };
1415 let res5 = smooth_basis(&data, &t, &fdpar5);
1416 assert!(res5.is_ok());
1417 }
1418
1419 #[test]
1422 fn test_smooth_basis_fourier_coefficient_shape() {
1423 let m = 50;
1424 let n = 3;
1425 let t = uniform_grid(m);
1426 let mut data = FdMatrix::zeros(n, m);
1427 for i in 0..n {
1428 for j in 0..m {
1429 data[(i, j)] = (2.0 * PI * t[j]).sin();
1430 }
1431 }
1432 let nbasis = 7;
1433 let fdpar = make_fourier_fdpar(nbasis, 1.0, 1e-6);
1434 let res = smooth_basis(&data, &t, &fdpar).unwrap();
1435 assert_eq!(res.coefficients.nrows(), n);
1436 assert_eq!(res.coefficients.ncols(), nbasis);
1437 assert_eq!(res.nbasis, nbasis);
1438 }
1439
1440 #[test]
1441 fn test_smooth_basis_fourier_fits_pure_sine() {
1442 let m = 100;
1444 let t = uniform_grid(m);
1445 let mut data = FdMatrix::zeros(1, m);
1446 for j in 0..m {
1447 data[(0, j)] = (2.0 * PI * t[j]).sin();
1448 }
1449 let fdpar = make_fourier_fdpar(5, 1.0, 1e-8);
1450 let res = smooth_basis(&data, &t, &fdpar).unwrap();
1451
1452 for j in 0..m {
1453 let expected = (2.0 * PI * t[j]).sin();
1454 assert!(
1455 (res.fitted[(0, j)] - expected).abs() < 0.05,
1456 "Fourier should fit pure sine; j={}, got={}, expected={}",
1457 j,
1458 res.fitted[(0, j)],
1459 expected
1460 );
1461 }
1462 }
1463
1464 #[test]
1465 fn test_smooth_basis_fourier_different_periods() {
1466 let m = 50;
1467 let t = uniform_grid(m);
1468 let mut data = FdMatrix::zeros(1, m);
1469 for j in 0..m {
1470 data[(0, j)] = (2.0 * PI * t[j]).sin();
1471 }
1472
1473 let fdpar1 = make_fourier_fdpar(7, 1.0, 1e-6);
1475 let res1 = smooth_basis(&data, &t, &fdpar1).unwrap();
1476
1477 let fdpar2 = make_fourier_fdpar(7, 2.0, 1e-6);
1479 let res2 = smooth_basis(&data, &t, &fdpar2).unwrap();
1480
1481 assert_eq!(res1.fitted.shape(), (1, m));
1483 assert_eq!(res2.fitted.shape(), (1, m));
1484 }
1485
1486 #[test]
1487 fn test_smooth_basis_fourier_zero_lambda() {
1488 let m = 50;
1489 let t = uniform_grid(m);
1490 let mut data = FdMatrix::zeros(1, m);
1491 for j in 0..m {
1492 data[(0, j)] = (2.0 * PI * t[j]).sin() + (4.0 * PI * t[j]).cos();
1493 }
1494 let fdpar = make_fourier_fdpar(9, 1.0, 0.0);
1495 let res = smooth_basis(&data, &t, &fdpar).unwrap();
1496 assert_eq!(res.fitted.shape(), (1, m));
1497 assert!(res.edf > 1.0);
1499 }
1500
1501 #[test]
1502 fn test_smooth_basis_fourier_large_lambda() {
1503 let m = 50;
1504 let t = uniform_grid(m);
1505 let mut data = FdMatrix::zeros(1, m);
1506 for j in 0..m {
1507 data[(0, j)] = (2.0 * PI * t[j]).sin();
1508 }
1509 let fdpar = make_fourier_fdpar(9, 1.0, 1e6);
1510 let res = smooth_basis(&data, &t, &fdpar).unwrap();
1511 assert!(
1513 res.edf < 5.0,
1514 "Large lambda should reduce EDF; edf={}",
1515 res.edf
1516 );
1517 }
1518
1519 #[test]
1522 fn test_smooth_basis_lambda_gradient_edf() {
1523 let m = 50;
1525 let (data, t) = make_test_data(3, m);
1526 let lambdas = [1e-8, 1e-4, 1e-2, 1.0, 1e2];
1527 let mut prev_edf = f64::INFINITY;
1528 for &lam in &lambdas {
1529 let fdpar = make_bspline_fdpar(&t, 12, lam);
1530 let res = smooth_basis(&data, &t, &fdpar).unwrap();
1531 assert!(
1532 res.edf <= prev_edf + 0.01,
1533 "EDF should decrease: lambda={}, edf={}, prev_edf={}",
1534 lam,
1535 res.edf,
1536 prev_edf
1537 );
1538 prev_edf = res.edf;
1539 }
1540 }
1541
1542 #[test]
1543 fn test_smooth_basis_lambda_gradient_rss() {
1544 let m = 50;
1546 let n = 2;
1547 let (data, t) = make_test_data(n, m);
1548 let lambdas = [0.0, 1e-6, 1e-2, 1.0, 1e4];
1549 let mut prev_rss = -1.0;
1550 for &lam in &lambdas {
1551 let fdpar = make_bspline_fdpar(&t, 12, lam);
1552 let res = smooth_basis(&data, &t, &fdpar).unwrap();
1553 let mut rss = 0.0;
1554 for i in 0..n {
1555 for j in 0..m {
1556 rss += (data[(i, j)] - res.fitted[(i, j)]).powi(2);
1557 }
1558 }
1559 assert!(
1560 rss >= prev_rss - 1e-8,
1561 "RSS should increase: lambda={}, rss={}, prev_rss={}",
1562 lam,
1563 rss,
1564 prev_rss
1565 );
1566 prev_rss = rss;
1567 }
1568 }
1569
1570 #[test]
1573 fn test_smooth_basis_empty_data_rows() {
1574 let t = uniform_grid(50);
1575 let data = FdMatrix::zeros(0, 50);
1576 let fdpar = make_bspline_fdpar(&t, 10, 1e-4);
1577 let res = smooth_basis(&data, &t, &fdpar);
1578 assert!(res.is_err());
1579 }
1580
1581 #[test]
1582 fn test_smooth_basis_empty_data_cols() {
1583 let data = FdMatrix::zeros(5, 0);
1584 let fdpar = FdPar {
1585 basis_type: BasisType::Bspline { order: 4 },
1586 nbasis: 10,
1587 lambda: 1e-4,
1588 lfd_order: 2,
1589 penalty_matrix: vec![0.0; 100],
1590 };
1591 let res = smooth_basis(&data, &[], &fdpar);
1592 assert!(res.is_err());
1593 }
1594
1595 #[test]
1596 fn test_smooth_basis_mismatched_argvals() {
1597 let t = uniform_grid(50);
1598 let data = FdMatrix::zeros(3, 40); let fdpar = make_bspline_fdpar(&t, 10, 1e-4);
1600 let res = smooth_basis(&data, &t, &fdpar);
1601 assert!(res.is_err());
1602 }
1603
1604 #[test]
1605 fn test_smooth_basis_nbasis_too_small() {
1606 let t = uniform_grid(50);
1607 let data = FdMatrix::zeros(3, 50);
1608 let fdpar = FdPar {
1610 basis_type: BasisType::Bspline { order: 4 },
1611 nbasis: 1,
1612 lambda: 1e-4,
1613 lfd_order: 2,
1614 penalty_matrix: vec![0.0; 1],
1615 };
1616 let res = smooth_basis(&data, &t, &fdpar);
1617 assert!(res.is_err());
1618 }
1619
1620 #[test]
1621 fn test_smooth_basis_error_is_invalid_dimension() {
1622 let t = uniform_grid(50);
1623 let data = FdMatrix::zeros(0, 50);
1624 let fdpar = make_bspline_fdpar(&t, 10, 1e-4);
1625 let err = smooth_basis(&data, &t, &fdpar).unwrap_err();
1626 match err {
1627 crate::FdarError::InvalidDimension { .. } => {} other => panic!("Expected InvalidDimension, got {:?}", other),
1629 }
1630 }
1631
1632 #[test]
1635 fn test_bspline_penalty_matrix_different_orders() {
1636 let t = uniform_grid(101);
1637 let p1 = bspline_penalty_matrix(&t, 10, 4, 1);
1639 let p2 = bspline_penalty_matrix(&t, 10, 4, 2);
1641 assert_eq!(p1.len(), p2.len());
1643 let diff: f64 = p1.iter().zip(p2.iter()).map(|(a, b)| (a - b).abs()).sum();
1645 assert!(
1646 diff > 1e-10,
1647 "Different lfd_orders should produce different penalties"
1648 );
1649 }
1650
1651 #[test]
1652 fn test_bspline_penalty_matrix_edge_cases() {
1653 let t = vec![0.0];
1655 let p = bspline_penalty_matrix(&t, 10, 4, 2);
1656 assert!(p.iter().all(|&v| v == 0.0));
1658
1659 let t2 = uniform_grid(50);
1661 let p2 = bspline_penalty_matrix(&t2, 1, 4, 2);
1662 assert!(p2.iter().all(|&v| v == 0.0));
1663
1664 let p3 = bspline_penalty_matrix(&t2, 10, 4, 4);
1666 assert!(p3.iter().all(|&v| v == 0.0));
1667 }
1668
1669 #[test]
1670 fn test_bspline_penalty_nonnegative_diagonal() {
1671 let t = uniform_grid(101);
1672 for nbasis in [5, 10, 20] {
1673 let p = bspline_penalty_matrix(&t, nbasis, 4, 2);
1674 let k = (p.len() as f64).sqrt() as usize;
1675 for i in 0..k {
1676 assert!(
1677 p[i + i * k] >= -1e-10,
1678 "Diagonal ({},{}) negative for nbasis={}: {}",
1679 i,
1680 i,
1681 nbasis,
1682 p[i + i * k]
1683 );
1684 }
1685 }
1686 }
1687
1688 #[test]
1689 fn test_fourier_penalty_increasing_with_frequency() {
1690 let penalty = fourier_penalty_matrix(11, 1.0, 2);
1691 let k = 11;
1692 assert!(penalty[0].abs() < 1e-15);
1694 let mut prev_eigenval = 0.0;
1696 for freq in 1..=5 {
1697 let idx_sin = 2 * freq - 1;
1698 let eigenval = penalty[idx_sin + idx_sin * k];
1699 assert!(
1700 eigenval > prev_eigenval,
1701 "Higher frequency should have larger penalty: freq={}, eigenval={}, prev={}",
1702 freq,
1703 eigenval,
1704 prev_eigenval
1705 );
1706 prev_eigenval = eigenval;
1707 let idx_cos = 2 * freq;
1709 if idx_cos < k {
1710 assert!(
1711 (penalty[idx_cos + idx_cos * k] - eigenval).abs() < 1e-10,
1712 "Sin and cos penalty should match at freq {}",
1713 freq
1714 );
1715 }
1716 }
1717 }
1718
1719 #[test]
1720 fn test_fourier_penalty_different_periods() {
1721 let p1 = fourier_penalty_matrix(7, 1.0, 2);
1722 let p2 = fourier_penalty_matrix(7, 2.0, 2);
1723 for i in 1..7 {
1725 assert!(
1726 p2[i + i * 7] < p1[i + i * 7] || (p1[i + i * 7] == 0.0 && p2[i + i * 7] == 0.0),
1727 "Longer period should have smaller penalties at i={}",
1728 i
1729 );
1730 }
1731 }
1732
1733 #[test]
1734 fn test_fourier_penalty_first_order() {
1735 let p = fourier_penalty_matrix(5, 1.0, 1);
1737 let omega1 = 2.0 * PI;
1739 let expected1 = omega1.powi(2);
1740 assert!(
1741 (p[1 + 5] - expected1).abs() < 1e-6,
1742 "First-order penalty eigenval: got {}, expected {}",
1743 p[1 + 5],
1744 expected1
1745 );
1746 }
1747
1748 #[test]
1749 fn test_fourier_penalty_zero_nbasis() {
1750 let p = fourier_penalty_matrix(0, 1.0, 2);
1751 assert!(p.is_empty());
1752 }
1753
1754 #[test]
1755 fn test_fourier_penalty_nbasis_one() {
1756 let p = fourier_penalty_matrix(1, 1.0, 2);
1757 assert_eq!(p.len(), 1);
1758 assert!(p[0].abs() < 1e-15); }
1760
1761 #[test]
1764 fn test_smooth_basis_gcv_returns_valid_result() {
1765 let (data, t) = make_test_data(5, 50);
1766 let bt = BasisType::Bspline { order: 4 };
1767 let result = smooth_basis_gcv(&data, &t, &bt, 12, 2, (-6.0, 2.0), 20);
1768 assert!(result.is_some());
1769 let res = result.unwrap();
1770 assert_eq!(res.fitted.shape(), (5, 50));
1771 assert!(res.gcv.is_finite());
1772 assert!(res.edf > 0.0);
1773 }
1774
1775 #[test]
1776 fn test_smooth_basis_gcv_fourier() {
1777 let m = 80;
1778 let t = uniform_grid(m);
1779 let mut data = FdMatrix::zeros(3, m);
1780 for i in 0..3 {
1781 for j in 0..m {
1782 data[(i, j)] = (2.0 * PI * t[j]).sin() + 0.5 * (4.0 * PI * t[j]).cos();
1783 }
1784 }
1785 let bt = BasisType::Fourier { period: 1.0 };
1786 let result = smooth_basis_gcv(&data, &t, &bt, 9, 2, (-8.0, 4.0), 25);
1787 assert!(result.is_some());
1788 let res = result.unwrap();
1789 assert_eq!(res.fitted.nrows(), 3);
1790 assert_eq!(res.nbasis, 9);
1791 }
1792
1793 #[test]
1794 fn test_smooth_basis_gcv_selects_finite_gcv() {
1795 let (data, t) = make_test_data(5, 60);
1796 let bt = BasisType::Bspline { order: 4 };
1797 let res = smooth_basis_gcv(&data, &t, &bt, 12, 2, (-6.0, 2.0), 15).unwrap();
1798 assert!(res.gcv.is_finite());
1799 assert!(res.gcv > 0.0);
1800 }
1801
1802 #[test]
1803 fn test_smooth_basis_gcv_empty_data() {
1804 let data = FdMatrix::zeros(0, 50);
1805 let t = uniform_grid(50);
1806 let bt = BasisType::Bspline { order: 4 };
1807 let result = smooth_basis_gcv(&data, &t, &bt, 10, 2, (-6.0, 2.0), 10);
1808 assert!(result.is_none());
1810 }
1811
1812 #[test]
1813 fn test_smooth_basis_gcv_empty_argvals() {
1814 let data = FdMatrix::zeros(5, 0);
1815 let bt = BasisType::Bspline { order: 4 };
1816 let result = smooth_basis_gcv(&data, &[], &bt, 10, 2, (-6.0, 2.0), 10);
1817 assert!(result.is_none());
1818 }
1819
1820 #[test]
1821 fn test_smooth_basis_gcv_nbasis_too_small() {
1822 let (data, t) = make_test_data(5, 50);
1823 let bt = BasisType::Bspline { order: 4 };
1824 let result = smooth_basis_gcv(&data, &t, &bt, 1, 2, (-6.0, 2.0), 10);
1825 assert!(result.is_none());
1826 }
1827
1828 #[test]
1829 fn test_smooth_basis_gcv_ngrid_too_small() {
1830 let (data, t) = make_test_data(5, 50);
1831 let bt = BasisType::Bspline { order: 4 };
1832 let result = smooth_basis_gcv(&data, &t, &bt, 10, 2, (-6.0, 2.0), 1);
1833 assert!(result.is_none());
1834 }
1835
1836 #[test]
1837 fn test_smooth_basis_gcv_narrow_range() {
1838 let (data, t) = make_test_data(3, 50);
1839 let bt = BasisType::Bspline { order: 4 };
1840 let result = smooth_basis_gcv(&data, &t, &bt, 10, 2, (-3.0, -2.0), 5);
1842 assert!(result.is_some());
1843 }
1844
1845 #[test]
1846 fn test_smooth_basis_gcv_wide_range() {
1847 let (data, t) = make_test_data(3, 50);
1848 let bt = BasisType::Bspline { order: 4 };
1849 let result = smooth_basis_gcv(&data, &t, &bt, 10, 2, (-12.0, 8.0), 30);
1851 assert!(result.is_some());
1852 }
1853
1854 #[test]
1857 fn test_basis_nbasis_cv_scores_length() {
1858 let (data, t) = make_test_data(5, 50);
1859 let nbasis_range: Vec<usize> = vec![4, 6, 8, 10, 12];
1860 let res = basis_nbasis_cv(
1861 &data,
1862 &t,
1863 &nbasis_range,
1864 &BasisType::Bspline { order: 4 },
1865 BasisCriterion::Gcv,
1866 5,
1867 1e-4,
1868 )
1869 .unwrap();
1870 assert_eq!(res.scores.len(), 5);
1871 assert_eq!(res.nbasis_range.len(), 5);
1872 assert_eq!(res.nbasis_range, nbasis_range);
1873 }
1874
1875 #[test]
1876 fn test_basis_nbasis_cv_optimal_within_range() {
1877 let (data, t) = make_test_data(8, 50);
1878 let nbasis_range: Vec<usize> = vec![5, 7, 9, 11, 13, 15];
1879 for criterion in [
1880 BasisCriterion::Gcv,
1881 BasisCriterion::Aic,
1882 BasisCriterion::Bic,
1883 ] {
1884 let res = basis_nbasis_cv(
1885 &data,
1886 &t,
1887 &nbasis_range,
1888 &BasisType::Bspline { order: 4 },
1889 criterion,
1890 5,
1891 1e-4,
1892 )
1893 .unwrap();
1894 assert!(
1895 nbasis_range.contains(&res.optimal_nbasis),
1896 "optimal_nbasis {} not in range for {:?}",
1897 res.optimal_nbasis,
1898 criterion
1899 );
1900 }
1901 }
1902
1903 #[test]
1904 fn test_basis_nbasis_cv_fourier_gcv() {
1905 let m = 80;
1906 let t = uniform_grid(m);
1907 let mut data = FdMatrix::zeros(5, m);
1908 for i in 0..5 {
1909 for j in 0..m {
1910 data[(i, j)] = (2.0 * PI * t[j]).sin()
1911 + 0.3 * (4.0 * PI * t[j]).cos()
1912 + 0.02 * ((i * 7 + j * 3) % 10) as f64;
1913 }
1914 }
1915 let nbasis_range: Vec<usize> = vec![5, 7, 9, 11];
1916 let res = basis_nbasis_cv(
1917 &data,
1918 &t,
1919 &nbasis_range,
1920 &BasisType::Fourier { period: 1.0 },
1921 BasisCriterion::Gcv,
1922 5,
1923 1e-4,
1924 )
1925 .unwrap();
1926 assert!(nbasis_range.contains(&res.optimal_nbasis));
1927 }
1928
1929 #[test]
1930 fn test_basis_nbasis_cv_fourier_cv() {
1931 let m = 60;
1932 let t = uniform_grid(m);
1933 let n = 10;
1934 let mut data = FdMatrix::zeros(n, m);
1935 for i in 0..n {
1936 for j in 0..m {
1937 data[(i, j)] = (2.0 * PI * t[j]).sin() + 0.02 * ((i * 11 + j) % 15) as f64;
1938 }
1939 }
1940 let nbasis_range: Vec<usize> = vec![5, 7, 9];
1941 let res = basis_nbasis_cv(
1942 &data,
1943 &t,
1944 &nbasis_range,
1945 &BasisType::Fourier { period: 1.0 },
1946 BasisCriterion::Cv,
1947 5,
1948 1e-4,
1949 )
1950 .unwrap();
1951 assert!(nbasis_range.contains(&res.optimal_nbasis));
1952 assert_eq!(res.criterion, BasisCriterion::Cv);
1953 }
1954
1955 #[test]
1956 fn test_basis_nbasis_cv_with_nbasis_below_minimum() {
1957 let (data, t) = make_test_data(5, 50);
1959 let nbasis_range: Vec<usize> = vec![1, 5, 10];
1960 let res = basis_nbasis_cv(
1961 &data,
1962 &t,
1963 &nbasis_range,
1964 &BasisType::Bspline { order: 4 },
1965 BasisCriterion::Gcv,
1966 5,
1967 1e-4,
1968 )
1969 .unwrap();
1970 assert!(
1972 res.optimal_nbasis >= 5,
1973 "Should skip invalid nbasis=1, got optimal={}",
1974 res.optimal_nbasis
1975 );
1976 assert!(res.scores[0].is_infinite());
1977 }
1978
1979 #[test]
1980 fn test_basis_nbasis_cv_empty_range() {
1981 let (data, t) = make_test_data(5, 50);
1982 let nbasis_range: Vec<usize> = vec![];
1983 let result = basis_nbasis_cv(
1984 &data,
1985 &t,
1986 &nbasis_range,
1987 &BasisType::Bspline { order: 4 },
1988 BasisCriterion::Gcv,
1989 5,
1990 1e-4,
1991 );
1992 assert!(result.is_none());
1993 }
1994
1995 #[test]
1996 fn test_basis_nbasis_cv_empty_data() {
1997 let data = FdMatrix::zeros(0, 50);
1998 let t = uniform_grid(50);
1999 let nbasis_range: Vec<usize> = vec![5, 10];
2000 let result = basis_nbasis_cv(
2001 &data,
2002 &t,
2003 &nbasis_range,
2004 &BasisType::Bspline { order: 4 },
2005 BasisCriterion::Gcv,
2006 5,
2007 1e-4,
2008 );
2009 assert!(result.is_none());
2010 }
2011
2012 #[test]
2013 fn test_basis_nbasis_cv_mismatched_argvals() {
2014 let data = FdMatrix::zeros(5, 50);
2015 let t = uniform_grid(40); let nbasis_range: Vec<usize> = vec![5, 10];
2017 let result = basis_nbasis_cv(
2018 &data,
2019 &t,
2020 &nbasis_range,
2021 &BasisType::Bspline { order: 4 },
2022 BasisCriterion::Gcv,
2023 5,
2024 1e-4,
2025 );
2026 assert!(result.is_none());
2027 }
2028
2029 #[test]
2030 fn test_basis_nbasis_cv_single_nbasis() {
2031 let (data, t) = make_test_data(5, 50);
2032 let nbasis_range: Vec<usize> = vec![10];
2033 let res = basis_nbasis_cv(
2034 &data,
2035 &t,
2036 &nbasis_range,
2037 &BasisType::Bspline { order: 4 },
2038 BasisCriterion::Gcv,
2039 5,
2040 1e-4,
2041 )
2042 .unwrap();
2043 assert_eq!(res.optimal_nbasis, 10);
2044 assert_eq!(res.scores.len(), 1);
2045 }
2046
2047 #[test]
2048 fn test_basis_nbasis_cv_bic_penalizes_more_than_aic() {
2049 let (data, t) = make_test_data(5, 80);
2052 let nbasis_range: Vec<usize> = (4..=20).step_by(2).collect();
2053
2054 let aic_res = basis_nbasis_cv(
2055 &data,
2056 &t,
2057 &nbasis_range,
2058 &BasisType::Bspline { order: 4 },
2059 BasisCriterion::Aic,
2060 5,
2061 1e-4,
2062 )
2063 .unwrap();
2064 let bic_res = basis_nbasis_cv(
2065 &data,
2066 &t,
2067 &nbasis_range,
2068 &BasisType::Bspline { order: 4 },
2069 BasisCriterion::Bic,
2070 5,
2071 1e-4,
2072 )
2073 .unwrap();
2074 assert!(
2077 bic_res.optimal_nbasis <= aic_res.optimal_nbasis + 4,
2078 "BIC selected {} vs AIC selected {} -- BIC should not select much more than AIC",
2079 bic_res.optimal_nbasis,
2080 aic_res.optimal_nbasis
2081 );
2082 }
2083
2084 #[test]
2087 fn test_smooth_basis_fitted_close_to_data() {
2088 let m = 50;
2090 let n = 3;
2091 let t = uniform_grid(m);
2092 let mut data = FdMatrix::zeros(n, m);
2093 for i in 0..n {
2094 for j in 0..m {
2095 data[(i, j)] = (2.0 * PI * t[j]).sin();
2096 }
2097 }
2098 let fdpar = make_bspline_fdpar(&t, 15, 1e-6);
2099 let res = smooth_basis(&data, &t, &fdpar).unwrap();
2100
2101 let mut max_err = 0.0_f64;
2102 for i in 0..n {
2103 for j in 0..m {
2104 let err = (data[(i, j)] - res.fitted[(i, j)]).abs();
2105 max_err = max_err.max(err);
2106 }
2107 }
2108 assert!(
2109 max_err < 0.1,
2110 "Fitted should be close to smooth data; max_err={}",
2111 max_err
2112 );
2113 }
2114
2115 #[test]
2116 fn test_smooth_basis_constant_data() {
2117 let m = 50;
2119 let n = 2;
2120 let t = uniform_grid(m);
2121 let mut data = FdMatrix::zeros(n, m);
2122 for i in 0..n {
2123 for j in 0..m {
2124 data[(i, j)] = 3.15;
2125 }
2126 }
2127 let fdpar = make_bspline_fdpar(&t, 10, 1e-4);
2128 let res = smooth_basis(&data, &t, &fdpar).unwrap();
2129 for i in 0..n {
2130 for j in 0..m {
2131 assert!(
2132 (res.fitted[(i, j)] - 3.15).abs() < 0.01,
2133 "Constant data should be fit well at ({},{}): got {}",
2134 i,
2135 j,
2136 res.fitted[(i, j)]
2137 );
2138 }
2139 }
2140 }
2141
2142 #[test]
2143 fn test_smooth_basis_linear_data() {
2144 let m = 50;
2146 let t = uniform_grid(m);
2147 let mut data = FdMatrix::zeros(1, m);
2148 for j in 0..m {
2149 data[(0, j)] = 2.0 * t[j] + 1.0;
2150 }
2151 let fdpar = make_bspline_fdpar(&t, 10, 1e-4);
2152 let res = smooth_basis(&data, &t, &fdpar).unwrap();
2153 for j in 0..m {
2154 let expected = 2.0 * t[j] + 1.0;
2155 assert!(
2156 (res.fitted[(0, j)] - expected).abs() < 0.05,
2157 "Linear data should be fit well at j={}: got {}, expected {}",
2158 j,
2159 res.fitted[(0, j)],
2160 expected
2161 );
2162 }
2163 }
2164
2165 #[test]
2168 fn test_smooth_basis_edf_bounded() {
2169 let m = 50;
2170 let (data, t) = make_test_data(3, m);
2171 let fdpar = make_bspline_fdpar(&t, 12, 1e-4);
2172 let res = smooth_basis(&data, &t, &fdpar).unwrap();
2173 assert!(
2175 res.edf > 0.0 && res.edf <= m as f64,
2176 "EDF should be in (0, {}]; got {}",
2177 m,
2178 res.edf
2179 );
2180 }
2181
2182 #[test]
2183 fn test_smooth_basis_gcv_aic_bic_all_finite() {
2184 let (data, t) = make_test_data(4, 60);
2185 let fdpar = make_bspline_fdpar(&t, 12, 1e-3);
2186 let res = smooth_basis(&data, &t, &fdpar).unwrap();
2187 assert!(res.gcv.is_finite(), "GCV should be finite: {}", res.gcv);
2188 assert!(res.aic.is_finite(), "AIC should be finite: {}", res.aic);
2189 assert!(res.bic.is_finite(), "BIC should be finite: {}", res.bic);
2190 }
2191
2192 #[test]
2195 fn test_smooth_basis_penalty_matrix_in_result() {
2196 let (data, t) = make_test_data(3, 50);
2197 let nbasis = 10;
2198 let fdpar = make_bspline_fdpar(&t, nbasis, 1e-4);
2199 let res = smooth_basis(&data, &t, &fdpar).unwrap();
2200 let k = res.nbasis;
2201 assert_eq!(
2202 res.penalty_matrix.len(),
2203 k * k,
2204 "Penalty matrix should be k*k = {}*{} = {}; got {}",
2205 k,
2206 k,
2207 k * k,
2208 res.penalty_matrix.len()
2209 );
2210 }
2211
2212 #[test]
2215 fn test_smooth_basis_identical_curves_same_coefficients() {
2216 let m = 50;
2217 let t = uniform_grid(m);
2218 let curve: Vec<f64> = (0..m).map(|j| (2.0 * PI * t[j]).sin()).collect();
2219 let n = 4;
2220 let mut data = FdMatrix::zeros(n, m);
2221 for i in 0..n {
2222 for j in 0..m {
2223 data[(i, j)] = curve[j];
2224 }
2225 }
2226 let fdpar = make_bspline_fdpar(&t, 10, 1e-4);
2227 let res = smooth_basis(&data, &t, &fdpar).unwrap();
2228
2229 let k = res.coefficients.ncols();
2231 for i in 1..n {
2232 for j in 0..k {
2233 assert!(
2234 (res.coefficients[(i, j)] - res.coefficients[(0, j)]).abs() < 1e-10,
2235 "Identical curves should have identical coefficients: curve {} col {} differs",
2236 i,
2237 j
2238 );
2239 }
2240 }
2241 }
2242
2243 #[test]
2246 fn test_basis_nbasis_cv_different_nfolds() {
2247 let (data, t) = make_test_data(12, 50);
2248 let nbasis_range: Vec<usize> = vec![5, 8, 11];
2249 for nfolds in [2, 3, 5, 10] {
2250 let res = basis_nbasis_cv(
2251 &data,
2252 &t,
2253 &nbasis_range,
2254 &BasisType::Bspline { order: 4 },
2255 BasisCriterion::Cv,
2256 nfolds,
2257 1e-4,
2258 );
2259 assert!(res.is_some(), "CV should succeed with nfolds={}", nfolds);
2260 let r = res.unwrap();
2261 assert!(nbasis_range.contains(&r.optimal_nbasis));
2262 }
2263 }
2264
2265 #[test]
2268 fn test_smooth_basis_many_basis_functions() {
2269 let m = 100;
2270 let (data, t) = make_test_data(2, m);
2271 let fdpar = make_bspline_fdpar(&t, 40, 1e-2);
2273 let res = smooth_basis(&data, &t, &fdpar);
2274 assert!(
2275 res.is_ok(),
2276 "Should handle many basis functions with penalty"
2277 );
2278 }
2279
2280 #[test]
2283 fn test_smooth_basis_bspline_vs_fourier_different_results() {
2284 let m = 50;
2285 let (data, t) = make_test_data(2, m);
2286 let fdpar_bs = make_bspline_fdpar(&t, 9, 1e-4);
2287 let fdpar_f = make_fourier_fdpar(9, 1.0, 1e-4);
2288 let res_bs = smooth_basis(&data, &t, &fdpar_bs).unwrap();
2289 let res_f = smooth_basis(&data, &t, &fdpar_f).unwrap();
2290 let diff: f64 = (0..m)
2292 .map(|j| (res_bs.fitted[(0, j)] - res_f.fitted[(0, j)]).abs())
2293 .sum();
2294 assert!(
2296 diff > 1e-10,
2297 "B-spline and Fourier fits should differ for the same data"
2298 );
2299 }
2300
2301 #[test]
2304 fn test_smooth_basis_gcv_positive_for_noisy_data() {
2305 let m = 50;
2306 let t = uniform_grid(m);
2307 let mut data = FdMatrix::zeros(1, m);
2308 for j in 0..m {
2309 data[(0, j)] = (2.0 * PI * t[j]).sin() + 0.5 * ((j * 37) % 20) as f64 / 20.0 - 0.25;
2311 }
2312 let fdpar = make_bspline_fdpar(&t, 10, 1e-3);
2313 let res = smooth_basis(&data, &t, &fdpar).unwrap();
2314 assert!(res.gcv > 0.0, "GCV should be positive for noisy data");
2315 }
2316
2317 #[test]
2320 fn test_smooth_basis_different_lfd_orders() {
2321 let m = 50;
2322 let (data, t) = make_test_data(2, m);
2323
2324 let penalty1 = bspline_penalty_matrix(&t, 10, 4, 1);
2326 let fdpar1 = FdPar {
2327 basis_type: BasisType::Bspline { order: 4 },
2328 nbasis: 10,
2329 lambda: 1e-2,
2330 lfd_order: 1,
2331 penalty_matrix: penalty1,
2332 };
2333 let res1 = smooth_basis(&data, &t, &fdpar1);
2334 assert!(res1.is_ok());
2335
2336 let penalty2 = bspline_penalty_matrix(&t, 10, 4, 2);
2338 let fdpar2 = FdPar {
2339 basis_type: BasisType::Bspline { order: 4 },
2340 nbasis: 10,
2341 lambda: 1e-2,
2342 lfd_order: 2,
2343 penalty_matrix: penalty2,
2344 };
2345 let res2 = smooth_basis(&data, &t, &fdpar2);
2346 assert!(res2.is_ok());
2347
2348 let r1 = res1.unwrap();
2350 let r2 = res2.unwrap();
2351 let diff: f64 = (0..m)
2352 .map(|j| (r1.fitted[(0, j)] - r2.fitted[(0, j)]).abs())
2353 .sum();
2354 assert!(
2355 diff > 1e-10,
2356 "Different lfd_orders should produce different fits"
2357 );
2358 }
2359
2360 #[test]
2363 fn test_basis_nbasis_cv_result_fields() {
2364 let (data, t) = make_test_data(6, 50);
2365 let nbasis_range: Vec<usize> = vec![5, 7, 9, 11, 13];
2366 let res = basis_nbasis_cv(
2367 &data,
2368 &t,
2369 &nbasis_range,
2370 &BasisType::Bspline { order: 4 },
2371 BasisCriterion::Aic,
2372 5,
2373 1e-4,
2374 )
2375 .unwrap();
2376
2377 assert!(nbasis_range.contains(&res.optimal_nbasis));
2378 assert_eq!(res.scores.len(), nbasis_range.len());
2379 assert_eq!(res.nbasis_range, nbasis_range);
2380 assert_eq!(res.criterion, BasisCriterion::Aic);
2381 let min_score = res.scores.iter().copied().fold(f64::INFINITY, f64::min);
2383 let best_idx = res
2384 .scores
2385 .iter()
2386 .position(|&s| (s - min_score).abs() < 1e-15)
2387 .unwrap();
2388 assert_eq!(res.optimal_nbasis, nbasis_range[best_idx]);
2389 }
2390
2391 #[test]
2392 fn test_basis_nbasis_cv_result_clone() {
2393 let (data, t) = make_test_data(5, 50);
2394 let nbasis_range: Vec<usize> = vec![5, 10];
2395 let res = basis_nbasis_cv(
2396 &data,
2397 &t,
2398 &nbasis_range,
2399 &BasisType::Bspline { order: 4 },
2400 BasisCriterion::Gcv,
2401 5,
2402 1e-4,
2403 )
2404 .unwrap();
2405 let cloned = res.clone();
2406 assert_eq!(res, cloned);
2407 }
2408
2409 #[test]
2412 fn test_smooth_basis_nonuniform_argvals() {
2413 let m = 50;
2414 let t: Vec<f64> = (0..m)
2416 .map(|i| {
2417 let x = i as f64 / (m - 1) as f64;
2418 0.5 * (1.0 - (PI * x).cos())
2419 })
2420 .collect();
2421 let mut data = FdMatrix::zeros(2, m);
2422 for i in 0..2 {
2423 for j in 0..m {
2424 data[(i, j)] = (2.0 * PI * t[j]).sin() + 0.1 * i as f64;
2425 }
2426 }
2427 let fdpar = make_bspline_fdpar(&t, 10, 1e-4);
2428 let res = smooth_basis(&data, &t, &fdpar);
2429 assert!(res.is_ok(), "Should handle non-uniform argvals");
2430 let r = res.unwrap();
2431 assert_eq!(r.fitted.shape(), (2, m));
2432 }
2433
2434 #[test]
2437 fn test_smooth_basis_very_small_lambda() {
2438 let m = 50;
2439 let (data, t) = make_test_data(2, m);
2440 let fdpar = make_bspline_fdpar(&t, 10, 1e-15);
2441 let res = smooth_basis(&data, &t, &fdpar);
2442 assert!(res.is_ok(), "Should handle very small lambda");
2443 }
2444
2445 #[test]
2446 fn test_smooth_basis_very_large_lambda() {
2447 let m = 50;
2448 let (data, t) = make_test_data(2, m);
2449 let fdpar = make_bspline_fdpar(&t, 10, 1e10);
2450 let res = smooth_basis(&data, &t, &fdpar);
2451 assert!(res.is_ok(), "Should handle very large lambda");
2452 }
2453
2454 #[test]
2457 fn test_smooth_basis_multi_curve_vs_single_curve() {
2458 let m = 50;
2460 let n = 3;
2461 let (data, t) = make_test_data(n, m);
2462 let fdpar = make_bspline_fdpar(&t, 10, 1e-3);
2463
2464 let res_all = smooth_basis(&data, &t, &fdpar).unwrap();
2466
2467 for i in 0..n {
2469 let mut single = FdMatrix::zeros(1, m);
2470 for j in 0..m {
2471 single[(0, j)] = data[(i, j)];
2472 }
2473 let res_single = smooth_basis(&single, &t, &fdpar).unwrap();
2474 for j in 0..m {
2475 assert!(
2476 (res_all.fitted[(i, j)] - res_single.fitted[(0, j)]).abs() < 1e-10,
2477 "Multi-curve fit should match single-curve fit: curve {} point {}",
2478 i,
2479 j
2480 );
2481 }
2482 }
2483 }
2484
2485 #[test]
2488 fn test_basis_nbasis_cv_all_criteria_finite_scores() {
2489 let (data, t) = make_test_data(10, 60);
2490 let nbasis_range: Vec<usize> = vec![5, 7, 9, 11];
2491
2492 for criterion in [
2493 BasisCriterion::Gcv,
2494 BasisCriterion::Aic,
2495 BasisCriterion::Bic,
2496 BasisCriterion::Cv,
2497 ] {
2498 let res = basis_nbasis_cv(
2499 &data,
2500 &t,
2501 &nbasis_range,
2502 &BasisType::Bspline { order: 4 },
2503 criterion,
2504 5,
2505 1e-4,
2506 )
2507 .unwrap();
2508 let finite_count = res.scores.iter().filter(|s| s.is_finite()).count();
2510 assert!(
2511 finite_count > 0,
2512 "At least one score should be finite for {:?}",
2513 criterion
2514 );
2515 }
2516 }
2517
2518 #[test]
2521 fn test_smooth_basis_gcv_config_default() {
2522 let config = SmoothBasisGcvConfig::default();
2523 assert_eq!(config.basis_type, BasisType::Bspline { order: 4 });
2524 assert_eq!(config.nbasis, 15);
2525 assert_eq!(config.lfd_order, 2);
2526 assert_eq!(config.log_lambda_range, (-10.0, 2.0));
2527 assert_eq!(config.n_grid, 50);
2528 }
2529
2530 #[test]
2531 fn test_smooth_basis_gcv_config_clone_eq() {
2532 let config = SmoothBasisGcvConfig {
2533 nbasis: 20,
2534 ..SmoothBasisGcvConfig::default()
2535 };
2536 let cloned = config.clone();
2537 assert_eq!(config, cloned);
2538 }
2539
2540 #[test]
2541 fn test_smooth_basis_gcv_config_debug() {
2542 let config = SmoothBasisGcvConfig::default();
2543 let debug_str = format!("{:?}", config);
2544 assert!(debug_str.contains("SmoothBasisGcvConfig"));
2545 assert!(debug_str.contains("nbasis"));
2546 }
2547
2548 #[test]
2549 fn test_smooth_basis_gcv_config_partial_override() {
2550 let config = SmoothBasisGcvConfig {
2551 basis_type: BasisType::Fourier { period: 2.0 },
2552 n_grid: 100,
2553 ..SmoothBasisGcvConfig::default()
2554 };
2555 assert_eq!(config.basis_type, BasisType::Fourier { period: 2.0 });
2556 assert_eq!(config.n_grid, 100);
2557 assert_eq!(config.nbasis, 15);
2559 assert_eq!(config.lfd_order, 2);
2560 }
2561
2562 #[test]
2563 fn test_smooth_basis_gcv_with_config_default() {
2564 let (data, t) = make_test_data(5, 101);
2565 let config = SmoothBasisGcvConfig::default();
2566 let result = smooth_basis_gcv_with_config(&data, &t, &config);
2567 assert!(result.is_ok(), "GCV with default config should succeed");
2568 let res = result.unwrap();
2569 assert_eq!(res.fitted.shape(), (5, 101));
2570 assert!(res.edf > 0.0);
2571 assert!(res.gcv.is_finite());
2572 }
2573
2574 #[test]
2575 fn test_smooth_basis_gcv_with_config_custom() {
2576 let (data, t) = make_test_data(3, 50);
2577 let config = SmoothBasisGcvConfig {
2578 nbasis: 10,
2579 log_lambda_range: (-6.0, 0.0),
2580 n_grid: 15,
2581 ..SmoothBasisGcvConfig::default()
2582 };
2583 let result = smooth_basis_gcv_with_config(&data, &t, &config);
2584 assert!(result.is_ok());
2585 }
2586
2587 #[test]
2588 fn test_smooth_basis_gcv_with_config_matches_direct() {
2589 let (data, t) = make_test_data(3, 50);
2590 let config = SmoothBasisGcvConfig {
2591 nbasis: 10,
2592 log_lambda_range: (-6.0, 0.0),
2593 n_grid: 20,
2594 ..SmoothBasisGcvConfig::default()
2595 };
2596 let with_config = smooth_basis_gcv_with_config(&data, &t, &config).unwrap();
2597 let direct = smooth_basis_gcv(
2598 &data,
2599 &t,
2600 &config.basis_type,
2601 config.nbasis,
2602 config.lfd_order,
2603 config.log_lambda_range,
2604 config.n_grid,
2605 )
2606 .unwrap();
2607 assert_eq!(with_config.gcv, direct.gcv);
2608 assert_eq!(with_config.edf, direct.edf);
2609 assert_eq!(with_config.nbasis, direct.nbasis);
2610 }
2611
2612 #[test]
2613 fn test_smooth_basis_gcv_with_config_fourier() {
2614 let m = 100;
2615 let t = uniform_grid(m);
2616 let mut data = FdMatrix::zeros(2, m);
2617 for i in 0..2 {
2618 for j in 0..m {
2619 data[(i, j)] = (2.0 * PI * t[j]).sin() + (4.0 * PI * t[j]).cos();
2620 }
2621 }
2622 let config = SmoothBasisGcvConfig {
2623 basis_type: BasisType::Fourier { period: 1.0 },
2624 nbasis: 7,
2625 n_grid: 20,
2626 ..SmoothBasisGcvConfig::default()
2627 };
2628 let result = smooth_basis_gcv_with_config(&data, &t, &config);
2629 assert!(result.is_ok());
2630 }
2631
2632 #[test]
2635 fn test_basis_nbasis_cv_config_default() {
2636 let config = BasisNbasisCvConfig::default();
2637 assert_eq!(config.basis_type, BasisType::Bspline { order: 4 });
2638 assert_eq!(config.nbasis_range, (5, 30));
2639 assert!((config.lambda - 1e-4).abs() < 1e-15);
2640 assert_eq!(config.lfd_order, 2);
2641 assert_eq!(config.n_folds, 5);
2642 assert_eq!(config.criterion, BasisCriterion::Gcv);
2643 }
2644
2645 #[test]
2646 fn test_basis_nbasis_cv_config_clone_eq() {
2647 let config = BasisNbasisCvConfig {
2648 nbasis_range: (4, 15),
2649 ..BasisNbasisCvConfig::default()
2650 };
2651 let cloned = config.clone();
2652 assert_eq!(config, cloned);
2653 }
2654
2655 #[test]
2656 fn test_basis_nbasis_cv_config_debug() {
2657 let config = BasisNbasisCvConfig::default();
2658 let debug_str = format!("{:?}", config);
2659 assert!(debug_str.contains("BasisNbasisCvConfig"));
2660 assert!(debug_str.contains("nbasis_range"));
2661 }
2662
2663 #[test]
2664 fn test_basis_nbasis_cv_config_partial_override() {
2665 let config = BasisNbasisCvConfig {
2666 criterion: BasisCriterion::Aic,
2667 lambda: 1e-2,
2668 ..BasisNbasisCvConfig::default()
2669 };
2670 assert_eq!(config.criterion, BasisCriterion::Aic);
2671 assert!((config.lambda - 1e-2).abs() < 1e-15);
2672 assert_eq!(config.nbasis_range, (5, 30));
2674 assert_eq!(config.n_folds, 5);
2675 }
2676
2677 #[test]
2678 fn test_basis_nbasis_cv_with_config_default() {
2679 let (data, t) = make_test_data(5, 51);
2680 let config = BasisNbasisCvConfig {
2681 nbasis_range: (5, 12),
2682 ..BasisNbasisCvConfig::default()
2683 };
2684 let result = basis_nbasis_cv_with_config(&data, &t, &config);
2685 assert!(
2686 result.is_ok(),
2687 "nbasis CV with default config should succeed"
2688 );
2689 let res = result.unwrap();
2690 assert!(res.optimal_nbasis >= 5 && res.optimal_nbasis <= 12);
2691 assert_eq!(res.scores.len(), 8); assert_eq!(res.criterion, BasisCriterion::Gcv);
2693 }
2694
2695 #[test]
2696 fn test_basis_nbasis_cv_with_config_aic() {
2697 let (data, t) = make_test_data(5, 51);
2698 let config = BasisNbasisCvConfig {
2699 nbasis_range: (5, 10),
2700 criterion: BasisCriterion::Aic,
2701 ..BasisNbasisCvConfig::default()
2702 };
2703 let result = basis_nbasis_cv_with_config(&data, &t, &config);
2704 assert!(result.is_ok());
2705 assert_eq!(result.unwrap().criterion, BasisCriterion::Aic);
2706 }
2707
2708 #[test]
2709 fn test_basis_nbasis_cv_with_config_cv_folds() {
2710 let (data, t) = make_test_data(10, 51);
2711 let config = BasisNbasisCvConfig {
2712 nbasis_range: (5, 9),
2713 criterion: BasisCriterion::Cv,
2714 n_folds: 3,
2715 ..BasisNbasisCvConfig::default()
2716 };
2717 let result = basis_nbasis_cv_with_config(&data, &t, &config);
2718 assert!(result.is_ok());
2719 assert_eq!(result.unwrap().criterion, BasisCriterion::Cv);
2720 }
2721
2722 #[test]
2723 fn test_basis_nbasis_cv_with_config_matches_direct() {
2724 let (data, t) = make_test_data(5, 51);
2725 let config = BasisNbasisCvConfig {
2726 nbasis_range: (5, 10),
2727 criterion: BasisCriterion::Bic,
2728 lambda: 1e-3,
2729 ..BasisNbasisCvConfig::default()
2730 };
2731 let with_config = basis_nbasis_cv_with_config(&data, &t, &config).unwrap();
2732 let nbasis_range: Vec<usize> = (5..=10).collect();
2733 let direct = basis_nbasis_cv(
2734 &data,
2735 &t,
2736 &nbasis_range,
2737 &config.basis_type,
2738 config.criterion,
2739 config.n_folds,
2740 config.lambda,
2741 )
2742 .unwrap();
2743 assert_eq!(with_config.optimal_nbasis, direct.optimal_nbasis);
2744 assert_eq!(with_config.scores, direct.scores);
2745 assert_eq!(with_config.nbasis_range, direct.nbasis_range);
2746 }
2747
2748 #[test]
2749 fn test_basis_nbasis_cv_with_config_nbasis_range_expansion() {
2750 let (data, t) = make_test_data(5, 51);
2751 let config = BasisNbasisCvConfig {
2752 nbasis_range: (7, 7), ..BasisNbasisCvConfig::default()
2754 };
2755 let result = basis_nbasis_cv_with_config(&data, &t, &config);
2756 assert!(result.is_ok());
2757 let res = result.unwrap();
2758 assert_eq!(res.optimal_nbasis, 7);
2759 assert_eq!(res.scores.len(), 1);
2760 }
2761}