1use crate::basis::{bspline_basis, fourier_basis_with_period};
13use crate::helpers::simpsons_weights;
14use crate::matrix::FdMatrix;
15use nalgebra::DMatrix;
16use std::f64::consts::PI;
17
18#[derive(Debug, Clone, PartialEq)]
22pub enum BasisType {
23 Bspline { order: usize },
25 Fourier { period: f64 },
27}
28
29#[derive(Debug, Clone, PartialEq)]
31pub struct FdPar {
32 pub basis_type: BasisType,
34 pub nbasis: usize,
36 pub lambda: f64,
38 pub lfd_order: usize,
40 pub penalty_matrix: Vec<f64>,
42}
43
44#[derive(Debug, Clone, PartialEq)]
46pub struct SmoothBasisResult {
47 pub coefficients: FdMatrix,
49 pub fitted: FdMatrix,
51 pub edf: f64,
53 pub gcv: f64,
55 pub aic: f64,
57 pub bic: f64,
59 pub penalty_matrix: Vec<f64>,
61 pub nbasis: usize,
63}
64
65pub fn bspline_penalty_matrix(
82 argvals: &[f64],
83 nbasis: usize,
84 order: usize,
85 lfd_order: usize,
86) -> Vec<f64> {
87 if nbasis < 2 || order < 1 || lfd_order >= order || argvals.len() < 2 {
88 return vec![0.0; nbasis * nbasis];
89 }
90
91 let nknots = nbasis.saturating_sub(order).max(2);
92
93 let n_sub = 10;
95 let t_min = argvals[0];
96 let t_max = argvals[argvals.len() - 1];
97 let n_quad = (argvals.len() - 1) * n_sub + 1;
98 let quad_t: Vec<f64> = (0..n_quad)
99 .map(|i| t_min + (t_max - t_min) * i as f64 / (n_quad - 1) as f64)
100 .collect();
101
102 let basis_fine = bspline_basis(&quad_t, nknots, order);
104 let actual_nbasis = basis_fine.len() / n_quad;
105
106 let h = (t_max - t_min) / (n_quad - 1) as f64;
108 let deriv_basis = differentiate_basis_columns(&basis_fine, n_quad, actual_nbasis, h, lfd_order);
109
110 let weights = simpsons_weights(&quad_t);
112
113 integrate_symmetric_penalty(&deriv_basis, &weights, actual_nbasis, n_quad)
115}
116
117pub fn fourier_penalty_matrix(nbasis: usize, period: f64, lfd_order: usize) -> Vec<f64> {
129 let k = nbasis;
130 let mut penalty = vec![0.0; k * k];
131
132 let mut freq = 1;
138 let mut idx = 1;
139 while idx < k {
140 let omega = 2.0 * PI * f64::from(freq) / period;
141 let eigenval = omega.powi(2 * lfd_order as i32);
142
143 if idx < k {
145 penalty[idx + idx * k] = eigenval;
146 idx += 1;
147 }
148 if idx < k {
150 penalty[idx + idx * k] = eigenval;
151 idx += 1;
152 }
153 freq += 1;
154 }
155
156 penalty
157}
158
159pub fn smooth_basis(
174 data: &FdMatrix,
175 argvals: &[f64],
176 fdpar: &FdPar,
177) -> Result<SmoothBasisResult, crate::FdarError> {
178 let (n, m) = data.shape();
179 if n == 0 || m == 0 || argvals.len() != m || fdpar.nbasis < 2 {
180 return Err(crate::FdarError::InvalidDimension {
181 parameter: "data/argvals/fdpar",
182 expected: "n > 0, m > 0, argvals.len() == m, nbasis >= 2".to_string(),
183 actual: format!(
184 "n={}, m={}, argvals.len()={}, nbasis={}",
185 n,
186 m,
187 argvals.len(),
188 fdpar.nbasis
189 ),
190 });
191 }
192
193 let (basis_flat, actual_nbasis) = evaluate_basis(argvals, &fdpar.basis_type, fdpar.nbasis);
195 let k = actual_nbasis;
196
197 let b_mat = DMatrix::from_column_slice(m, k, &basis_flat);
198 let r_mat = DMatrix::from_column_slice(k, k, &fdpar.penalty_matrix);
199
200 let btb = b_mat.transpose() * &b_mat;
202 let ridge_eps = 1e-10;
203 let system: DMatrix<f64> =
204 &btb + fdpar.lambda * &r_mat + ridge_eps * DMatrix::<f64>::identity(k, k);
205
206 let system_inv =
208 invert_penalized_system(&system, k).ok_or_else(|| crate::FdarError::ComputationFailed {
209 operation: "matrix inversion",
210 detail: "failed to invert penalized system (Φ'Φ + λR)".to_string(),
211 })?;
212
213 let h_mat = &b_mat * &system_inv * b_mat.transpose();
215 let edf: f64 = (0..m).map(|i| h_mat[(i, i)]).sum();
216
217 let proj = &system_inv * b_mat.transpose();
219 let (all_coefs, all_fitted, total_rss) = project_all_curves(data, &b_mat, &proj, n, m, k);
220
221 let total_points = (n * m) as f64;
222 let gcv = compute_gcv(total_rss, total_points, edf, m);
223 let mse = total_rss / total_points;
224 let total_edf = n as f64 * edf;
226 let aic = total_points * mse.max(1e-300).ln() + 2.0 * total_edf;
227 let bic = total_points * mse.max(1e-300).ln() + total_points.ln() * total_edf;
228
229 Ok(SmoothBasisResult {
230 coefficients: all_coefs,
231 fitted: all_fitted,
232 edf,
233 gcv,
234 aic,
235 bic,
236 penalty_matrix: fdpar.penalty_matrix.clone(),
237 nbasis: k,
238 })
239}
240
241pub fn smooth_basis_gcv(
254 data: &FdMatrix,
255 argvals: &[f64],
256 basis_type: &BasisType,
257 nbasis: usize,
258 lfd_order: usize,
259 log_lambda_range: (f64, f64),
260 n_grid: usize,
261) -> Option<SmoothBasisResult> {
262 let m = argvals.len();
263 if m == 0 || nbasis < 2 || n_grid < 2 {
264 return None;
265 }
266
267 let penalty = match basis_type {
269 BasisType::Bspline { order } => bspline_penalty_matrix(argvals, nbasis, *order, lfd_order),
270 BasisType::Fourier { period } => fourier_penalty_matrix(nbasis, *period, lfd_order),
271 };
272
273 let (lo, hi) = log_lambda_range;
274 let mut best_gcv = f64::INFINITY;
275 let mut best_result: Option<SmoothBasisResult> = None;
276
277 for i in 0..n_grid {
278 let log_lam = lo + (hi - lo) * i as f64 / (n_grid - 1) as f64;
279 let lam = 10.0_f64.powf(log_lam);
280
281 let fdpar = FdPar {
282 basis_type: basis_type.clone(),
283 nbasis,
284 lambda: lam,
285 lfd_order,
286 penalty_matrix: penalty.clone(),
287 };
288
289 if let Ok(result) = smooth_basis(data, argvals, &fdpar) {
290 if result.gcv < best_gcv {
291 best_gcv = result.gcv;
292 best_result = Some(result);
293 }
294 }
295 }
296
297 best_result
298}
299
300fn differentiate_basis_columns(
304 basis: &[f64],
305 n_quad: usize,
306 nbasis: usize,
307 h: f64,
308 lfd_order: usize,
309) -> Vec<f64> {
310 let mut deriv = basis.to_vec();
311 for _ in 0..lfd_order {
312 let mut new_deriv = vec![0.0; n_quad * nbasis];
313 for j in 0..nbasis {
314 let col: Vec<f64> = (0..n_quad).map(|i| deriv[i + j * n_quad]).collect();
315 let grad = crate::helpers::gradient_uniform(&col, h);
316 for i in 0..n_quad {
317 new_deriv[i + j * n_quad] = grad[i];
318 }
319 }
320 deriv = new_deriv;
321 }
322 deriv
323}
324
325fn integrate_symmetric_penalty(
327 deriv_basis: &[f64],
328 weights: &[f64],
329 k: usize,
330 n_quad: usize,
331) -> Vec<f64> {
332 let mut penalty = vec![0.0; k * k];
333 for j in 0..k {
334 for l in j..k {
335 let mut val = 0.0;
336 for i in 0..n_quad {
337 val += deriv_basis[i + j * n_quad] * deriv_basis[i + l * n_quad] * weights[i];
338 }
339 penalty[j + l * k] = val;
340 penalty[l + j * k] = val;
341 }
342 }
343 penalty
344}
345
346fn evaluate_basis(argvals: &[f64], basis_type: &BasisType, nbasis: usize) -> (Vec<f64>, usize) {
348 let m = argvals.len();
349 match basis_type {
350 BasisType::Bspline { order } => {
351 let nknots = nbasis.saturating_sub(*order).max(2);
352 let basis = bspline_basis(argvals, nknots, *order);
353 let actual = basis.len() / m;
354 (basis, actual)
355 }
356 BasisType::Fourier { period } => {
357 let basis = fourier_basis_with_period(argvals, nbasis, *period);
358 (basis, nbasis)
359 }
360 }
361}
362
363fn invert_penalized_system(system: &DMatrix<f64>, k: usize) -> Option<DMatrix<f64>> {
365 if let Some(chol) = system.clone().cholesky() {
366 return Some(chol.inverse());
367 }
368 let svd = nalgebra::SVD::new(system.clone(), true, true);
370 let u = svd.u.as_ref()?;
371 let v_t = svd.v_t.as_ref()?;
372 let max_sv: f64 = svd.singular_values.iter().copied().fold(0.0_f64, f64::max);
373 let eps = 1e-10 * max_sv;
374 let mut inv = DMatrix::<f64>::zeros(k, k);
375 for ii in 0..k {
376 for jj in 0..k {
377 let mut sum = 0.0;
378 for s in 0..k.min(svd.singular_values.len()) {
379 if svd.singular_values[s] > eps {
380 sum += v_t[(s, ii)] / svd.singular_values[s] * u[(jj, s)];
381 }
382 }
383 inv[(ii, jj)] = sum;
384 }
385 }
386 Some(inv)
387}
388
389fn project_all_curves(
391 data: &FdMatrix,
392 b_mat: &DMatrix<f64>,
393 proj: &DMatrix<f64>,
394 n: usize,
395 m: usize,
396 k: usize,
397) -> (FdMatrix, FdMatrix, f64) {
398 let mut all_coefs = FdMatrix::zeros(n, k);
399 let mut all_fitted = FdMatrix::zeros(n, m);
400 let mut total_rss = 0.0;
401
402 for i in 0..n {
403 let curve: Vec<f64> = (0..m).map(|j| data[(i, j)]).collect();
404 let y_vec = nalgebra::DVector::from_vec(curve.clone());
405 let coefs = proj * &y_vec;
406
407 for j in 0..k {
408 all_coefs[(i, j)] = coefs[j];
409 }
410 let fitted = b_mat * &coefs;
411 for j in 0..m {
412 all_fitted[(i, j)] = fitted[j];
413 let resid = curve[j] - fitted[j];
414 total_rss += resid * resid;
415 }
416 }
417
418 (all_coefs, all_fitted, total_rss)
419}
420
421fn compute_gcv(rss: f64, n_points: f64, edf: f64, m: usize) -> f64 {
423 let gcv_denom = 1.0 - edf / m as f64;
424 if gcv_denom.abs() > 1e-10 {
425 (rss / n_points) / (gcv_denom * gcv_denom)
426 } else {
427 f64::INFINITY
428 }
429}
430
431#[derive(Debug, Clone, Copy, PartialEq)]
435pub enum BasisCriterion {
436 Gcv,
438 Cv,
440 Aic,
442 Bic,
444}
445
446#[derive(Debug, Clone, PartialEq)]
448pub struct BasisNbasisCvResult {
449 pub optimal_nbasis: usize,
451 pub scores: Vec<f64>,
453 pub nbasis_range: Vec<usize>,
455 pub criterion: BasisCriterion,
457}
458
459fn evaluate_nbasis_info_criterion(
461 data: &FdMatrix,
462 argvals: &[f64],
463 nbasis_range: &[usize],
464 basis_type: &BasisType,
465 criterion: BasisCriterion,
466 lambda: f64,
467) -> Vec<f64> {
468 let mut scores = Vec::with_capacity(nbasis_range.len());
469 for &nb in nbasis_range {
470 if nb < 2 {
471 scores.push(f64::INFINITY);
472 continue;
473 }
474 let penalty = match basis_type {
475 BasisType::Bspline { order } => bspline_penalty_matrix(argvals, nb, *order, 2),
476 BasisType::Fourier { period } => fourier_penalty_matrix(nb, *period, 2),
477 };
478 let fdpar = FdPar {
479 basis_type: basis_type.clone(),
480 nbasis: nb,
481 lambda,
482 lfd_order: 2,
483 penalty_matrix: penalty,
484 };
485 match smooth_basis(data, argvals, &fdpar) {
486 Ok(result) => {
487 let score = match criterion {
488 BasisCriterion::Gcv => result.gcv,
489 BasisCriterion::Aic => result.aic,
490 BasisCriterion::Bic => result.bic,
491 _ => unreachable!(),
492 };
493 scores.push(score);
494 }
495 Err(_) => scores.push(f64::INFINITY),
496 }
497 }
498 scores
499}
500
501fn evaluate_nbasis_cv(
503 data: &FdMatrix,
504 argvals: &[f64],
505 nbasis_range: &[usize],
506 basis_type: &BasisType,
507 lambda: f64,
508 n_folds: usize,
509) -> Vec<f64> {
510 let (n, m) = data.shape();
511 let n_folds = n_folds.max(2);
512 let folds = crate::cv::create_folds(n, n_folds, 42);
513 let mut scores = Vec::with_capacity(nbasis_range.len());
514
515 for &nb in nbasis_range {
516 if nb < 2 {
517 scores.push(f64::INFINITY);
518 continue;
519 }
520 let penalty = match basis_type {
521 BasisType::Bspline { order } => bspline_penalty_matrix(argvals, nb, *order, 2),
522 BasisType::Fourier { period } => fourier_penalty_matrix(nb, *period, 2),
523 };
524
525 let mut total_mse = 0.0;
526 let mut total_count = 0;
527
528 for fold in 0..n_folds {
529 let (train_idx, test_idx) = crate::cv::fold_indices(&folds, fold);
530 if train_idx.is_empty() || test_idx.is_empty() {
531 continue;
532 }
533 let train_data = crate::cv::subset_rows(data, &train_idx);
534 let fdpar = FdPar {
535 basis_type: basis_type.clone(),
536 nbasis: nb,
537 lambda,
538 lfd_order: 2,
539 penalty_matrix: penalty.clone(),
540 };
541
542 if let Ok(train_result) = smooth_basis(&train_data, argvals, &fdpar) {
543 let (basis_flat, actual_k) = evaluate_basis(argvals, basis_type, nb);
544 let b_mat = DMatrix::from_column_slice(m, actual_k, &basis_flat);
545 let r_mat =
546 DMatrix::from_column_slice(actual_k, actual_k, &train_result.penalty_matrix);
547 let btb = b_mat.transpose() * &b_mat;
548 let ridge_eps = 1e-10;
549 let system: DMatrix<f64> = &btb
550 + lambda * &r_mat
551 + ridge_eps * DMatrix::<f64>::identity(actual_k, actual_k);
552
553 if let Some(system_inv) = invert_penalized_system(&system, actual_k) {
554 let proj = &system_inv * b_mat.transpose();
555 for &ti in &test_idx {
556 let curve: Vec<f64> = (0..m).map(|j| data[(ti, j)]).collect();
557 let y_vec = nalgebra::DVector::from_vec(curve.clone());
558 let coefs = &proj * &y_vec;
559 let fitted = &b_mat * &coefs;
560 let mse: f64 =
561 (0..m).map(|j| (curve[j] - fitted[j]).powi(2)).sum::<f64>() / m as f64;
562 total_mse += mse;
563 total_count += 1;
564 }
565 }
566 }
567 }
568
569 if total_count > 0 {
570 scores.push(total_mse / f64::from(total_count));
571 } else {
572 scores.push(f64::INFINITY);
573 }
574 }
575 scores
576}
577
578pub fn basis_nbasis_cv(
581 data: &FdMatrix,
582 argvals: &[f64],
583 nbasis_range: &[usize],
584 basis_type: &BasisType,
585 criterion: BasisCriterion,
586 n_folds: usize,
587 lambda: f64,
588) -> Option<BasisNbasisCvResult> {
589 let (n, m) = data.shape();
590 if n == 0 || m == 0 || argvals.len() != m || nbasis_range.is_empty() {
591 return None;
592 }
593
594 let scores = match criterion {
595 BasisCriterion::Gcv | BasisCriterion::Aic | BasisCriterion::Bic => {
596 evaluate_nbasis_info_criterion(
597 data,
598 argvals,
599 nbasis_range,
600 basis_type,
601 criterion,
602 lambda,
603 )
604 }
605 BasisCriterion::Cv => {
606 evaluate_nbasis_cv(data, argvals, nbasis_range, basis_type, lambda, n_folds)
607 }
608 };
609
610 let (best_idx, _) = scores
611 .iter()
612 .enumerate()
613 .min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))?;
614
615 Some(BasisNbasisCvResult {
616 optimal_nbasis: nbasis_range[best_idx],
617 scores,
618 nbasis_range: nbasis_range.to_vec(),
619 criterion,
620 })
621}
622
623#[cfg(test)]
624mod tests {
625 use super::*;
626 use std::f64::consts::PI;
627
628 fn uniform_grid(m: usize) -> Vec<f64> {
629 (0..m).map(|i| i as f64 / (m - 1) as f64).collect()
630 }
631
632 #[test]
633 fn test_bspline_penalty_matrix_symmetric() {
634 let t = uniform_grid(101);
635 let penalty = bspline_penalty_matrix(&t, 15, 4, 2);
636 let _k = 15; let actual_k = (penalty.len() as f64).sqrt() as usize;
638 for i in 0..actual_k {
639 for j in 0..actual_k {
640 assert!(
641 (penalty[i + j * actual_k] - penalty[j + i * actual_k]).abs() < 1e-10,
642 "Penalty matrix not symmetric at ({}, {})",
643 i,
644 j
645 );
646 }
647 }
648 }
649
650 #[test]
651 fn test_bspline_penalty_matrix_positive_semidefinite() {
652 let t = uniform_grid(101);
653 let penalty = bspline_penalty_matrix(&t, 10, 4, 2);
654 let k = (penalty.len() as f64).sqrt() as usize;
655 for i in 0..k {
657 assert!(
658 penalty[i + i * k] >= -1e-10,
659 "Diagonal element {} is negative: {}",
660 i,
661 penalty[i + i * k]
662 );
663 }
664 }
665
666 #[test]
667 fn test_fourier_penalty_diagonal() {
668 let penalty = fourier_penalty_matrix(7, 1.0, 2);
669 for i in 0..7 {
671 for j in 0..7 {
672 if i != j {
673 assert!(
674 penalty[i + j * 7].abs() < 1e-10,
675 "Off-diagonal ({},{}) = {}",
676 i,
677 j,
678 penalty[i + j * 7]
679 );
680 }
681 }
682 }
683 assert!(penalty[0].abs() < 1e-10);
685 assert!(penalty[1 + 7] > 0.0);
687 assert!(penalty[3 + 3 * 7] > penalty[1 + 7]);
688 }
689
690 #[test]
691 fn test_smooth_basis_bspline() {
692 let m = 101;
693 let n = 5;
694 let t = uniform_grid(m);
695
696 let mut data = FdMatrix::zeros(n, m);
698 for i in 0..n {
699 for j in 0..m {
700 data[(i, j)] = (2.0 * PI * t[j]).sin() + 0.1 * (i as f64 * 0.3 + j as f64 * 0.01);
701 }
702 }
703
704 let nbasis = 15;
705 let penalty = bspline_penalty_matrix(&t, nbasis, 4, 2);
706 let _actual_k = (penalty.len() as f64).sqrt() as usize;
707
708 let fdpar = FdPar {
709 basis_type: BasisType::Bspline { order: 4 },
710 nbasis,
711 lambda: 1e-4,
712 lfd_order: 2,
713 penalty_matrix: penalty,
714 };
715
716 let result = smooth_basis(&data, &t, &fdpar);
717 assert!(result.is_ok(), "smooth_basis should succeed");
718
719 let res = result.unwrap();
720 assert_eq!(res.fitted.shape(), (n, m));
721 assert_eq!(res.coefficients.nrows(), n);
722 assert!(res.edf > 0.0, "EDF should be positive");
723 assert!(res.gcv > 0.0, "GCV should be positive");
724 }
725
726 #[test]
727 fn test_smooth_basis_fourier() {
728 let m = 101;
729 let n = 3;
730 let t = uniform_grid(m);
731
732 let mut data = FdMatrix::zeros(n, m);
733 for i in 0..n {
734 for j in 0..m {
735 data[(i, j)] = (2.0 * PI * t[j]).sin() + (4.0 * PI * t[j]).cos();
736 }
737 }
738
739 let nbasis = 7;
740 let period = 1.0;
741 let penalty = fourier_penalty_matrix(nbasis, period, 2);
742
743 let fdpar = FdPar {
744 basis_type: BasisType::Fourier { period },
745 nbasis,
746 lambda: 1e-6,
747 lfd_order: 2,
748 penalty_matrix: penalty,
749 };
750
751 let result = smooth_basis(&data, &t, &fdpar);
752 assert!(result.is_ok());
753
754 let res = result.unwrap();
755 for j in 0..m {
757 let expected = (2.0 * PI * t[j]).sin() + (4.0 * PI * t[j]).cos();
758 assert!(
759 (res.fitted[(0, j)] - expected).abs() < 0.1,
760 "Fourier fit poor at j={}: got {}, expected {}",
761 j,
762 res.fitted[(0, j)],
763 expected
764 );
765 }
766 }
767
768 #[test]
769 fn test_smooth_basis_gcv_selects_reasonable_lambda() {
770 let m = 101;
771 let n = 5;
772 let t = uniform_grid(m);
773
774 let mut data = FdMatrix::zeros(n, m);
775 for i in 0..n {
776 for j in 0..m {
777 data[(i, j)] =
778 (2.0 * PI * t[j]).sin() + 0.1 * ((i * 37 + j * 13) % 20) as f64 / 20.0;
779 }
780 }
781
782 let basis_type = BasisType::Bspline { order: 4 };
783 let result = smooth_basis_gcv(&data, &t, &basis_type, 15, 2, (-8.0, 4.0), 25);
784 assert!(result.is_some(), "GCV search should succeed");
785 }
786
787 #[test]
788 fn test_smooth_basis_large_lambda_reduces_edf() {
789 let m = 101;
790 let n = 3;
791 let t = uniform_grid(m);
792
793 let mut data = FdMatrix::zeros(n, m);
794 for i in 0..n {
795 for j in 0..m {
796 data[(i, j)] = (2.0 * PI * t[j]).sin();
797 }
798 }
799
800 let nbasis = 15;
801 let penalty = bspline_penalty_matrix(&t, nbasis, 4, 2);
802 let _actual_k = (penalty.len() as f64).sqrt() as usize;
803
804 let fdpar_small = FdPar {
805 basis_type: BasisType::Bspline { order: 4 },
806 nbasis,
807 lambda: 1e-8,
808 lfd_order: 2,
809 penalty_matrix: penalty.clone(),
810 };
811 let fdpar_large = FdPar {
812 basis_type: BasisType::Bspline { order: 4 },
813 nbasis,
814 lambda: 1e2,
815 lfd_order: 2,
816 penalty_matrix: penalty,
817 };
818
819 let res_small = smooth_basis(&data, &t, &fdpar_small).unwrap();
820 let res_large = smooth_basis(&data, &t, &fdpar_large).unwrap();
821
822 assert!(
823 res_large.edf < res_small.edf,
824 "Larger lambda should reduce EDF: {} vs {}",
825 res_large.edf,
826 res_small.edf
827 );
828 }
829
830 #[test]
833 fn test_basis_nbasis_cv_gcv() {
834 let m = 101;
835 let n = 5;
836 let t = uniform_grid(m);
837 let mut data = FdMatrix::zeros(n, m);
838 for i in 0..n {
839 for j in 0..m {
840 data[(i, j)] =
841 (2.0 * PI * t[j]).sin() + 0.1 * ((i * 37 + j * 13) % 20) as f64 / 20.0;
842 }
843 }
844
845 let nbasis_range: Vec<usize> = (4..=20).step_by(2).collect();
846 let result = basis_nbasis_cv(
847 &data,
848 &t,
849 &nbasis_range,
850 &BasisType::Bspline { order: 4 },
851 BasisCriterion::Gcv,
852 5,
853 1e-4,
854 );
855 assert!(result.is_some());
856 let res = result.unwrap();
857 assert!(nbasis_range.contains(&res.optimal_nbasis));
858 assert_eq!(res.scores.len(), nbasis_range.len());
859 assert_eq!(res.criterion, BasisCriterion::Gcv);
860 }
861
862 #[test]
863 fn test_basis_nbasis_cv_aic_bic() {
864 let m = 51;
865 let n = 5;
866 let t = uniform_grid(m);
867 let mut data = FdMatrix::zeros(n, m);
868 for i in 0..n {
869 for j in 0..m {
870 data[(i, j)] = (2.0 * PI * t[j]).sin();
871 }
872 }
873
874 let nbasis_range: Vec<usize> = vec![5, 7, 9, 11];
875 let aic_result = basis_nbasis_cv(
876 &data,
877 &t,
878 &nbasis_range,
879 &BasisType::Bspline { order: 4 },
880 BasisCriterion::Aic,
881 5,
882 0.0,
883 );
884 let bic_result = basis_nbasis_cv(
885 &data,
886 &t,
887 &nbasis_range,
888 &BasisType::Bspline { order: 4 },
889 BasisCriterion::Bic,
890 5,
891 0.0,
892 );
893 assert!(aic_result.is_some());
894 assert!(bic_result.is_some());
895 }
896
897 #[test]
898 fn test_basis_nbasis_cv_kfold() {
899 let m = 51;
900 let n = 10;
901 let t = uniform_grid(m);
902 let mut data = FdMatrix::zeros(n, m);
903 for i in 0..n {
904 for j in 0..m {
905 data[(i, j)] = (2.0 * PI * t[j]).sin() + 0.05 * ((i * 7 + j * 3) % 10) as f64;
906 }
907 }
908
909 let nbasis_range: Vec<usize> = vec![5, 7, 9];
910 let result = basis_nbasis_cv(
911 &data,
912 &t,
913 &nbasis_range,
914 &BasisType::Bspline { order: 4 },
915 BasisCriterion::Cv,
916 5,
917 1e-4,
918 );
919 assert!(result.is_some());
920 let res = result.unwrap();
921 assert!(nbasis_range.contains(&res.optimal_nbasis));
922 assert_eq!(res.criterion, BasisCriterion::Cv);
923 }
924}