1use crate::basis::{bspline_basis, fourier_basis_with_period};
13use crate::helpers::simpsons_weights;
14use crate::matrix::FdMatrix;
15use nalgebra::DMatrix;
16use std::f64::consts::PI;
17
18#[derive(Debug, Clone, PartialEq)]
22pub enum BasisType {
23 Bspline { order: usize },
25 Fourier { period: f64 },
27}
28
29#[derive(Debug, Clone, PartialEq)]
31pub struct FdPar {
32 pub basis_type: BasisType,
34 pub nbasis: usize,
36 pub lambda: f64,
38 pub lfd_order: usize,
40 pub penalty_matrix: Vec<f64>,
42}
43
44#[derive(Debug, Clone, PartialEq)]
46pub struct SmoothBasisResult {
47 pub coefficients: FdMatrix,
49 pub fitted: FdMatrix,
51 pub edf: f64,
53 pub gcv: f64,
55 pub aic: f64,
57 pub bic: f64,
59 pub penalty_matrix: Vec<f64>,
61 pub nbasis: usize,
63}
64
65pub fn bspline_penalty_matrix(
82 argvals: &[f64],
83 nbasis: usize,
84 order: usize,
85 lfd_order: usize,
86) -> Vec<f64> {
87 if nbasis < 2 || order < 1 || lfd_order >= order || argvals.len() < 2 {
88 return vec![0.0; nbasis * nbasis];
89 }
90
91 let nknots = nbasis.saturating_sub(order).max(2);
92
93 let n_sub = 10;
95 let t_min = argvals[0];
96 let t_max = argvals[argvals.len() - 1];
97 let n_quad = (argvals.len() - 1) * n_sub + 1;
98 let quad_t: Vec<f64> = (0..n_quad)
99 .map(|i| t_min + (t_max - t_min) * i as f64 / (n_quad - 1) as f64)
100 .collect();
101
102 let basis_fine = bspline_basis(&quad_t, nknots, order);
104 let actual_nbasis = basis_fine.len() / n_quad;
105
106 let h = (t_max - t_min) / (n_quad - 1) as f64;
108 let deriv_basis = differentiate_basis_columns(&basis_fine, n_quad, actual_nbasis, h, lfd_order);
109
110 let weights = simpsons_weights(&quad_t);
112
113 integrate_symmetric_penalty(&deriv_basis, &weights, actual_nbasis, n_quad)
115}
116
117pub fn fourier_penalty_matrix(nbasis: usize, period: f64, lfd_order: usize) -> Vec<f64> {
129 let k = nbasis;
130 let mut penalty = vec![0.0; k * k];
131
132 let mut freq = 1;
138 let mut idx = 1;
139 while idx < k {
140 let omega = 2.0 * PI * f64::from(freq) / period;
141 let eigenval = omega.powi(2 * lfd_order as i32);
142
143 if idx < k {
145 penalty[idx + idx * k] = eigenval;
146 idx += 1;
147 }
148 if idx < k {
150 penalty[idx + idx * k] = eigenval;
151 idx += 1;
152 }
153 freq += 1;
154 }
155
156 penalty
157}
158
159pub fn smooth_basis(
174 data: &FdMatrix,
175 argvals: &[f64],
176 fdpar: &FdPar,
177) -> Result<SmoothBasisResult, crate::FdarError> {
178 let (n, m) = data.shape();
179 if n == 0 || m == 0 || argvals.len() != m || fdpar.nbasis < 2 {
180 return Err(crate::FdarError::InvalidDimension {
181 parameter: "data/argvals/fdpar",
182 expected: "n > 0, m > 0, argvals.len() == m, nbasis >= 2".to_string(),
183 actual: format!(
184 "n={}, m={}, argvals.len()={}, nbasis={}",
185 n,
186 m,
187 argvals.len(),
188 fdpar.nbasis
189 ),
190 });
191 }
192
193 let (basis_flat, actual_nbasis) = evaluate_basis(argvals, &fdpar.basis_type, fdpar.nbasis);
195 let k = actual_nbasis;
196
197 let b_mat = DMatrix::from_column_slice(m, k, &basis_flat);
198 let r_mat = DMatrix::from_column_slice(k, k, &fdpar.penalty_matrix);
199
200 let btb = b_mat.transpose() * &b_mat;
202 let ridge_eps = 1e-10;
203 let system: DMatrix<f64> =
204 &btb + fdpar.lambda * &r_mat + ridge_eps * DMatrix::<f64>::identity(k, k);
205
206 let system_inv =
208 invert_penalized_system(&system, k).ok_or_else(|| crate::FdarError::ComputationFailed {
209 operation: "matrix inversion",
210 detail: "failed to invert penalized system (Φ'Φ + λR); try increasing lambda or reducing the number of basis functions".to_string(),
211 })?;
212
213 let h_mat = &b_mat * &system_inv * b_mat.transpose();
215 let edf: f64 = (0..m).map(|i| h_mat[(i, i)]).sum();
216
217 let proj = &system_inv * b_mat.transpose();
219 let (all_coefs, all_fitted, total_rss) = project_all_curves(data, &b_mat, &proj, n, m, k);
220
221 let total_points = (n * m) as f64;
222 let gcv = compute_gcv(total_rss, total_points, edf, m);
223 let mse = total_rss / total_points;
224 let total_edf = n as f64 * edf;
226 let aic = total_points * mse.max(1e-300).ln() + 2.0 * total_edf;
227 let bic = total_points * mse.max(1e-300).ln() + total_points.ln() * total_edf;
228
229 Ok(SmoothBasisResult {
230 coefficients: all_coefs,
231 fitted: all_fitted,
232 edf,
233 gcv,
234 aic,
235 bic,
236 penalty_matrix: fdpar.penalty_matrix.clone(),
237 nbasis: k,
238 })
239}
240
241pub fn smooth_basis_gcv(
254 data: &FdMatrix,
255 argvals: &[f64],
256 basis_type: &BasisType,
257 nbasis: usize,
258 lfd_order: usize,
259 log_lambda_range: (f64, f64),
260 n_grid: usize,
261) -> Option<SmoothBasisResult> {
262 let m = argvals.len();
263 if m == 0 || nbasis < 2 || n_grid < 2 {
264 return None;
265 }
266
267 let penalty = match basis_type {
269 BasisType::Bspline { order } => bspline_penalty_matrix(argvals, nbasis, *order, lfd_order),
270 BasisType::Fourier { period } => fourier_penalty_matrix(nbasis, *period, lfd_order),
271 };
272
273 let (lo, hi) = log_lambda_range;
274 let mut best_gcv = f64::INFINITY;
275 let mut best_result: Option<SmoothBasisResult> = None;
276
277 for i in 0..n_grid {
278 let log_lam = lo + (hi - lo) * i as f64 / (n_grid - 1) as f64;
279 let lam = 10.0_f64.powf(log_lam);
280
281 let fdpar = FdPar {
282 basis_type: basis_type.clone(),
283 nbasis,
284 lambda: lam,
285 lfd_order,
286 penalty_matrix: penalty.clone(),
287 };
288
289 if let Ok(result) = smooth_basis(data, argvals, &fdpar) {
290 if result.gcv < best_gcv {
291 best_gcv = result.gcv;
292 best_result = Some(result);
293 }
294 }
295 }
296
297 best_result
298}
299
300#[derive(Debug, Clone, PartialEq)]
318pub struct SmoothBasisGcvConfig {
319 pub basis_type: BasisType,
321 pub nbasis: usize,
323 pub lfd_order: usize,
325 pub log_lambda_range: (f64, f64),
327 pub n_grid: usize,
329}
330
331impl Default for SmoothBasisGcvConfig {
332 fn default() -> Self {
333 Self {
334 basis_type: BasisType::Bspline { order: 4 },
335 nbasis: 15,
336 lfd_order: 2,
337 log_lambda_range: (-10.0, 2.0),
338 n_grid: 50,
339 }
340 }
341}
342
343#[must_use = "expensive computation whose result should not be discarded"]
358pub fn smooth_basis_gcv_with_config(
359 data: &FdMatrix,
360 argvals: &[f64],
361 config: &SmoothBasisGcvConfig,
362) -> Result<SmoothBasisResult, crate::FdarError> {
363 smooth_basis_gcv(
364 data,
365 argvals,
366 &config.basis_type,
367 config.nbasis,
368 config.lfd_order,
369 config.log_lambda_range,
370 config.n_grid,
371 )
372 .ok_or_else(|| crate::FdarError::ComputationFailed {
373 operation: "smooth_basis_gcv_with_config",
374 detail: "no valid smoothing result found in GCV lambda search".to_string(),
375 })
376}
377
378#[derive(Debug, Clone, PartialEq)]
394pub struct BasisNbasisCvConfig {
395 pub basis_type: BasisType,
397 pub nbasis_range: (usize, usize),
399 pub lambda: f64,
401 pub lfd_order: usize,
403 pub n_folds: usize,
405 pub criterion: BasisCriterion,
407}
408
409impl Default for BasisNbasisCvConfig {
410 fn default() -> Self {
411 Self {
412 basis_type: BasisType::Bspline { order: 4 },
413 nbasis_range: (5, 30),
414 lambda: 1e-4,
415 lfd_order: 2,
416 n_folds: 5,
417 criterion: BasisCriterion::Gcv,
418 }
419 }
420}
421
422#[must_use = "expensive computation whose result should not be discarded"]
440pub fn basis_nbasis_cv_with_config(
441 data: &FdMatrix,
442 argvals: &[f64],
443 config: &BasisNbasisCvConfig,
444) -> Result<BasisNbasisCvResult, crate::FdarError> {
445 let nbasis_range: Vec<usize> = (config.nbasis_range.0..=config.nbasis_range.1).collect();
446 basis_nbasis_cv(
447 data,
448 argvals,
449 &nbasis_range,
450 &config.basis_type,
451 config.criterion,
452 config.n_folds,
453 config.lambda,
454 )
455 .ok_or_else(|| crate::FdarError::ComputationFailed {
456 operation: "basis_nbasis_cv_with_config",
457 detail: "no valid result found in nbasis CV search".to_string(),
458 })
459}
460
461fn differentiate_basis_columns(
465 basis: &[f64],
466 n_quad: usize,
467 nbasis: usize,
468 h: f64,
469 lfd_order: usize,
470) -> Vec<f64> {
471 let mut deriv = basis.to_vec();
472 for _ in 0..lfd_order {
473 let mut new_deriv = vec![0.0; n_quad * nbasis];
474 for j in 0..nbasis {
475 let col: Vec<f64> = (0..n_quad).map(|i| deriv[i + j * n_quad]).collect();
476 let grad = crate::helpers::gradient_uniform(&col, h);
477 for i in 0..n_quad {
478 new_deriv[i + j * n_quad] = grad[i];
479 }
480 }
481 deriv = new_deriv;
482 }
483 deriv
484}
485
486fn integrate_symmetric_penalty(
488 deriv_basis: &[f64],
489 weights: &[f64],
490 k: usize,
491 n_quad: usize,
492) -> Vec<f64> {
493 let mut penalty = vec![0.0; k * k];
494 for j in 0..k {
495 for l in j..k {
496 let mut val = 0.0;
497 for i in 0..n_quad {
498 val += deriv_basis[i + j * n_quad] * deriv_basis[i + l * n_quad] * weights[i];
499 }
500 penalty[j + l * k] = val;
501 penalty[l + j * k] = val;
502 }
503 }
504 penalty
505}
506
507fn evaluate_basis(argvals: &[f64], basis_type: &BasisType, nbasis: usize) -> (Vec<f64>, usize) {
509 let m = argvals.len();
510 match basis_type {
511 BasisType::Bspline { order } => {
512 let nknots = nbasis.saturating_sub(*order).max(2);
513 let basis = bspline_basis(argvals, nknots, *order);
514 let actual = basis.len() / m;
515 (basis, actual)
516 }
517 BasisType::Fourier { period } => {
518 let basis = fourier_basis_with_period(argvals, nbasis, *period);
519 (basis, nbasis)
520 }
521 }
522}
523
524fn invert_penalized_system(system: &DMatrix<f64>, k: usize) -> Option<DMatrix<f64>> {
526 if let Some(chol) = system.clone().cholesky() {
527 return Some(chol.inverse());
528 }
529 let svd = nalgebra::SVD::new(system.clone(), true, true);
531 let u = svd.u.as_ref()?;
532 let v_t = svd.v_t.as_ref()?;
533 let max_sv: f64 = svd.singular_values.iter().copied().fold(0.0_f64, f64::max);
534 let eps = 1e-10 * max_sv;
535 let mut inv = DMatrix::<f64>::zeros(k, k);
536 for ii in 0..k {
537 for jj in 0..k {
538 let mut sum = 0.0;
539 for s in 0..k.min(svd.singular_values.len()) {
540 if svd.singular_values[s] > eps {
541 sum += v_t[(s, ii)] / svd.singular_values[s] * u[(jj, s)];
542 }
543 }
544 inv[(ii, jj)] = sum;
545 }
546 }
547 Some(inv)
548}
549
550fn project_all_curves(
552 data: &FdMatrix,
553 b_mat: &DMatrix<f64>,
554 proj: &DMatrix<f64>,
555 n: usize,
556 m: usize,
557 k: usize,
558) -> (FdMatrix, FdMatrix, f64) {
559 let mut all_coefs = FdMatrix::zeros(n, k);
560 let mut all_fitted = FdMatrix::zeros(n, m);
561 let mut total_rss = 0.0;
562
563 for i in 0..n {
564 let curve: Vec<f64> = (0..m).map(|j| data[(i, j)]).collect();
565 let y_vec = nalgebra::DVector::from_vec(curve.clone());
566 let coefs = proj * &y_vec;
567
568 for j in 0..k {
569 all_coefs[(i, j)] = coefs[j];
570 }
571 let fitted = b_mat * &coefs;
572 for j in 0..m {
573 all_fitted[(i, j)] = fitted[j];
574 let resid = curve[j] - fitted[j];
575 total_rss += resid * resid;
576 }
577 }
578
579 (all_coefs, all_fitted, total_rss)
580}
581
582fn compute_gcv(rss: f64, n_points: f64, edf: f64, m: usize) -> f64 {
584 let gcv_denom = 1.0 - edf / m as f64;
585 if gcv_denom.abs() > 1e-10 {
586 (rss / n_points) / (gcv_denom * gcv_denom)
587 } else {
588 f64::INFINITY
589 }
590}
591
592#[derive(Debug, Clone, Copy, PartialEq)]
596pub enum BasisCriterion {
597 Gcv,
599 Cv,
601 Aic,
603 Bic,
605}
606
607#[derive(Debug, Clone, PartialEq)]
609pub struct BasisNbasisCvResult {
610 pub optimal_nbasis: usize,
612 pub scores: Vec<f64>,
614 pub nbasis_range: Vec<usize>,
616 pub criterion: BasisCriterion,
618}
619
620fn evaluate_nbasis_info_criterion(
622 data: &FdMatrix,
623 argvals: &[f64],
624 nbasis_range: &[usize],
625 basis_type: &BasisType,
626 criterion: BasisCriterion,
627 lambda: f64,
628) -> Vec<f64> {
629 let mut scores = Vec::with_capacity(nbasis_range.len());
630 for &nb in nbasis_range {
631 if nb < 2 {
632 scores.push(f64::INFINITY);
633 continue;
634 }
635 let penalty = match basis_type {
636 BasisType::Bspline { order } => bspline_penalty_matrix(argvals, nb, *order, 2),
637 BasisType::Fourier { period } => fourier_penalty_matrix(nb, *period, 2),
638 };
639 let fdpar = FdPar {
640 basis_type: basis_type.clone(),
641 nbasis: nb,
642 lambda,
643 lfd_order: 2,
644 penalty_matrix: penalty,
645 };
646 match smooth_basis(data, argvals, &fdpar) {
647 Ok(result) => {
648 let score = match criterion {
649 BasisCriterion::Gcv => result.gcv,
650 BasisCriterion::Aic => result.aic,
651 BasisCriterion::Bic => result.bic,
652 BasisCriterion::Cv => unreachable!(),
653 };
654 scores.push(score);
655 }
656 Err(_) => scores.push(f64::INFINITY),
657 }
658 }
659 scores
660}
661
662fn evaluate_nbasis_cv(
664 data: &FdMatrix,
665 argvals: &[f64],
666 nbasis_range: &[usize],
667 basis_type: &BasisType,
668 lambda: f64,
669 n_folds: usize,
670) -> Vec<f64> {
671 let (n, m) = data.shape();
672 let n_folds = n_folds.max(2);
673 let folds = crate::cv::create_folds(n, n_folds, 42);
674 let mut scores = Vec::with_capacity(nbasis_range.len());
675
676 for &nb in nbasis_range {
677 if nb < 2 {
678 scores.push(f64::INFINITY);
679 continue;
680 }
681 let penalty = match basis_type {
682 BasisType::Bspline { order } => bspline_penalty_matrix(argvals, nb, *order, 2),
683 BasisType::Fourier { period } => fourier_penalty_matrix(nb, *period, 2),
684 };
685
686 let mut total_mse = 0.0;
687 let mut total_count = 0;
688
689 for fold in 0..n_folds {
690 let (train_idx, test_idx) = crate::cv::fold_indices(&folds, fold);
691 if train_idx.is_empty() || test_idx.is_empty() {
692 continue;
693 }
694 let train_data = crate::cv::subset_rows(data, &train_idx);
695 let fdpar = FdPar {
696 basis_type: basis_type.clone(),
697 nbasis: nb,
698 lambda,
699 lfd_order: 2,
700 penalty_matrix: penalty.clone(),
701 };
702
703 if let Ok(train_result) = smooth_basis(&train_data, argvals, &fdpar) {
704 let (basis_flat, actual_k) = evaluate_basis(argvals, basis_type, nb);
705 let b_mat = DMatrix::from_column_slice(m, actual_k, &basis_flat);
706 let r_mat =
707 DMatrix::from_column_slice(actual_k, actual_k, &train_result.penalty_matrix);
708 let btb = b_mat.transpose() * &b_mat;
709 let ridge_eps = 1e-10;
710 let system: DMatrix<f64> = &btb
711 + lambda * &r_mat
712 + ridge_eps * DMatrix::<f64>::identity(actual_k, actual_k);
713
714 if let Some(system_inv) = invert_penalized_system(&system, actual_k) {
715 let proj = &system_inv * b_mat.transpose();
716 for &ti in &test_idx {
717 let curve: Vec<f64> = (0..m).map(|j| data[(ti, j)]).collect();
718 let y_vec = nalgebra::DVector::from_vec(curve.clone());
719 let coefs = &proj * &y_vec;
720 let fitted = &b_mat * &coefs;
721 let mse: f64 =
722 (0..m).map(|j| (curve[j] - fitted[j]).powi(2)).sum::<f64>() / m as f64;
723 total_mse += mse;
724 total_count += 1;
725 }
726 }
727 }
728 }
729
730 if total_count > 0 {
731 scores.push(total_mse / f64::from(total_count));
732 } else {
733 scores.push(f64::INFINITY);
734 }
735 }
736 scores
737}
738
739pub fn basis_nbasis_cv(
742 data: &FdMatrix,
743 argvals: &[f64],
744 nbasis_range: &[usize],
745 basis_type: &BasisType,
746 criterion: BasisCriterion,
747 n_folds: usize,
748 lambda: f64,
749) -> Option<BasisNbasisCvResult> {
750 let (n, m) = data.shape();
751 if n == 0 || m == 0 || argvals.len() != m || nbasis_range.is_empty() {
752 return None;
753 }
754
755 let scores = match criterion {
756 BasisCriterion::Gcv | BasisCriterion::Aic | BasisCriterion::Bic => {
757 evaluate_nbasis_info_criterion(
758 data,
759 argvals,
760 nbasis_range,
761 basis_type,
762 criterion,
763 lambda,
764 )
765 }
766 BasisCriterion::Cv => {
767 evaluate_nbasis_cv(data, argvals, nbasis_range, basis_type, lambda, n_folds)
768 }
769 };
770
771 let (best_idx, _) = scores
772 .iter()
773 .enumerate()
774 .min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))?;
775
776 Some(BasisNbasisCvResult {
777 optimal_nbasis: nbasis_range[best_idx],
778 scores,
779 nbasis_range: nbasis_range.to_vec(),
780 criterion,
781 })
782}
783
784#[cfg(test)]
785mod tests {
786 use super::*;
787 use crate::test_helpers::uniform_grid;
788 use std::f64::consts::PI;
789
790 #[test]
791 fn test_bspline_penalty_matrix_symmetric() {
792 let t = uniform_grid(101);
793 let penalty = bspline_penalty_matrix(&t, 15, 4, 2);
794 let _k = 15; let actual_k = (penalty.len() as f64).sqrt() as usize;
796 for i in 0..actual_k {
797 for j in 0..actual_k {
798 assert!(
799 (penalty[i + j * actual_k] - penalty[j + i * actual_k]).abs() < 1e-10,
800 "Penalty matrix not symmetric at ({}, {})",
801 i,
802 j
803 );
804 }
805 }
806 }
807
808 #[test]
809 fn test_bspline_penalty_matrix_positive_semidefinite() {
810 let t = uniform_grid(101);
811 let penalty = bspline_penalty_matrix(&t, 10, 4, 2);
812 let k = (penalty.len() as f64).sqrt() as usize;
813 for i in 0..k {
815 assert!(
816 penalty[i + i * k] >= -1e-10,
817 "Diagonal element {} is negative: {}",
818 i,
819 penalty[i + i * k]
820 );
821 }
822 }
823
824 #[test]
825 fn test_fourier_penalty_diagonal() {
826 let penalty = fourier_penalty_matrix(7, 1.0, 2);
827 for i in 0..7 {
829 for j in 0..7 {
830 if i != j {
831 assert!(
832 penalty[i + j * 7].abs() < 1e-10,
833 "Off-diagonal ({},{}) = {}",
834 i,
835 j,
836 penalty[i + j * 7]
837 );
838 }
839 }
840 }
841 assert!(penalty[0].abs() < 1e-10);
843 assert!(penalty[1 + 7] > 0.0);
845 assert!(penalty[3 + 3 * 7] > penalty[1 + 7]);
846 }
847
848 #[test]
849 fn test_smooth_basis_bspline() {
850 let m = 101;
851 let n = 5;
852 let t = uniform_grid(m);
853
854 let mut data = FdMatrix::zeros(n, m);
856 for i in 0..n {
857 for j in 0..m {
858 data[(i, j)] = (2.0 * PI * t[j]).sin() + 0.1 * (i as f64 * 0.3 + j as f64 * 0.01);
859 }
860 }
861
862 let nbasis = 15;
863 let penalty = bspline_penalty_matrix(&t, nbasis, 4, 2);
864 let _actual_k = (penalty.len() as f64).sqrt() as usize;
865
866 let fdpar = FdPar {
867 basis_type: BasisType::Bspline { order: 4 },
868 nbasis,
869 lambda: 1e-4,
870 lfd_order: 2,
871 penalty_matrix: penalty,
872 };
873
874 let result = smooth_basis(&data, &t, &fdpar);
875 assert!(result.is_ok(), "smooth_basis should succeed");
876
877 let res = result.unwrap();
878 assert_eq!(res.fitted.shape(), (n, m));
879 assert_eq!(res.coefficients.nrows(), n);
880 assert!(res.edf > 0.0, "EDF should be positive");
881 assert!(res.gcv > 0.0, "GCV should be positive");
882 }
883
884 #[test]
885 fn test_smooth_basis_fourier() {
886 let m = 101;
887 let n = 3;
888 let t = uniform_grid(m);
889
890 let mut data = FdMatrix::zeros(n, m);
891 for i in 0..n {
892 for j in 0..m {
893 data[(i, j)] = (2.0 * PI * t[j]).sin() + (4.0 * PI * t[j]).cos();
894 }
895 }
896
897 let nbasis = 7;
898 let period = 1.0;
899 let penalty = fourier_penalty_matrix(nbasis, period, 2);
900
901 let fdpar = FdPar {
902 basis_type: BasisType::Fourier { period },
903 nbasis,
904 lambda: 1e-6,
905 lfd_order: 2,
906 penalty_matrix: penalty,
907 };
908
909 let result = smooth_basis(&data, &t, &fdpar);
910 assert!(result.is_ok());
911
912 let res = result.unwrap();
913 for j in 0..m {
915 let expected = (2.0 * PI * t[j]).sin() + (4.0 * PI * t[j]).cos();
916 assert!(
917 (res.fitted[(0, j)] - expected).abs() < 0.1,
918 "Fourier fit poor at j={}: got {}, expected {}",
919 j,
920 res.fitted[(0, j)],
921 expected
922 );
923 }
924 }
925
926 #[test]
927 fn test_smooth_basis_gcv_selects_reasonable_lambda() {
928 let m = 101;
929 let n = 5;
930 let t = uniform_grid(m);
931
932 let mut data = FdMatrix::zeros(n, m);
933 for i in 0..n {
934 for j in 0..m {
935 data[(i, j)] =
936 (2.0 * PI * t[j]).sin() + 0.1 * ((i * 37 + j * 13) % 20) as f64 / 20.0;
937 }
938 }
939
940 let basis_type = BasisType::Bspline { order: 4 };
941 let result = smooth_basis_gcv(&data, &t, &basis_type, 15, 2, (-8.0, 4.0), 25);
942 assert!(result.is_some(), "GCV search should succeed");
943 }
944
945 #[test]
946 fn test_smooth_basis_large_lambda_reduces_edf() {
947 let m = 101;
948 let n = 3;
949 let t = uniform_grid(m);
950
951 let mut data = FdMatrix::zeros(n, m);
952 for i in 0..n {
953 for j in 0..m {
954 data[(i, j)] = (2.0 * PI * t[j]).sin();
955 }
956 }
957
958 let nbasis = 15;
959 let penalty = bspline_penalty_matrix(&t, nbasis, 4, 2);
960 let _actual_k = (penalty.len() as f64).sqrt() as usize;
961
962 let fdpar_small = FdPar {
963 basis_type: BasisType::Bspline { order: 4 },
964 nbasis,
965 lambda: 1e-8,
966 lfd_order: 2,
967 penalty_matrix: penalty.clone(),
968 };
969 let fdpar_large = FdPar {
970 basis_type: BasisType::Bspline { order: 4 },
971 nbasis,
972 lambda: 1e2,
973 lfd_order: 2,
974 penalty_matrix: penalty,
975 };
976
977 let res_small = smooth_basis(&data, &t, &fdpar_small).unwrap();
978 let res_large = smooth_basis(&data, &t, &fdpar_large).unwrap();
979
980 assert!(
981 res_large.edf < res_small.edf,
982 "Larger lambda should reduce EDF: {} vs {}",
983 res_large.edf,
984 res_small.edf
985 );
986 }
987
988 #[test]
991 fn test_basis_nbasis_cv_gcv() {
992 let m = 101;
993 let n = 5;
994 let t = uniform_grid(m);
995 let mut data = FdMatrix::zeros(n, m);
996 for i in 0..n {
997 for j in 0..m {
998 data[(i, j)] =
999 (2.0 * PI * t[j]).sin() + 0.1 * ((i * 37 + j * 13) % 20) as f64 / 20.0;
1000 }
1001 }
1002
1003 let nbasis_range: Vec<usize> = (4..=20).step_by(2).collect();
1004 let result = basis_nbasis_cv(
1005 &data,
1006 &t,
1007 &nbasis_range,
1008 &BasisType::Bspline { order: 4 },
1009 BasisCriterion::Gcv,
1010 5,
1011 1e-4,
1012 );
1013 assert!(result.is_some());
1014 let res = result.unwrap();
1015 assert!(nbasis_range.contains(&res.optimal_nbasis));
1016 assert_eq!(res.scores.len(), nbasis_range.len());
1017 assert_eq!(res.criterion, BasisCriterion::Gcv);
1018 }
1019
1020 #[test]
1021 fn test_basis_nbasis_cv_aic_bic() {
1022 let m = 51;
1023 let n = 5;
1024 let t = uniform_grid(m);
1025 let mut data = FdMatrix::zeros(n, m);
1026 for i in 0..n {
1027 for j in 0..m {
1028 data[(i, j)] = (2.0 * PI * t[j]).sin();
1029 }
1030 }
1031
1032 let nbasis_range: Vec<usize> = vec![5, 7, 9, 11];
1033 let aic_result = basis_nbasis_cv(
1034 &data,
1035 &t,
1036 &nbasis_range,
1037 &BasisType::Bspline { order: 4 },
1038 BasisCriterion::Aic,
1039 5,
1040 0.0,
1041 );
1042 let bic_result = basis_nbasis_cv(
1043 &data,
1044 &t,
1045 &nbasis_range,
1046 &BasisType::Bspline { order: 4 },
1047 BasisCriterion::Bic,
1048 5,
1049 0.0,
1050 );
1051 assert!(aic_result.is_some());
1052 assert!(bic_result.is_some());
1053 }
1054
1055 #[test]
1056 fn test_basis_nbasis_cv_kfold() {
1057 let m = 51;
1058 let n = 10;
1059 let t = uniform_grid(m);
1060 let mut data = FdMatrix::zeros(n, m);
1061 for i in 0..n {
1062 for j in 0..m {
1063 data[(i, j)] = (2.0 * PI * t[j]).sin() + 0.05 * ((i * 7 + j * 3) % 10) as f64;
1064 }
1065 }
1066
1067 let nbasis_range: Vec<usize> = vec![5, 7, 9];
1068 let result = basis_nbasis_cv(
1069 &data,
1070 &t,
1071 &nbasis_range,
1072 &BasisType::Bspline { order: 4 },
1073 BasisCriterion::Cv,
1074 5,
1075 1e-4,
1076 );
1077 assert!(result.is_some());
1078 let res = result.unwrap();
1079 assert!(nbasis_range.contains(&res.optimal_nbasis));
1080 assert_eq!(res.criterion, BasisCriterion::Cv);
1081 }
1082
1083 fn make_test_data(n: usize, m: usize) -> (FdMatrix, Vec<f64>) {
1087 let t: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
1088 let mut data = FdMatrix::zeros(n, m);
1089 for i in 0..n {
1090 for j in 0..m {
1091 data[(i, j)] = (2.0 * PI * t[j]).sin()
1092 + 0.1 * (10.0 * t[j]).sin()
1093 + 0.05 * ((i * 37 + j * 13) % 20) as f64 / 20.0;
1094 }
1095 }
1096 (data, t)
1097 }
1098
1099 fn make_bspline_fdpar(argvals: &[f64], nbasis: usize, lambda: f64) -> FdPar {
1101 let penalty = bspline_penalty_matrix(argvals, nbasis, 4, 2);
1102 FdPar {
1103 basis_type: BasisType::Bspline { order: 4 },
1104 nbasis,
1105 lambda,
1106 lfd_order: 2,
1107 penalty_matrix: penalty,
1108 }
1109 }
1110
1111 fn make_fourier_fdpar(nbasis: usize, period: f64, lambda: f64) -> FdPar {
1113 let penalty = fourier_penalty_matrix(nbasis, period, 2);
1114 FdPar {
1115 basis_type: BasisType::Fourier { period },
1116 nbasis,
1117 lambda,
1118 lfd_order: 2,
1119 penalty_matrix: penalty,
1120 }
1121 }
1122
1123 #[test]
1126 fn test_basis_type_bspline_variant() {
1127 let bt = BasisType::Bspline { order: 4 };
1128 assert_eq!(bt, BasisType::Bspline { order: 4 });
1129 assert_ne!(bt, BasisType::Bspline { order: 3 });
1131 }
1132
1133 #[test]
1134 fn test_basis_type_fourier_variant() {
1135 let bt = BasisType::Fourier { period: 1.0 };
1136 assert_eq!(bt, BasisType::Fourier { period: 1.0 });
1137 assert_ne!(bt, BasisType::Fourier { period: 2.0 });
1138 }
1139
1140 #[test]
1141 fn test_basis_type_cross_variant_inequality() {
1142 let bspline = BasisType::Bspline { order: 4 };
1143 let fourier = BasisType::Fourier { period: 1.0 };
1144 assert_ne!(bspline, fourier);
1145 }
1146
1147 #[test]
1148 fn test_basis_type_clone_and_debug() {
1149 let bt = BasisType::Bspline { order: 4 };
1150 let cloned = bt.clone();
1151 assert_eq!(bt, cloned);
1152 let debug_str = format!("{:?}", bt);
1153 assert!(debug_str.contains("Bspline"));
1154 assert!(debug_str.contains("4"));
1155 }
1156
1157 #[test]
1160 fn test_fdpar_construction_and_fields() {
1161 let penalty = vec![1.0, 0.0, 0.0, 1.0];
1162 let fdpar = FdPar {
1163 basis_type: BasisType::Bspline { order: 4 },
1164 nbasis: 2,
1165 lambda: 0.01,
1166 lfd_order: 2,
1167 penalty_matrix: penalty.clone(),
1168 };
1169 assert_eq!(fdpar.nbasis, 2);
1170 assert!((fdpar.lambda - 0.01).abs() < 1e-15);
1171 assert_eq!(fdpar.lfd_order, 2);
1172 assert_eq!(fdpar.penalty_matrix.len(), 4);
1173 }
1174
1175 #[test]
1176 fn test_fdpar_clone_and_debug() {
1177 let t = uniform_grid(50);
1178 let fdpar = make_bspline_fdpar(&t, 8, 1e-3);
1179 let cloned = fdpar.clone();
1180 assert_eq!(fdpar, cloned);
1181 let debug_str = format!("{:?}", fdpar);
1182 assert!(debug_str.contains("FdPar"));
1183 }
1184
1185 #[test]
1188 fn test_basis_criterion_variants() {
1189 assert_eq!(BasisCriterion::Gcv, BasisCriterion::Gcv);
1190 assert_eq!(BasisCriterion::Cv, BasisCriterion::Cv);
1191 assert_eq!(BasisCriterion::Aic, BasisCriterion::Aic);
1192 assert_eq!(BasisCriterion::Bic, BasisCriterion::Bic);
1193 assert_ne!(BasisCriterion::Gcv, BasisCriterion::Aic);
1194 assert_ne!(BasisCriterion::Cv, BasisCriterion::Bic);
1195 }
1196
1197 #[test]
1198 fn test_basis_criterion_copy() {
1199 let c = BasisCriterion::Gcv;
1200 let copied = c; assert_eq!(c, copied);
1202 }
1203
1204 #[test]
1205 fn test_basis_criterion_debug() {
1206 let debug_str = format!("{:?}", BasisCriterion::Bic);
1207 assert!(debug_str.contains("Bic"));
1208 }
1209
1210 #[test]
1213 fn test_smooth_basis_result_all_fields() {
1214 let (data, t) = make_test_data(3, 50);
1215 let fdpar = make_bspline_fdpar(&t, 10, 1e-4);
1216 let res = smooth_basis(&data, &t, &fdpar).unwrap();
1217
1218 assert_eq!(res.coefficients.nrows(), 3);
1220 assert!(res.coefficients.ncols() > 0);
1221 assert_eq!(res.nbasis, res.coefficients.ncols());
1222 assert_eq!(res.fitted.shape(), (3, 50));
1224 assert!(res.edf > 0.0 && res.edf <= res.nbasis as f64);
1226 assert!(res.gcv.is_finite());
1228 assert!(res.aic.is_finite());
1229 assert!(res.bic.is_finite());
1230 let k = res.nbasis;
1232 assert_eq!(res.penalty_matrix.len(), k * k);
1233 }
1234
1235 #[test]
1236 fn test_smooth_basis_result_clone() {
1237 let (data, t) = make_test_data(2, 50);
1238 let fdpar = make_bspline_fdpar(&t, 8, 1e-3);
1239 let res = smooth_basis(&data, &t, &fdpar).unwrap();
1240 let cloned = res.clone();
1241 assert_eq!(res, cloned);
1242 }
1243
1244 #[test]
1247 fn test_smooth_basis_bspline_coefficient_shape() {
1248 let (data, t) = make_test_data(4, 50);
1249 let nbasis = 12;
1250 let fdpar = make_bspline_fdpar(&t, nbasis, 1e-4);
1251 let res = smooth_basis(&data, &t, &fdpar).unwrap();
1252 assert_eq!(res.coefficients.nrows(), 4);
1253 assert!(res.coefficients.ncols() >= 2);
1255 assert_eq!(res.nbasis, res.coefficients.ncols());
1256 }
1257
1258 #[test]
1259 fn test_smooth_basis_bspline_fitted_values_shape() {
1260 let m = 80;
1261 let n = 6;
1262 let (data, t) = make_test_data(n, m);
1263 let fdpar = make_bspline_fdpar(&t, 15, 1e-4);
1264 let res = smooth_basis(&data, &t, &fdpar).unwrap();
1265 assert_eq!(res.fitted.shape(), (n, m));
1266 }
1267
1268 #[test]
1269 fn test_smooth_basis_bspline_zero_lambda_interpolates() {
1270 let m = 30;
1272 let n = 2;
1273 let (data, t) = make_test_data(n, m);
1274 let fdpar = make_bspline_fdpar(&t, 15, 0.0);
1275 let res = smooth_basis(&data, &t, &fdpar).unwrap();
1276
1277 let mut max_resid = 0.0_f64;
1279 for i in 0..n {
1280 for j in 0..m {
1281 let resid = (data[(i, j)] - res.fitted[(i, j)]).abs();
1282 max_resid = max_resid.max(resid);
1283 }
1284 }
1285 assert!(
1286 max_resid < 0.5,
1287 "Zero-lambda B-spline should closely interpolate; max_resid = {}",
1288 max_resid
1289 );
1290 }
1291
1292 #[test]
1293 fn test_smooth_basis_bspline_large_lambda_oversmooths() {
1294 let m = 50;
1297 let n = 1;
1298 let (data, t) = make_test_data(n, m);
1299
1300 let fdpar_small = make_bspline_fdpar(&t, 15, 1e-6);
1301 let res_small = smooth_basis(&data, &t, &fdpar_small).unwrap();
1302
1303 let fdpar_large = make_bspline_fdpar(&t, 15, 1e6);
1304 let res_large = smooth_basis(&data, &t, &fdpar_large).unwrap();
1305
1306 let compute_variance = |fitted: &FdMatrix, row: usize, ncols: usize| -> f64 {
1307 let vals: Vec<f64> = (0..ncols).map(|j| fitted[(row, j)]).collect();
1308 let mean = vals.iter().sum::<f64>() / ncols as f64;
1309 vals.iter().map(|&v| (v - mean).powi(2)).sum::<f64>() / ncols as f64
1310 };
1311
1312 let var_small = compute_variance(&res_small.fitted, 0, m);
1313 let var_large = compute_variance(&res_large.fitted, 0, m);
1314 assert!(
1315 var_large < var_small,
1316 "Large lambda should yield lower variance fit: var_large={}, var_small={}",
1317 var_large,
1318 var_small
1319 );
1320 }
1321
1322 #[test]
1323 fn test_smooth_basis_bspline_penalty_effect_on_smoothness() {
1324 let m = 50;
1326 let n = 1;
1327 let (data, t) = make_test_data(n, m);
1328
1329 let fdpar_small = make_bspline_fdpar(&t, 15, 1e-8);
1330 let fdpar_large = make_bspline_fdpar(&t, 15, 1.0);
1331
1332 let res_small = smooth_basis(&data, &t, &fdpar_small).unwrap();
1333 let res_large = smooth_basis(&data, &t, &fdpar_large).unwrap();
1334
1335 let roughness = |fitted: &FdMatrix, row: usize, ncols: usize| -> f64 {
1337 (1..ncols - 1)
1338 .map(|j| {
1339 let d2 = fitted[(row, j + 1)] - 2.0 * fitted[(row, j)] + fitted[(row, j - 1)];
1340 d2 * d2
1341 })
1342 .sum::<f64>()
1343 };
1344
1345 let r_small = roughness(&res_small.fitted, 0, m);
1346 let r_large = roughness(&res_large.fitted, 0, m);
1347 assert!(
1348 r_large < r_small,
1349 "Larger lambda should produce smoother fit: roughness_large={}, roughness_small={}",
1350 r_large,
1351 r_small
1352 );
1353 }
1354
1355 #[test]
1356 fn test_smooth_basis_bspline_single_curve() {
1357 let m = 50;
1358 let (data, t) = make_test_data(1, m);
1359 let fdpar = make_bspline_fdpar(&t, 10, 1e-4);
1360 let res = smooth_basis(&data, &t, &fdpar).unwrap();
1361 assert_eq!(res.fitted.nrows(), 1);
1362 assert_eq!(res.fitted.ncols(), m);
1363 assert!(res.gcv.is_finite());
1364 }
1365
1366 #[test]
1367 fn test_smooth_basis_bspline_many_curves() {
1368 let m = 50;
1369 let n = 20;
1370 let (data, t) = make_test_data(n, m);
1371 let fdpar = make_bspline_fdpar(&t, 10, 1e-4);
1372 let res = smooth_basis(&data, &t, &fdpar).unwrap();
1373 assert_eq!(res.fitted.nrows(), n);
1374 assert_eq!(res.coefficients.nrows(), n);
1375 }
1376
1377 #[test]
1378 fn test_smooth_basis_bspline_minimal_nbasis() {
1379 let m = 50;
1381 let (data, t) = make_test_data(1, m);
1382 let fdpar = make_bspline_fdpar(&t, 2, 1e-4);
1383 let res = smooth_basis(&data, &t, &fdpar);
1384 assert!(res.is_ok());
1386 }
1387
1388 #[test]
1389 fn test_smooth_basis_bspline_different_orders() {
1390 let m = 50;
1391 let (data, t) = make_test_data(2, m);
1392 let penalty3 = bspline_penalty_matrix(&t, 10, 3, 2);
1394 let fdpar3 = FdPar {
1395 basis_type: BasisType::Bspline { order: 3 },
1396 nbasis: 10,
1397 lambda: 1e-4,
1398 lfd_order: 2,
1399 penalty_matrix: penalty3,
1400 };
1401 let res3 = smooth_basis(&data, &t, &fdpar3);
1402 assert!(res3.is_ok());
1403
1404 let penalty5 = bspline_penalty_matrix(&t, 10, 5, 2);
1406 let fdpar5 = FdPar {
1407 basis_type: BasisType::Bspline { order: 5 },
1408 nbasis: 10,
1409 lambda: 1e-4,
1410 lfd_order: 2,
1411 penalty_matrix: penalty5,
1412 };
1413 let res5 = smooth_basis(&data, &t, &fdpar5);
1414 assert!(res5.is_ok());
1415 }
1416
1417 #[test]
1420 fn test_smooth_basis_fourier_coefficient_shape() {
1421 let m = 50;
1422 let n = 3;
1423 let t = uniform_grid(m);
1424 let mut data = FdMatrix::zeros(n, m);
1425 for i in 0..n {
1426 for j in 0..m {
1427 data[(i, j)] = (2.0 * PI * t[j]).sin();
1428 }
1429 }
1430 let nbasis = 7;
1431 let fdpar = make_fourier_fdpar(nbasis, 1.0, 1e-6);
1432 let res = smooth_basis(&data, &t, &fdpar).unwrap();
1433 assert_eq!(res.coefficients.nrows(), n);
1434 assert_eq!(res.coefficients.ncols(), nbasis);
1435 assert_eq!(res.nbasis, nbasis);
1436 }
1437
1438 #[test]
1439 fn test_smooth_basis_fourier_fits_pure_sine() {
1440 let m = 100;
1442 let t = uniform_grid(m);
1443 let mut data = FdMatrix::zeros(1, m);
1444 for j in 0..m {
1445 data[(0, j)] = (2.0 * PI * t[j]).sin();
1446 }
1447 let fdpar = make_fourier_fdpar(5, 1.0, 1e-8);
1448 let res = smooth_basis(&data, &t, &fdpar).unwrap();
1449
1450 for j in 0..m {
1451 let expected = (2.0 * PI * t[j]).sin();
1452 assert!(
1453 (res.fitted[(0, j)] - expected).abs() < 0.05,
1454 "Fourier should fit pure sine; j={}, got={}, expected={}",
1455 j,
1456 res.fitted[(0, j)],
1457 expected
1458 );
1459 }
1460 }
1461
1462 #[test]
1463 fn test_smooth_basis_fourier_different_periods() {
1464 let m = 50;
1465 let t = uniform_grid(m);
1466 let mut data = FdMatrix::zeros(1, m);
1467 for j in 0..m {
1468 data[(0, j)] = (2.0 * PI * t[j]).sin();
1469 }
1470
1471 let fdpar1 = make_fourier_fdpar(7, 1.0, 1e-6);
1473 let res1 = smooth_basis(&data, &t, &fdpar1).unwrap();
1474
1475 let fdpar2 = make_fourier_fdpar(7, 2.0, 1e-6);
1477 let res2 = smooth_basis(&data, &t, &fdpar2).unwrap();
1478
1479 assert_eq!(res1.fitted.shape(), (1, m));
1481 assert_eq!(res2.fitted.shape(), (1, m));
1482 }
1483
1484 #[test]
1485 fn test_smooth_basis_fourier_zero_lambda() {
1486 let m = 50;
1487 let t = uniform_grid(m);
1488 let mut data = FdMatrix::zeros(1, m);
1489 for j in 0..m {
1490 data[(0, j)] = (2.0 * PI * t[j]).sin() + (4.0 * PI * t[j]).cos();
1491 }
1492 let fdpar = make_fourier_fdpar(9, 1.0, 0.0);
1493 let res = smooth_basis(&data, &t, &fdpar).unwrap();
1494 assert_eq!(res.fitted.shape(), (1, m));
1495 assert!(res.edf > 1.0);
1497 }
1498
1499 #[test]
1500 fn test_smooth_basis_fourier_large_lambda() {
1501 let m = 50;
1502 let t = uniform_grid(m);
1503 let mut data = FdMatrix::zeros(1, m);
1504 for j in 0..m {
1505 data[(0, j)] = (2.0 * PI * t[j]).sin();
1506 }
1507 let fdpar = make_fourier_fdpar(9, 1.0, 1e6);
1508 let res = smooth_basis(&data, &t, &fdpar).unwrap();
1509 assert!(
1511 res.edf < 5.0,
1512 "Large lambda should reduce EDF; edf={}",
1513 res.edf
1514 );
1515 }
1516
1517 #[test]
1520 fn test_smooth_basis_lambda_gradient_edf() {
1521 let m = 50;
1523 let (data, t) = make_test_data(3, m);
1524 let lambdas = [1e-8, 1e-4, 1e-2, 1.0, 1e2];
1525 let mut prev_edf = f64::INFINITY;
1526 for &lam in &lambdas {
1527 let fdpar = make_bspline_fdpar(&t, 12, lam);
1528 let res = smooth_basis(&data, &t, &fdpar).unwrap();
1529 assert!(
1530 res.edf <= prev_edf + 0.01,
1531 "EDF should decrease: lambda={}, edf={}, prev_edf={}",
1532 lam,
1533 res.edf,
1534 prev_edf
1535 );
1536 prev_edf = res.edf;
1537 }
1538 }
1539
1540 #[test]
1541 fn test_smooth_basis_lambda_gradient_rss() {
1542 let m = 50;
1544 let n = 2;
1545 let (data, t) = make_test_data(n, m);
1546 let lambdas = [0.0, 1e-6, 1e-2, 1.0, 1e4];
1547 let mut prev_rss = -1.0;
1548 for &lam in &lambdas {
1549 let fdpar = make_bspline_fdpar(&t, 12, lam);
1550 let res = smooth_basis(&data, &t, &fdpar).unwrap();
1551 let mut rss = 0.0;
1552 for i in 0..n {
1553 for j in 0..m {
1554 rss += (data[(i, j)] - res.fitted[(i, j)]).powi(2);
1555 }
1556 }
1557 assert!(
1558 rss >= prev_rss - 1e-8,
1559 "RSS should increase: lambda={}, rss={}, prev_rss={}",
1560 lam,
1561 rss,
1562 prev_rss
1563 );
1564 prev_rss = rss;
1565 }
1566 }
1567
1568 #[test]
1571 fn test_smooth_basis_empty_data_rows() {
1572 let t = uniform_grid(50);
1573 let data = FdMatrix::zeros(0, 50);
1574 let fdpar = make_bspline_fdpar(&t, 10, 1e-4);
1575 let res = smooth_basis(&data, &t, &fdpar);
1576 assert!(res.is_err());
1577 }
1578
1579 #[test]
1580 fn test_smooth_basis_empty_data_cols() {
1581 let data = FdMatrix::zeros(5, 0);
1582 let fdpar = FdPar {
1583 basis_type: BasisType::Bspline { order: 4 },
1584 nbasis: 10,
1585 lambda: 1e-4,
1586 lfd_order: 2,
1587 penalty_matrix: vec![0.0; 100],
1588 };
1589 let res = smooth_basis(&data, &[], &fdpar);
1590 assert!(res.is_err());
1591 }
1592
1593 #[test]
1594 fn test_smooth_basis_mismatched_argvals() {
1595 let t = uniform_grid(50);
1596 let data = FdMatrix::zeros(3, 40); let fdpar = make_bspline_fdpar(&t, 10, 1e-4);
1598 let res = smooth_basis(&data, &t, &fdpar);
1599 assert!(res.is_err());
1600 }
1601
1602 #[test]
1603 fn test_smooth_basis_nbasis_too_small() {
1604 let t = uniform_grid(50);
1605 let data = FdMatrix::zeros(3, 50);
1606 let fdpar = FdPar {
1608 basis_type: BasisType::Bspline { order: 4 },
1609 nbasis: 1,
1610 lambda: 1e-4,
1611 lfd_order: 2,
1612 penalty_matrix: vec![0.0; 1],
1613 };
1614 let res = smooth_basis(&data, &t, &fdpar);
1615 assert!(res.is_err());
1616 }
1617
1618 #[test]
1619 fn test_smooth_basis_error_is_invalid_dimension() {
1620 let t = uniform_grid(50);
1621 let data = FdMatrix::zeros(0, 50);
1622 let fdpar = make_bspline_fdpar(&t, 10, 1e-4);
1623 let err = smooth_basis(&data, &t, &fdpar).unwrap_err();
1624 match err {
1625 crate::FdarError::InvalidDimension { .. } => {} other => panic!("Expected InvalidDimension, got {:?}", other),
1627 }
1628 }
1629
1630 #[test]
1633 fn test_bspline_penalty_matrix_different_orders() {
1634 let t = uniform_grid(101);
1635 let p1 = bspline_penalty_matrix(&t, 10, 4, 1);
1637 let p2 = bspline_penalty_matrix(&t, 10, 4, 2);
1639 assert_eq!(p1.len(), p2.len());
1641 let diff: f64 = p1.iter().zip(p2.iter()).map(|(a, b)| (a - b).abs()).sum();
1643 assert!(
1644 diff > 1e-10,
1645 "Different lfd_orders should produce different penalties"
1646 );
1647 }
1648
1649 #[test]
1650 fn test_bspline_penalty_matrix_edge_cases() {
1651 let t = vec![0.0];
1653 let p = bspline_penalty_matrix(&t, 10, 4, 2);
1654 assert!(p.iter().all(|&v| v == 0.0));
1656
1657 let t2 = uniform_grid(50);
1659 let p2 = bspline_penalty_matrix(&t2, 1, 4, 2);
1660 assert!(p2.iter().all(|&v| v == 0.0));
1661
1662 let p3 = bspline_penalty_matrix(&t2, 10, 4, 4);
1664 assert!(p3.iter().all(|&v| v == 0.0));
1665 }
1666
1667 #[test]
1668 fn test_bspline_penalty_nonnegative_diagonal() {
1669 let t = uniform_grid(101);
1670 for nbasis in [5, 10, 20] {
1671 let p = bspline_penalty_matrix(&t, nbasis, 4, 2);
1672 let k = (p.len() as f64).sqrt() as usize;
1673 for i in 0..k {
1674 assert!(
1675 p[i + i * k] >= -1e-10,
1676 "Diagonal ({},{}) negative for nbasis={}: {}",
1677 i,
1678 i,
1679 nbasis,
1680 p[i + i * k]
1681 );
1682 }
1683 }
1684 }
1685
1686 #[test]
1687 fn test_fourier_penalty_increasing_with_frequency() {
1688 let penalty = fourier_penalty_matrix(11, 1.0, 2);
1689 let k = 11;
1690 assert!(penalty[0].abs() < 1e-15);
1692 let mut prev_eigenval = 0.0;
1694 for freq in 1..=5 {
1695 let idx_sin = 2 * freq - 1;
1696 let eigenval = penalty[idx_sin + idx_sin * k];
1697 assert!(
1698 eigenval > prev_eigenval,
1699 "Higher frequency should have larger penalty: freq={}, eigenval={}, prev={}",
1700 freq,
1701 eigenval,
1702 prev_eigenval
1703 );
1704 prev_eigenval = eigenval;
1705 let idx_cos = 2 * freq;
1707 if idx_cos < k {
1708 assert!(
1709 (penalty[idx_cos + idx_cos * k] - eigenval).abs() < 1e-10,
1710 "Sin and cos penalty should match at freq {}",
1711 freq
1712 );
1713 }
1714 }
1715 }
1716
1717 #[test]
1718 fn test_fourier_penalty_different_periods() {
1719 let p1 = fourier_penalty_matrix(7, 1.0, 2);
1720 let p2 = fourier_penalty_matrix(7, 2.0, 2);
1721 for i in 1..7 {
1723 assert!(
1724 p2[i + i * 7] < p1[i + i * 7] || (p1[i + i * 7] == 0.0 && p2[i + i * 7] == 0.0),
1725 "Longer period should have smaller penalties at i={}",
1726 i
1727 );
1728 }
1729 }
1730
1731 #[test]
1732 fn test_fourier_penalty_first_order() {
1733 let p = fourier_penalty_matrix(5, 1.0, 1);
1735 let omega1 = 2.0 * PI;
1737 let expected1 = omega1.powi(2);
1738 assert!(
1739 (p[1 + 5] - expected1).abs() < 1e-6,
1740 "First-order penalty eigenval: got {}, expected {}",
1741 p[1 + 5],
1742 expected1
1743 );
1744 }
1745
1746 #[test]
1747 fn test_fourier_penalty_zero_nbasis() {
1748 let p = fourier_penalty_matrix(0, 1.0, 2);
1749 assert!(p.is_empty());
1750 }
1751
1752 #[test]
1753 fn test_fourier_penalty_nbasis_one() {
1754 let p = fourier_penalty_matrix(1, 1.0, 2);
1755 assert_eq!(p.len(), 1);
1756 assert!(p[0].abs() < 1e-15); }
1758
1759 #[test]
1762 fn test_smooth_basis_gcv_returns_valid_result() {
1763 let (data, t) = make_test_data(5, 50);
1764 let bt = BasisType::Bspline { order: 4 };
1765 let result = smooth_basis_gcv(&data, &t, &bt, 12, 2, (-6.0, 2.0), 20);
1766 assert!(result.is_some());
1767 let res = result.unwrap();
1768 assert_eq!(res.fitted.shape(), (5, 50));
1769 assert!(res.gcv.is_finite());
1770 assert!(res.edf > 0.0);
1771 }
1772
1773 #[test]
1774 fn test_smooth_basis_gcv_fourier() {
1775 let m = 80;
1776 let t = uniform_grid(m);
1777 let mut data = FdMatrix::zeros(3, m);
1778 for i in 0..3 {
1779 for j in 0..m {
1780 data[(i, j)] = (2.0 * PI * t[j]).sin() + 0.5 * (4.0 * PI * t[j]).cos();
1781 }
1782 }
1783 let bt = BasisType::Fourier { period: 1.0 };
1784 let result = smooth_basis_gcv(&data, &t, &bt, 9, 2, (-8.0, 4.0), 25);
1785 assert!(result.is_some());
1786 let res = result.unwrap();
1787 assert_eq!(res.fitted.nrows(), 3);
1788 assert_eq!(res.nbasis, 9);
1789 }
1790
1791 #[test]
1792 fn test_smooth_basis_gcv_selects_finite_gcv() {
1793 let (data, t) = make_test_data(5, 60);
1794 let bt = BasisType::Bspline { order: 4 };
1795 let res = smooth_basis_gcv(&data, &t, &bt, 12, 2, (-6.0, 2.0), 15).unwrap();
1796 assert!(res.gcv.is_finite());
1797 assert!(res.gcv > 0.0);
1798 }
1799
1800 #[test]
1801 fn test_smooth_basis_gcv_empty_data() {
1802 let data = FdMatrix::zeros(0, 50);
1803 let t = uniform_grid(50);
1804 let bt = BasisType::Bspline { order: 4 };
1805 let result = smooth_basis_gcv(&data, &t, &bt, 10, 2, (-6.0, 2.0), 10);
1806 assert!(result.is_none());
1808 }
1809
1810 #[test]
1811 fn test_smooth_basis_gcv_empty_argvals() {
1812 let data = FdMatrix::zeros(5, 0);
1813 let bt = BasisType::Bspline { order: 4 };
1814 let result = smooth_basis_gcv(&data, &[], &bt, 10, 2, (-6.0, 2.0), 10);
1815 assert!(result.is_none());
1816 }
1817
1818 #[test]
1819 fn test_smooth_basis_gcv_nbasis_too_small() {
1820 let (data, t) = make_test_data(5, 50);
1821 let bt = BasisType::Bspline { order: 4 };
1822 let result = smooth_basis_gcv(&data, &t, &bt, 1, 2, (-6.0, 2.0), 10);
1823 assert!(result.is_none());
1824 }
1825
1826 #[test]
1827 fn test_smooth_basis_gcv_ngrid_too_small() {
1828 let (data, t) = make_test_data(5, 50);
1829 let bt = BasisType::Bspline { order: 4 };
1830 let result = smooth_basis_gcv(&data, &t, &bt, 10, 2, (-6.0, 2.0), 1);
1831 assert!(result.is_none());
1832 }
1833
1834 #[test]
1835 fn test_smooth_basis_gcv_narrow_range() {
1836 let (data, t) = make_test_data(3, 50);
1837 let bt = BasisType::Bspline { order: 4 };
1838 let result = smooth_basis_gcv(&data, &t, &bt, 10, 2, (-3.0, -2.0), 5);
1840 assert!(result.is_some());
1841 }
1842
1843 #[test]
1844 fn test_smooth_basis_gcv_wide_range() {
1845 let (data, t) = make_test_data(3, 50);
1846 let bt = BasisType::Bspline { order: 4 };
1847 let result = smooth_basis_gcv(&data, &t, &bt, 10, 2, (-12.0, 8.0), 30);
1849 assert!(result.is_some());
1850 }
1851
1852 #[test]
1855 fn test_basis_nbasis_cv_scores_length() {
1856 let (data, t) = make_test_data(5, 50);
1857 let nbasis_range: Vec<usize> = vec![4, 6, 8, 10, 12];
1858 let res = basis_nbasis_cv(
1859 &data,
1860 &t,
1861 &nbasis_range,
1862 &BasisType::Bspline { order: 4 },
1863 BasisCriterion::Gcv,
1864 5,
1865 1e-4,
1866 )
1867 .unwrap();
1868 assert_eq!(res.scores.len(), 5);
1869 assert_eq!(res.nbasis_range.len(), 5);
1870 assert_eq!(res.nbasis_range, nbasis_range);
1871 }
1872
1873 #[test]
1874 fn test_basis_nbasis_cv_optimal_within_range() {
1875 let (data, t) = make_test_data(8, 50);
1876 let nbasis_range: Vec<usize> = vec![5, 7, 9, 11, 13, 15];
1877 for criterion in [
1878 BasisCriterion::Gcv,
1879 BasisCriterion::Aic,
1880 BasisCriterion::Bic,
1881 ] {
1882 let res = basis_nbasis_cv(
1883 &data,
1884 &t,
1885 &nbasis_range,
1886 &BasisType::Bspline { order: 4 },
1887 criterion,
1888 5,
1889 1e-4,
1890 )
1891 .unwrap();
1892 assert!(
1893 nbasis_range.contains(&res.optimal_nbasis),
1894 "optimal_nbasis {} not in range for {:?}",
1895 res.optimal_nbasis,
1896 criterion
1897 );
1898 }
1899 }
1900
1901 #[test]
1902 fn test_basis_nbasis_cv_fourier_gcv() {
1903 let m = 80;
1904 let t = uniform_grid(m);
1905 let mut data = FdMatrix::zeros(5, m);
1906 for i in 0..5 {
1907 for j in 0..m {
1908 data[(i, j)] = (2.0 * PI * t[j]).sin()
1909 + 0.3 * (4.0 * PI * t[j]).cos()
1910 + 0.02 * ((i * 7 + j * 3) % 10) as f64;
1911 }
1912 }
1913 let nbasis_range: Vec<usize> = vec![5, 7, 9, 11];
1914 let res = basis_nbasis_cv(
1915 &data,
1916 &t,
1917 &nbasis_range,
1918 &BasisType::Fourier { period: 1.0 },
1919 BasisCriterion::Gcv,
1920 5,
1921 1e-4,
1922 )
1923 .unwrap();
1924 assert!(nbasis_range.contains(&res.optimal_nbasis));
1925 }
1926
1927 #[test]
1928 fn test_basis_nbasis_cv_fourier_cv() {
1929 let m = 60;
1930 let t = uniform_grid(m);
1931 let n = 10;
1932 let mut data = FdMatrix::zeros(n, m);
1933 for i in 0..n {
1934 for j in 0..m {
1935 data[(i, j)] = (2.0 * PI * t[j]).sin() + 0.02 * ((i * 11 + j) % 15) as f64;
1936 }
1937 }
1938 let nbasis_range: Vec<usize> = vec![5, 7, 9];
1939 let res = basis_nbasis_cv(
1940 &data,
1941 &t,
1942 &nbasis_range,
1943 &BasisType::Fourier { period: 1.0 },
1944 BasisCriterion::Cv,
1945 5,
1946 1e-4,
1947 )
1948 .unwrap();
1949 assert!(nbasis_range.contains(&res.optimal_nbasis));
1950 assert_eq!(res.criterion, BasisCriterion::Cv);
1951 }
1952
1953 #[test]
1954 fn test_basis_nbasis_cv_with_nbasis_below_minimum() {
1955 let (data, t) = make_test_data(5, 50);
1957 let nbasis_range: Vec<usize> = vec![1, 5, 10];
1958 let res = basis_nbasis_cv(
1959 &data,
1960 &t,
1961 &nbasis_range,
1962 &BasisType::Bspline { order: 4 },
1963 BasisCriterion::Gcv,
1964 5,
1965 1e-4,
1966 )
1967 .unwrap();
1968 assert!(
1970 res.optimal_nbasis >= 5,
1971 "Should skip invalid nbasis=1, got optimal={}",
1972 res.optimal_nbasis
1973 );
1974 assert!(res.scores[0].is_infinite());
1975 }
1976
1977 #[test]
1978 fn test_basis_nbasis_cv_empty_range() {
1979 let (data, t) = make_test_data(5, 50);
1980 let nbasis_range: Vec<usize> = vec![];
1981 let result = basis_nbasis_cv(
1982 &data,
1983 &t,
1984 &nbasis_range,
1985 &BasisType::Bspline { order: 4 },
1986 BasisCriterion::Gcv,
1987 5,
1988 1e-4,
1989 );
1990 assert!(result.is_none());
1991 }
1992
1993 #[test]
1994 fn test_basis_nbasis_cv_empty_data() {
1995 let data = FdMatrix::zeros(0, 50);
1996 let t = uniform_grid(50);
1997 let nbasis_range: Vec<usize> = vec![5, 10];
1998 let result = basis_nbasis_cv(
1999 &data,
2000 &t,
2001 &nbasis_range,
2002 &BasisType::Bspline { order: 4 },
2003 BasisCriterion::Gcv,
2004 5,
2005 1e-4,
2006 );
2007 assert!(result.is_none());
2008 }
2009
2010 #[test]
2011 fn test_basis_nbasis_cv_mismatched_argvals() {
2012 let data = FdMatrix::zeros(5, 50);
2013 let t = uniform_grid(40); let nbasis_range: Vec<usize> = vec![5, 10];
2015 let result = basis_nbasis_cv(
2016 &data,
2017 &t,
2018 &nbasis_range,
2019 &BasisType::Bspline { order: 4 },
2020 BasisCriterion::Gcv,
2021 5,
2022 1e-4,
2023 );
2024 assert!(result.is_none());
2025 }
2026
2027 #[test]
2028 fn test_basis_nbasis_cv_single_nbasis() {
2029 let (data, t) = make_test_data(5, 50);
2030 let nbasis_range: Vec<usize> = vec![10];
2031 let res = basis_nbasis_cv(
2032 &data,
2033 &t,
2034 &nbasis_range,
2035 &BasisType::Bspline { order: 4 },
2036 BasisCriterion::Gcv,
2037 5,
2038 1e-4,
2039 )
2040 .unwrap();
2041 assert_eq!(res.optimal_nbasis, 10);
2042 assert_eq!(res.scores.len(), 1);
2043 }
2044
2045 #[test]
2046 fn test_basis_nbasis_cv_bic_penalizes_more_than_aic() {
2047 let (data, t) = make_test_data(5, 80);
2050 let nbasis_range: Vec<usize> = (4..=20).step_by(2).collect();
2051
2052 let aic_res = basis_nbasis_cv(
2053 &data,
2054 &t,
2055 &nbasis_range,
2056 &BasisType::Bspline { order: 4 },
2057 BasisCriterion::Aic,
2058 5,
2059 1e-4,
2060 )
2061 .unwrap();
2062 let bic_res = basis_nbasis_cv(
2063 &data,
2064 &t,
2065 &nbasis_range,
2066 &BasisType::Bspline { order: 4 },
2067 BasisCriterion::Bic,
2068 5,
2069 1e-4,
2070 )
2071 .unwrap();
2072 assert!(
2075 bic_res.optimal_nbasis <= aic_res.optimal_nbasis + 4,
2076 "BIC selected {} vs AIC selected {} -- BIC should not select much more than AIC",
2077 bic_res.optimal_nbasis,
2078 aic_res.optimal_nbasis
2079 );
2080 }
2081
2082 #[test]
2085 fn test_smooth_basis_fitted_close_to_data() {
2086 let m = 50;
2088 let n = 3;
2089 let t = uniform_grid(m);
2090 let mut data = FdMatrix::zeros(n, m);
2091 for i in 0..n {
2092 for j in 0..m {
2093 data[(i, j)] = (2.0 * PI * t[j]).sin();
2094 }
2095 }
2096 let fdpar = make_bspline_fdpar(&t, 15, 1e-6);
2097 let res = smooth_basis(&data, &t, &fdpar).unwrap();
2098
2099 let mut max_err = 0.0_f64;
2100 for i in 0..n {
2101 for j in 0..m {
2102 let err = (data[(i, j)] - res.fitted[(i, j)]).abs();
2103 max_err = max_err.max(err);
2104 }
2105 }
2106 assert!(
2107 max_err < 0.1,
2108 "Fitted should be close to smooth data; max_err={}",
2109 max_err
2110 );
2111 }
2112
2113 #[test]
2114 fn test_smooth_basis_constant_data() {
2115 let m = 50;
2117 let n = 2;
2118 let t = uniform_grid(m);
2119 let mut data = FdMatrix::zeros(n, m);
2120 for i in 0..n {
2121 for j in 0..m {
2122 data[(i, j)] = 3.15;
2123 }
2124 }
2125 let fdpar = make_bspline_fdpar(&t, 10, 1e-4);
2126 let res = smooth_basis(&data, &t, &fdpar).unwrap();
2127 for i in 0..n {
2128 for j in 0..m {
2129 assert!(
2130 (res.fitted[(i, j)] - 3.15).abs() < 0.01,
2131 "Constant data should be fit well at ({},{}): got {}",
2132 i,
2133 j,
2134 res.fitted[(i, j)]
2135 );
2136 }
2137 }
2138 }
2139
2140 #[test]
2141 fn test_smooth_basis_linear_data() {
2142 let m = 50;
2144 let t = uniform_grid(m);
2145 let mut data = FdMatrix::zeros(1, m);
2146 for j in 0..m {
2147 data[(0, j)] = 2.0 * t[j] + 1.0;
2148 }
2149 let fdpar = make_bspline_fdpar(&t, 10, 1e-4);
2150 let res = smooth_basis(&data, &t, &fdpar).unwrap();
2151 for j in 0..m {
2152 let expected = 2.0 * t[j] + 1.0;
2153 assert!(
2154 (res.fitted[(0, j)] - expected).abs() < 0.05,
2155 "Linear data should be fit well at j={}: got {}, expected {}",
2156 j,
2157 res.fitted[(0, j)],
2158 expected
2159 );
2160 }
2161 }
2162
2163 #[test]
2166 fn test_smooth_basis_edf_bounded() {
2167 let m = 50;
2168 let (data, t) = make_test_data(3, m);
2169 let fdpar = make_bspline_fdpar(&t, 12, 1e-4);
2170 let res = smooth_basis(&data, &t, &fdpar).unwrap();
2171 assert!(
2173 res.edf > 0.0 && res.edf <= m as f64,
2174 "EDF should be in (0, {}]; got {}",
2175 m,
2176 res.edf
2177 );
2178 }
2179
2180 #[test]
2181 fn test_smooth_basis_gcv_aic_bic_all_finite() {
2182 let (data, t) = make_test_data(4, 60);
2183 let fdpar = make_bspline_fdpar(&t, 12, 1e-3);
2184 let res = smooth_basis(&data, &t, &fdpar).unwrap();
2185 assert!(res.gcv.is_finite(), "GCV should be finite: {}", res.gcv);
2186 assert!(res.aic.is_finite(), "AIC should be finite: {}", res.aic);
2187 assert!(res.bic.is_finite(), "BIC should be finite: {}", res.bic);
2188 }
2189
2190 #[test]
2193 fn test_smooth_basis_penalty_matrix_in_result() {
2194 let (data, t) = make_test_data(3, 50);
2195 let nbasis = 10;
2196 let fdpar = make_bspline_fdpar(&t, nbasis, 1e-4);
2197 let res = smooth_basis(&data, &t, &fdpar).unwrap();
2198 let k = res.nbasis;
2199 assert_eq!(
2200 res.penalty_matrix.len(),
2201 k * k,
2202 "Penalty matrix should be k*k = {}*{} = {}; got {}",
2203 k,
2204 k,
2205 k * k,
2206 res.penalty_matrix.len()
2207 );
2208 }
2209
2210 #[test]
2213 fn test_smooth_basis_identical_curves_same_coefficients() {
2214 let m = 50;
2215 let t = uniform_grid(m);
2216 let curve: Vec<f64> = (0..m).map(|j| (2.0 * PI * t[j]).sin()).collect();
2217 let n = 4;
2218 let mut data = FdMatrix::zeros(n, m);
2219 for i in 0..n {
2220 for j in 0..m {
2221 data[(i, j)] = curve[j];
2222 }
2223 }
2224 let fdpar = make_bspline_fdpar(&t, 10, 1e-4);
2225 let res = smooth_basis(&data, &t, &fdpar).unwrap();
2226
2227 let k = res.coefficients.ncols();
2229 for i in 1..n {
2230 for j in 0..k {
2231 assert!(
2232 (res.coefficients[(i, j)] - res.coefficients[(0, j)]).abs() < 1e-10,
2233 "Identical curves should have identical coefficients: curve {} col {} differs",
2234 i,
2235 j
2236 );
2237 }
2238 }
2239 }
2240
2241 #[test]
2244 fn test_basis_nbasis_cv_different_nfolds() {
2245 let (data, t) = make_test_data(12, 50);
2246 let nbasis_range: Vec<usize> = vec![5, 8, 11];
2247 for nfolds in [2, 3, 5, 10] {
2248 let res = basis_nbasis_cv(
2249 &data,
2250 &t,
2251 &nbasis_range,
2252 &BasisType::Bspline { order: 4 },
2253 BasisCriterion::Cv,
2254 nfolds,
2255 1e-4,
2256 );
2257 assert!(res.is_some(), "CV should succeed with nfolds={}", nfolds);
2258 let r = res.unwrap();
2259 assert!(nbasis_range.contains(&r.optimal_nbasis));
2260 }
2261 }
2262
2263 #[test]
2266 fn test_smooth_basis_many_basis_functions() {
2267 let m = 100;
2268 let (data, t) = make_test_data(2, m);
2269 let fdpar = make_bspline_fdpar(&t, 40, 1e-2);
2271 let res = smooth_basis(&data, &t, &fdpar);
2272 assert!(
2273 res.is_ok(),
2274 "Should handle many basis functions with penalty"
2275 );
2276 }
2277
2278 #[test]
2281 fn test_smooth_basis_bspline_vs_fourier_different_results() {
2282 let m = 50;
2283 let (data, t) = make_test_data(2, m);
2284 let fdpar_bs = make_bspline_fdpar(&t, 9, 1e-4);
2285 let fdpar_f = make_fourier_fdpar(9, 1.0, 1e-4);
2286 let res_bs = smooth_basis(&data, &t, &fdpar_bs).unwrap();
2287 let res_f = smooth_basis(&data, &t, &fdpar_f).unwrap();
2288 let diff: f64 = (0..m)
2290 .map(|j| (res_bs.fitted[(0, j)] - res_f.fitted[(0, j)]).abs())
2291 .sum();
2292 assert!(
2294 diff > 1e-10,
2295 "B-spline and Fourier fits should differ for the same data"
2296 );
2297 }
2298
2299 #[test]
2302 fn test_smooth_basis_gcv_positive_for_noisy_data() {
2303 let m = 50;
2304 let t = uniform_grid(m);
2305 let mut data = FdMatrix::zeros(1, m);
2306 for j in 0..m {
2307 data[(0, j)] = (2.0 * PI * t[j]).sin() + 0.5 * ((j * 37) % 20) as f64 / 20.0 - 0.25;
2309 }
2310 let fdpar = make_bspline_fdpar(&t, 10, 1e-3);
2311 let res = smooth_basis(&data, &t, &fdpar).unwrap();
2312 assert!(res.gcv > 0.0, "GCV should be positive for noisy data");
2313 }
2314
2315 #[test]
2318 fn test_smooth_basis_different_lfd_orders() {
2319 let m = 50;
2320 let (data, t) = make_test_data(2, m);
2321
2322 let penalty1 = bspline_penalty_matrix(&t, 10, 4, 1);
2324 let fdpar1 = FdPar {
2325 basis_type: BasisType::Bspline { order: 4 },
2326 nbasis: 10,
2327 lambda: 1e-2,
2328 lfd_order: 1,
2329 penalty_matrix: penalty1,
2330 };
2331 let res1 = smooth_basis(&data, &t, &fdpar1);
2332 assert!(res1.is_ok());
2333
2334 let penalty2 = bspline_penalty_matrix(&t, 10, 4, 2);
2336 let fdpar2 = FdPar {
2337 basis_type: BasisType::Bspline { order: 4 },
2338 nbasis: 10,
2339 lambda: 1e-2,
2340 lfd_order: 2,
2341 penalty_matrix: penalty2,
2342 };
2343 let res2 = smooth_basis(&data, &t, &fdpar2);
2344 assert!(res2.is_ok());
2345
2346 let r1 = res1.unwrap();
2348 let r2 = res2.unwrap();
2349 let diff: f64 = (0..m)
2350 .map(|j| (r1.fitted[(0, j)] - r2.fitted[(0, j)]).abs())
2351 .sum();
2352 assert!(
2353 diff > 1e-10,
2354 "Different lfd_orders should produce different fits"
2355 );
2356 }
2357
2358 #[test]
2361 fn test_basis_nbasis_cv_result_fields() {
2362 let (data, t) = make_test_data(6, 50);
2363 let nbasis_range: Vec<usize> = vec![5, 7, 9, 11, 13];
2364 let res = basis_nbasis_cv(
2365 &data,
2366 &t,
2367 &nbasis_range,
2368 &BasisType::Bspline { order: 4 },
2369 BasisCriterion::Aic,
2370 5,
2371 1e-4,
2372 )
2373 .unwrap();
2374
2375 assert!(nbasis_range.contains(&res.optimal_nbasis));
2376 assert_eq!(res.scores.len(), nbasis_range.len());
2377 assert_eq!(res.nbasis_range, nbasis_range);
2378 assert_eq!(res.criterion, BasisCriterion::Aic);
2379 let min_score = res.scores.iter().copied().fold(f64::INFINITY, f64::min);
2381 let best_idx = res
2382 .scores
2383 .iter()
2384 .position(|&s| (s - min_score).abs() < 1e-15)
2385 .unwrap();
2386 assert_eq!(res.optimal_nbasis, nbasis_range[best_idx]);
2387 }
2388
2389 #[test]
2390 fn test_basis_nbasis_cv_result_clone() {
2391 let (data, t) = make_test_data(5, 50);
2392 let nbasis_range: Vec<usize> = vec![5, 10];
2393 let res = basis_nbasis_cv(
2394 &data,
2395 &t,
2396 &nbasis_range,
2397 &BasisType::Bspline { order: 4 },
2398 BasisCriterion::Gcv,
2399 5,
2400 1e-4,
2401 )
2402 .unwrap();
2403 let cloned = res.clone();
2404 assert_eq!(res, cloned);
2405 }
2406
2407 #[test]
2410 fn test_smooth_basis_nonuniform_argvals() {
2411 let m = 50;
2412 let t: Vec<f64> = (0..m)
2414 .map(|i| {
2415 let x = i as f64 / (m - 1) as f64;
2416 0.5 * (1.0 - (PI * x).cos())
2417 })
2418 .collect();
2419 let mut data = FdMatrix::zeros(2, m);
2420 for i in 0..2 {
2421 for j in 0..m {
2422 data[(i, j)] = (2.0 * PI * t[j]).sin() + 0.1 * i as f64;
2423 }
2424 }
2425 let fdpar = make_bspline_fdpar(&t, 10, 1e-4);
2426 let res = smooth_basis(&data, &t, &fdpar);
2427 assert!(res.is_ok(), "Should handle non-uniform argvals");
2428 let r = res.unwrap();
2429 assert_eq!(r.fitted.shape(), (2, m));
2430 }
2431
2432 #[test]
2435 fn test_smooth_basis_very_small_lambda() {
2436 let m = 50;
2437 let (data, t) = make_test_data(2, m);
2438 let fdpar = make_bspline_fdpar(&t, 10, 1e-15);
2439 let res = smooth_basis(&data, &t, &fdpar);
2440 assert!(res.is_ok(), "Should handle very small lambda");
2441 }
2442
2443 #[test]
2444 fn test_smooth_basis_very_large_lambda() {
2445 let m = 50;
2446 let (data, t) = make_test_data(2, m);
2447 let fdpar = make_bspline_fdpar(&t, 10, 1e10);
2448 let res = smooth_basis(&data, &t, &fdpar);
2449 assert!(res.is_ok(), "Should handle very large lambda");
2450 }
2451
2452 #[test]
2455 fn test_smooth_basis_multi_curve_vs_single_curve() {
2456 let m = 50;
2458 let n = 3;
2459 let (data, t) = make_test_data(n, m);
2460 let fdpar = make_bspline_fdpar(&t, 10, 1e-3);
2461
2462 let res_all = smooth_basis(&data, &t, &fdpar).unwrap();
2464
2465 for i in 0..n {
2467 let mut single = FdMatrix::zeros(1, m);
2468 for j in 0..m {
2469 single[(0, j)] = data[(i, j)];
2470 }
2471 let res_single = smooth_basis(&single, &t, &fdpar).unwrap();
2472 for j in 0..m {
2473 assert!(
2474 (res_all.fitted[(i, j)] - res_single.fitted[(0, j)]).abs() < 1e-10,
2475 "Multi-curve fit should match single-curve fit: curve {} point {}",
2476 i,
2477 j
2478 );
2479 }
2480 }
2481 }
2482
2483 #[test]
2486 fn test_basis_nbasis_cv_all_criteria_finite_scores() {
2487 let (data, t) = make_test_data(10, 60);
2488 let nbasis_range: Vec<usize> = vec![5, 7, 9, 11];
2489
2490 for criterion in [
2491 BasisCriterion::Gcv,
2492 BasisCriterion::Aic,
2493 BasisCriterion::Bic,
2494 BasisCriterion::Cv,
2495 ] {
2496 let res = basis_nbasis_cv(
2497 &data,
2498 &t,
2499 &nbasis_range,
2500 &BasisType::Bspline { order: 4 },
2501 criterion,
2502 5,
2503 1e-4,
2504 )
2505 .unwrap();
2506 let finite_count = res.scores.iter().filter(|s| s.is_finite()).count();
2508 assert!(
2509 finite_count > 0,
2510 "At least one score should be finite for {:?}",
2511 criterion
2512 );
2513 }
2514 }
2515
2516 #[test]
2519 fn test_smooth_basis_gcv_config_default() {
2520 let config = SmoothBasisGcvConfig::default();
2521 assert_eq!(config.basis_type, BasisType::Bspline { order: 4 });
2522 assert_eq!(config.nbasis, 15);
2523 assert_eq!(config.lfd_order, 2);
2524 assert_eq!(config.log_lambda_range, (-10.0, 2.0));
2525 assert_eq!(config.n_grid, 50);
2526 }
2527
2528 #[test]
2529 fn test_smooth_basis_gcv_config_clone_eq() {
2530 let config = SmoothBasisGcvConfig {
2531 nbasis: 20,
2532 ..SmoothBasisGcvConfig::default()
2533 };
2534 let cloned = config.clone();
2535 assert_eq!(config, cloned);
2536 }
2537
2538 #[test]
2539 fn test_smooth_basis_gcv_config_debug() {
2540 let config = SmoothBasisGcvConfig::default();
2541 let debug_str = format!("{:?}", config);
2542 assert!(debug_str.contains("SmoothBasisGcvConfig"));
2543 assert!(debug_str.contains("nbasis"));
2544 }
2545
2546 #[test]
2547 fn test_smooth_basis_gcv_config_partial_override() {
2548 let config = SmoothBasisGcvConfig {
2549 basis_type: BasisType::Fourier { period: 2.0 },
2550 n_grid: 100,
2551 ..SmoothBasisGcvConfig::default()
2552 };
2553 assert_eq!(config.basis_type, BasisType::Fourier { period: 2.0 });
2554 assert_eq!(config.n_grid, 100);
2555 assert_eq!(config.nbasis, 15);
2557 assert_eq!(config.lfd_order, 2);
2558 }
2559
2560 #[test]
2561 fn test_smooth_basis_gcv_with_config_default() {
2562 let (data, t) = make_test_data(5, 101);
2563 let config = SmoothBasisGcvConfig::default();
2564 let result = smooth_basis_gcv_with_config(&data, &t, &config);
2565 assert!(result.is_ok(), "GCV with default config should succeed");
2566 let res = result.unwrap();
2567 assert_eq!(res.fitted.shape(), (5, 101));
2568 assert!(res.edf > 0.0);
2569 assert!(res.gcv.is_finite());
2570 }
2571
2572 #[test]
2573 fn test_smooth_basis_gcv_with_config_custom() {
2574 let (data, t) = make_test_data(3, 50);
2575 let config = SmoothBasisGcvConfig {
2576 nbasis: 10,
2577 log_lambda_range: (-6.0, 0.0),
2578 n_grid: 15,
2579 ..SmoothBasisGcvConfig::default()
2580 };
2581 let result = smooth_basis_gcv_with_config(&data, &t, &config);
2582 assert!(result.is_ok());
2583 }
2584
2585 #[test]
2586 fn test_smooth_basis_gcv_with_config_matches_direct() {
2587 let (data, t) = make_test_data(3, 50);
2588 let config = SmoothBasisGcvConfig {
2589 nbasis: 10,
2590 log_lambda_range: (-6.0, 0.0),
2591 n_grid: 20,
2592 ..SmoothBasisGcvConfig::default()
2593 };
2594 let with_config = smooth_basis_gcv_with_config(&data, &t, &config).unwrap();
2595 let direct = smooth_basis_gcv(
2596 &data,
2597 &t,
2598 &config.basis_type,
2599 config.nbasis,
2600 config.lfd_order,
2601 config.log_lambda_range,
2602 config.n_grid,
2603 )
2604 .unwrap();
2605 assert_eq!(with_config.gcv, direct.gcv);
2606 assert_eq!(with_config.edf, direct.edf);
2607 assert_eq!(with_config.nbasis, direct.nbasis);
2608 }
2609
2610 #[test]
2611 fn test_smooth_basis_gcv_with_config_fourier() {
2612 let m = 100;
2613 let t = uniform_grid(m);
2614 let mut data = FdMatrix::zeros(2, m);
2615 for i in 0..2 {
2616 for j in 0..m {
2617 data[(i, j)] = (2.0 * PI * t[j]).sin() + (4.0 * PI * t[j]).cos();
2618 }
2619 }
2620 let config = SmoothBasisGcvConfig {
2621 basis_type: BasisType::Fourier { period: 1.0 },
2622 nbasis: 7,
2623 n_grid: 20,
2624 ..SmoothBasisGcvConfig::default()
2625 };
2626 let result = smooth_basis_gcv_with_config(&data, &t, &config);
2627 assert!(result.is_ok());
2628 }
2629
2630 #[test]
2633 fn test_basis_nbasis_cv_config_default() {
2634 let config = BasisNbasisCvConfig::default();
2635 assert_eq!(config.basis_type, BasisType::Bspline { order: 4 });
2636 assert_eq!(config.nbasis_range, (5, 30));
2637 assert!((config.lambda - 1e-4).abs() < 1e-15);
2638 assert_eq!(config.lfd_order, 2);
2639 assert_eq!(config.n_folds, 5);
2640 assert_eq!(config.criterion, BasisCriterion::Gcv);
2641 }
2642
2643 #[test]
2644 fn test_basis_nbasis_cv_config_clone_eq() {
2645 let config = BasisNbasisCvConfig {
2646 nbasis_range: (4, 15),
2647 ..BasisNbasisCvConfig::default()
2648 };
2649 let cloned = config.clone();
2650 assert_eq!(config, cloned);
2651 }
2652
2653 #[test]
2654 fn test_basis_nbasis_cv_config_debug() {
2655 let config = BasisNbasisCvConfig::default();
2656 let debug_str = format!("{:?}", config);
2657 assert!(debug_str.contains("BasisNbasisCvConfig"));
2658 assert!(debug_str.contains("nbasis_range"));
2659 }
2660
2661 #[test]
2662 fn test_basis_nbasis_cv_config_partial_override() {
2663 let config = BasisNbasisCvConfig {
2664 criterion: BasisCriterion::Aic,
2665 lambda: 1e-2,
2666 ..BasisNbasisCvConfig::default()
2667 };
2668 assert_eq!(config.criterion, BasisCriterion::Aic);
2669 assert!((config.lambda - 1e-2).abs() < 1e-15);
2670 assert_eq!(config.nbasis_range, (5, 30));
2672 assert_eq!(config.n_folds, 5);
2673 }
2674
2675 #[test]
2676 fn test_basis_nbasis_cv_with_config_default() {
2677 let (data, t) = make_test_data(5, 51);
2678 let config = BasisNbasisCvConfig {
2679 nbasis_range: (5, 12),
2680 ..BasisNbasisCvConfig::default()
2681 };
2682 let result = basis_nbasis_cv_with_config(&data, &t, &config);
2683 assert!(
2684 result.is_ok(),
2685 "nbasis CV with default config should succeed"
2686 );
2687 let res = result.unwrap();
2688 assert!(res.optimal_nbasis >= 5 && res.optimal_nbasis <= 12);
2689 assert_eq!(res.scores.len(), 8); assert_eq!(res.criterion, BasisCriterion::Gcv);
2691 }
2692
2693 #[test]
2694 fn test_basis_nbasis_cv_with_config_aic() {
2695 let (data, t) = make_test_data(5, 51);
2696 let config = BasisNbasisCvConfig {
2697 nbasis_range: (5, 10),
2698 criterion: BasisCriterion::Aic,
2699 ..BasisNbasisCvConfig::default()
2700 };
2701 let result = basis_nbasis_cv_with_config(&data, &t, &config);
2702 assert!(result.is_ok());
2703 assert_eq!(result.unwrap().criterion, BasisCriterion::Aic);
2704 }
2705
2706 #[test]
2707 fn test_basis_nbasis_cv_with_config_cv_folds() {
2708 let (data, t) = make_test_data(10, 51);
2709 let config = BasisNbasisCvConfig {
2710 nbasis_range: (5, 9),
2711 criterion: BasisCriterion::Cv,
2712 n_folds: 3,
2713 ..BasisNbasisCvConfig::default()
2714 };
2715 let result = basis_nbasis_cv_with_config(&data, &t, &config);
2716 assert!(result.is_ok());
2717 assert_eq!(result.unwrap().criterion, BasisCriterion::Cv);
2718 }
2719
2720 #[test]
2721 fn test_basis_nbasis_cv_with_config_matches_direct() {
2722 let (data, t) = make_test_data(5, 51);
2723 let config = BasisNbasisCvConfig {
2724 nbasis_range: (5, 10),
2725 criterion: BasisCriterion::Bic,
2726 lambda: 1e-3,
2727 ..BasisNbasisCvConfig::default()
2728 };
2729 let with_config = basis_nbasis_cv_with_config(&data, &t, &config).unwrap();
2730 let nbasis_range: Vec<usize> = (5..=10).collect();
2731 let direct = basis_nbasis_cv(
2732 &data,
2733 &t,
2734 &nbasis_range,
2735 &config.basis_type,
2736 config.criterion,
2737 config.n_folds,
2738 config.lambda,
2739 )
2740 .unwrap();
2741 assert_eq!(with_config.optimal_nbasis, direct.optimal_nbasis);
2742 assert_eq!(with_config.scores, direct.scores);
2743 assert_eq!(with_config.nbasis_range, direct.nbasis_range);
2744 }
2745
2746 #[test]
2747 fn test_basis_nbasis_cv_with_config_nbasis_range_expansion() {
2748 let (data, t) = make_test_data(5, 51);
2749 let config = BasisNbasisCvConfig {
2750 nbasis_range: (7, 7), ..BasisNbasisCvConfig::default()
2752 };
2753 let result = basis_nbasis_cv_with_config(&data, &t, &config);
2754 assert!(result.is_ok());
2755 let res = result.unwrap();
2756 assert_eq!(res.optimal_nbasis, 7);
2757 assert_eq!(res.scores.len(), 1);
2758 }
2759}