1use crate::iter_maybe_parallel;
7use nalgebra::{DMatrix, DVector, SVD};
8#[cfg(feature = "parallel")]
9use rayon::iter::ParallelIterator;
10use std::f64::consts::PI;
11
12fn svd_pseudoinverse(mat: &DMatrix<f64>) -> Option<DMatrix<f64>> {
17 let n = mat.nrows();
18 let svd = SVD::new(mat.clone(), true, true);
19 let max_sv = svd.singular_values.iter().cloned().fold(0.0_f64, f64::max);
20 let eps = 1e-10 * max_sv;
21
22 let u = svd.u.as_ref()?;
23 let v_t = svd.v_t.as_ref()?;
24
25 let s_inv: Vec<f64> = svd
26 .singular_values
27 .iter()
28 .map(|&s| if s > eps { 1.0 / s } else { 0.0 })
29 .collect();
30
31 let mut result = DMatrix::zeros(n, n);
32 for i in 0..n {
33 for j in 0..n {
34 let mut sum = 0.0;
35 for k in 0..n.min(s_inv.len()) {
36 sum += v_t[(k, i)] * s_inv[k] * u[(j, k)];
37 }
38 result[(i, j)] = sum;
39 }
40 }
41
42 Some(result)
43}
44
45fn compute_model_criterion(rss: f64, n_points: f64, edf: f64, criterion: i32) -> f64 {
53 match criterion {
54 0 => {
55 let gcv_denom = 1.0 - edf / n_points;
56 if gcv_denom.abs() > 1e-10 {
57 (rss / n_points) / (gcv_denom * gcv_denom)
58 } else {
59 f64::INFINITY
60 }
61 }
62 1 => {
63 let mse = rss / n_points;
64 n_points * mse.ln() + 2.0 * edf
65 }
66 _ => {
67 let mse = rss / n_points;
68 n_points * mse.ln() + n_points.ln() * edf
69 }
70 }
71}
72
73fn construct_bspline_knots(t_min: f64, t_max: f64, nknots: usize, order: usize) -> Vec<f64> {
75 let dt = (t_max - t_min) / (nknots - 1) as f64;
76 let mut knots = Vec::with_capacity(nknots + 2 * order);
77 for i in 0..order {
78 knots.push(t_min - (order - i) as f64 * dt);
79 }
80 for i in 0..nknots {
81 knots.push(t_min + i as f64 * dt);
82 }
83 for i in 1..=order {
84 knots.push(t_max + i as f64 * dt);
85 }
86 knots
87}
88
89fn evaluate_order_zero(t_val: f64, knots: &[f64], t_max_knot_idx: usize) -> Vec<f64> {
91 let mut b0 = vec![0.0; knots.len() - 1];
92 for j in 0..(knots.len() - 1) {
93 let in_interval = if j == t_max_knot_idx - 1 {
94 t_val >= knots[j] && t_val <= knots[j + 1]
95 } else {
96 t_val >= knots[j] && t_val < knots[j + 1]
97 };
98 if in_interval {
99 b0[j] = 1.0;
100 break;
101 }
102 }
103 b0
104}
105
106fn bspline_recurrence_step(b: &[f64], knots: &[f64], t_val: f64, k: usize) -> Vec<f64> {
108 (0..(knots.len() - k))
109 .map(|j| {
110 let d1 = knots[j + k - 1] - knots[j];
111 let d2 = knots[j + k] - knots[j + 1];
112 let left = if d1.abs() > 1e-10 {
113 (t_val - knots[j]) / d1 * b[j]
114 } else {
115 0.0
116 };
117 let right = if d2.abs() > 1e-10 {
118 (knots[j + k] - t_val) / d2 * b[j + 1]
119 } else {
120 0.0
121 };
122 left + right
123 })
124 .collect()
125}
126
127pub fn bspline_basis(t: &[f64], nknots: usize, order: usize) -> Vec<f64> {
132 let n = t.len();
133 let nbasis = nknots + order;
134
135 let t_min = t.iter().cloned().fold(f64::INFINITY, f64::min);
136 let t_max = t.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
137
138 let knots = construct_bspline_knots(t_min, t_max, nknots, order);
139 let t_max_knot_idx = order + nknots - 1;
140
141 let mut basis = vec![0.0; n * nbasis];
142
143 for (ti, &t_val) in t.iter().enumerate() {
144 let mut b = evaluate_order_zero(t_val, &knots, t_max_knot_idx);
145
146 for k in 2..=order {
147 b = bspline_recurrence_step(&b, &knots, t_val, k);
148 }
149
150 for j in 0..nbasis {
151 basis[ti + j * n] = b[j];
152 }
153 }
154
155 basis
156}
157
158pub fn fourier_basis(t: &[f64], nbasis: usize) -> Vec<f64> {
163 let t_min = t.iter().cloned().fold(f64::INFINITY, f64::min);
164 let t_max = t.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
165 let period = t_max - t_min;
166 fourier_basis_with_period(t, nbasis, period)
167}
168
169pub fn fourier_basis_with_period(t: &[f64], nbasis: usize, period: f64) -> Vec<f64> {
183 let n = t.len();
184 let t_min = t.iter().cloned().fold(f64::INFINITY, f64::min);
185
186 let mut basis = vec![0.0; n * nbasis];
187
188 for (i, &ti) in t.iter().enumerate() {
189 let x = 2.0 * PI * (ti - t_min) / period;
190
191 basis[i] = 1.0;
192
193 let mut k = 1;
194 let mut freq = 1;
195 while k < nbasis {
196 if k < nbasis {
197 basis[i + k * n] = (freq as f64 * x).sin();
198 k += 1;
199 }
200 if k < nbasis {
201 basis[i + k * n] = (freq as f64 * x).cos();
202 k += 1;
203 }
204 freq += 1;
205 }
206 }
207
208 basis
209}
210
211pub fn difference_matrix(n: usize, order: usize) -> DMatrix<f64> {
213 if order == 0 {
214 return DMatrix::identity(n, n);
215 }
216
217 let mut d = DMatrix::zeros(n - 1, n);
218 for i in 0..(n - 1) {
219 d[(i, i)] = -1.0;
220 d[(i, i + 1)] = 1.0;
221 }
222
223 let mut result = d;
224 for _ in 1..order {
225 if result.nrows() <= 1 {
226 break;
227 }
228 let rows = result.nrows() - 1;
229 let cols = result.ncols();
230 let mut d_next = DMatrix::zeros(rows, cols);
231 for i in 0..rows {
232 for j in 0..cols {
233 d_next[(i, j)] = -result[(i, j)] + result[(i + 1, j)];
234 }
235 }
236 result = d_next;
237 }
238
239 result
240}
241
242pub struct BasisProjectionResult {
244 pub coefficients: Vec<f64>,
246 pub n_basis: usize,
248}
249
250pub fn fdata_to_basis_1d(
260 data: &[f64],
261 n: usize,
262 m: usize,
263 argvals: &[f64],
264 nbasis: usize,
265 basis_type: i32,
266) -> Option<BasisProjectionResult> {
267 if n == 0 || m == 0 || argvals.len() != m || nbasis < 2 {
268 return None;
269 }
270
271 let basis = if basis_type == 1 {
272 fourier_basis(argvals, nbasis)
273 } else {
274 bspline_basis(argvals, nbasis.saturating_sub(4).max(2), 4)
276 };
277
278 let actual_nbasis = basis.len() / m;
279 let b_mat = DMatrix::from_column_slice(m, actual_nbasis, &basis);
280
281 let btb = &b_mat.transpose() * &b_mat;
282 let btb_inv = svd_pseudoinverse(&btb)?;
283 let proj = btb_inv * b_mat.transpose();
284
285 let coefs: Vec<f64> = iter_maybe_parallel!(0..n)
286 .flat_map(|i| {
287 let curve: Vec<f64> = (0..m).map(|j| data[i + j * n]).collect();
288 (0..actual_nbasis)
289 .map(|k| {
290 let mut sum = 0.0;
291 for j in 0..m {
292 sum += proj[(k, j)] * curve[j];
293 }
294 sum
295 })
296 .collect::<Vec<_>>()
297 })
298 .collect();
299
300 Some(BasisProjectionResult {
301 coefficients: coefs,
302 n_basis: actual_nbasis,
303 })
304}
305
306pub fn basis_to_fdata_1d(
308 coefs: &[f64],
309 n: usize,
310 coefs_ncols: usize,
311 argvals: &[f64],
312 nbasis: usize,
313 basis_type: i32,
314) -> Vec<f64> {
315 let m = argvals.len();
316 if n == 0 || m == 0 || nbasis < 2 {
317 return Vec::new();
318 }
319
320 let basis = if basis_type == 1 {
321 fourier_basis(argvals, nbasis)
322 } else {
323 bspline_basis(argvals, nbasis.saturating_sub(4).max(2), 4)
325 };
326
327 let actual_nbasis = basis.len() / m;
328
329 iter_maybe_parallel!(0..n)
330 .flat_map(|i| {
331 (0..m)
332 .map(|j| {
333 let mut sum = 0.0;
334 for k in 0..actual_nbasis.min(coefs_ncols) {
335 sum += coefs[i + k * n] * basis[j + k * m];
336 }
337 sum
338 })
339 .collect::<Vec<_>>()
340 })
341 .collect()
342}
343
344pub struct PsplineFitResult {
346 pub coefficients: Vec<f64>,
348 pub fitted: Vec<f64>,
350 pub edf: f64,
352 pub rss: f64,
354 pub gcv: f64,
356 pub aic: f64,
358 pub bic: f64,
360 pub n_basis: usize,
362}
363
364pub fn pspline_fit_1d(
366 data: &[f64],
367 n: usize,
368 m: usize,
369 argvals: &[f64],
370 nbasis: usize,
371 lambda: f64,
372 order: usize,
373) -> Option<PsplineFitResult> {
374 if n == 0 || m == 0 || nbasis < 2 || argvals.len() != m {
375 return None;
376 }
377
378 let basis = bspline_basis(argvals, nbasis.saturating_sub(4).max(2), 4);
380 let actual_nbasis = basis.len() / m;
381 let b_mat = DMatrix::from_column_slice(m, actual_nbasis, &basis);
382
383 let d = difference_matrix(actual_nbasis, order);
384 let penalty = &d.transpose() * &d;
385
386 let btb = &b_mat.transpose() * &b_mat;
387 let btb_penalized = &btb + lambda * &penalty;
388
389 let btb_inv = svd_pseudoinverse(&btb_penalized)?;
390 let proj = &btb_inv * b_mat.transpose();
391 let h_mat = &b_mat * &proj;
392 let edf: f64 = (0..m).map(|i| h_mat[(i, i)]).sum();
393
394 let mut all_coefs = vec![0.0; n * actual_nbasis];
395 let mut all_fitted = vec![0.0; n * m];
396 let mut total_rss = 0.0;
397
398 for i in 0..n {
399 let curve: Vec<f64> = (0..m).map(|j| data[i + j * n]).collect();
400 let curve_vec = DVector::from_vec(curve.clone());
401
402 let bt_y = b_mat.transpose() * &curve_vec;
403 let coefs = &btb_inv * bt_y;
404
405 for k in 0..actual_nbasis {
406 all_coefs[i + k * n] = coefs[k];
407 }
408
409 let fitted = &b_mat * &coefs;
410 for j in 0..m {
411 all_fitted[i + j * n] = fitted[j];
412 let resid = curve[j] - fitted[j];
413 total_rss += resid * resid;
414 }
415 }
416
417 let total_points = (n * m) as f64;
418
419 let gcv_denom = 1.0 - edf / m as f64;
420 let gcv = if gcv_denom.abs() > 1e-10 {
421 (total_rss / total_points) / (gcv_denom * gcv_denom)
422 } else {
423 f64::INFINITY
424 };
425
426 let mse = total_rss / total_points;
427 let aic = total_points * mse.ln() + 2.0 * edf;
428 let bic = total_points * mse.ln() + total_points.ln() * edf;
429
430 Some(PsplineFitResult {
431 coefficients: all_coefs,
432 fitted: all_fitted,
433 edf,
434 rss: total_rss,
435 gcv,
436 aic,
437 bic,
438 n_basis: actual_nbasis,
439 })
440}
441
442pub struct FourierFitResult {
444 pub coefficients: Vec<f64>,
446 pub fitted: Vec<f64>,
448 pub edf: f64,
450 pub rss: f64,
452 pub gcv: f64,
454 pub aic: f64,
456 pub bic: f64,
458 pub n_basis: usize,
460}
461
462fn compute_fit_criteria(total_rss: f64, total_points: f64, edf: f64, m: usize) -> (f64, f64, f64) {
464 let gcv_denom = 1.0 - edf / m as f64;
465 let gcv = if gcv_denom.abs() > 1e-10 {
466 (total_rss / total_points) / (gcv_denom * gcv_denom)
467 } else {
468 f64::INFINITY
469 };
470 let mse = total_rss / total_points;
471 let aic = total_points * mse.ln() + 2.0 * edf;
472 let bic = total_points * mse.ln() + total_points.ln() * edf;
473 (gcv, aic, bic)
474}
475
476pub fn fourier_fit_1d(
491 data: &[f64],
492 n: usize,
493 m: usize,
494 argvals: &[f64],
495 nbasis: usize,
496) -> Option<FourierFitResult> {
497 if n == 0 || m == 0 || nbasis < 3 || argvals.len() != m {
498 return None;
499 }
500
501 let nbasis = if nbasis % 2 == 0 { nbasis + 1 } else { nbasis };
503
504 let basis = fourier_basis(argvals, nbasis);
505 let actual_nbasis = basis.len() / m;
506 let b_mat = DMatrix::from_column_slice(m, actual_nbasis, &basis);
507
508 let btb = &b_mat.transpose() * &b_mat;
509 let btb_inv = svd_pseudoinverse(&btb)?;
510 let proj = &btb_inv * b_mat.transpose();
511 let h_mat = &b_mat * &proj;
512 let edf: f64 = (0..m).map(|i| h_mat[(i, i)]).sum();
513
514 let mut all_coefs = vec![0.0; n * actual_nbasis];
515 let mut all_fitted = vec![0.0; n * m];
516 let mut total_rss = 0.0;
517
518 for i in 0..n {
519 let curve: Vec<f64> = (0..m).map(|j| data[i + j * n]).collect();
520 let curve_vec = DVector::from_vec(curve.clone());
521
522 let bt_y = b_mat.transpose() * &curve_vec;
523 let coefs = &btb_inv * bt_y;
524
525 for k in 0..actual_nbasis {
526 all_coefs[i + k * n] = coefs[k];
527 }
528
529 let fitted = &b_mat * &coefs;
530 for j in 0..m {
531 all_fitted[i + j * n] = fitted[j];
532 let resid = curve[j] - fitted[j];
533 total_rss += resid * resid;
534 }
535 }
536
537 let total_points = (n * m) as f64;
538 let (gcv, aic, bic) = compute_fit_criteria(total_rss, total_points, edf, m);
539
540 Some(FourierFitResult {
541 coefficients: all_coefs,
542 fitted: all_fitted,
543 edf,
544 rss: total_rss,
545 gcv,
546 aic,
547 bic,
548 n_basis: actual_nbasis,
549 })
550}
551
552pub fn select_fourier_nbasis_gcv(
567 data: &[f64],
568 n: usize,
569 m: usize,
570 argvals: &[f64],
571 min_nbasis: usize,
572 max_nbasis: usize,
573) -> usize {
574 let min_nb = min_nbasis.max(3);
575 let max_nb = max_nbasis.min(m / 2);
577
578 if max_nb <= min_nb {
579 return min_nb;
580 }
581
582 let mut best_nbasis = min_nb;
583 let mut best_gcv = f64::INFINITY;
584
585 let mut nbasis = if min_nb % 2 == 0 { min_nb + 1 } else { min_nb };
587 while nbasis <= max_nb {
588 if let Some(result) = fourier_fit_1d(data, n, m, argvals, nbasis) {
589 if result.gcv < best_gcv && result.gcv.is_finite() {
590 best_gcv = result.gcv;
591 best_nbasis = nbasis;
592 }
593 }
594 nbasis += 2;
595 }
596
597 best_nbasis
598}
599
600#[derive(Clone)]
602pub struct SingleCurveSelection {
603 pub basis_type: i32,
605 pub nbasis: usize,
607 pub score: f64,
609 pub coefficients: Vec<f64>,
611 pub fitted: Vec<f64>,
613 pub edf: f64,
615 pub seasonal_detected: bool,
617 pub lambda: f64,
619}
620
621pub struct BasisAutoSelectionResult {
623 pub selections: Vec<SingleCurveSelection>,
625 pub criterion: i32,
627}
628
629fn detect_seasonality_fft(curve: &[f64]) -> bool {
634 use rustfft::{num_complex::Complex, FftPlanner};
635
636 let n = curve.len();
637 if n < 8 {
638 return false;
639 }
640
641 let mean: f64 = curve.iter().sum::<f64>() / n as f64;
643 let mut input: Vec<Complex<f64>> = curve.iter().map(|&x| Complex::new(x - mean, 0.0)).collect();
644
645 let mut planner = FftPlanner::new();
646 let fft = planner.plan_fft_forward(n);
647 fft.process(&mut input);
648
649 let powers: Vec<f64> = input[1..n / 2].iter().map(|c| c.norm_sqr()).collect();
651
652 if powers.is_empty() {
653 return false;
654 }
655
656 let max_power = powers.iter().cloned().fold(0.0_f64, f64::max);
657 let mean_power = powers.iter().sum::<f64>() / powers.len() as f64;
658
659 max_power > 2.0 * mean_power
661}
662
663fn fit_curve_fourier(
665 curve: &[f64],
666 m: usize,
667 argvals: &[f64],
668 nbasis: usize,
669 criterion: i32,
670) -> Option<(f64, Vec<f64>, Vec<f64>, f64)> {
671 let nbasis = if nbasis % 2 == 0 { nbasis + 1 } else { nbasis };
672
673 let basis = fourier_basis(argvals, nbasis);
674 let actual_nbasis = basis.len() / m;
675 let b_mat = DMatrix::from_column_slice(m, actual_nbasis, &basis);
676
677 let btb = &b_mat.transpose() * &b_mat;
678 let btb_inv = svd_pseudoinverse(&btb)?;
679 let proj = &btb_inv * b_mat.transpose();
680 let h_mat = &b_mat * &proj;
681 let edf: f64 = (0..m).map(|i| h_mat[(i, i)]).sum();
682
683 let curve_vec = DVector::from_column_slice(curve);
684 let coefs = &btb_inv * (b_mat.transpose() * &curve_vec);
685 let fitted = &b_mat * &coefs;
686
687 let rss: f64 = (0..m).map(|j| (curve[j] - fitted[j]).powi(2)).sum();
688 let score = compute_model_criterion(rss, m as f64, edf, criterion);
689
690 let coef_vec: Vec<f64> = (0..actual_nbasis).map(|k| coefs[k]).collect();
691 let fitted_vec: Vec<f64> = (0..m).map(|j| fitted[j]).collect();
692
693 Some((score, coef_vec, fitted_vec, edf))
694}
695
696fn fit_curve_pspline(
698 curve: &[f64],
699 m: usize,
700 argvals: &[f64],
701 nbasis: usize,
702 lambda: f64,
703 order: usize,
704 criterion: i32,
705) -> Option<(f64, Vec<f64>, Vec<f64>, f64)> {
706 let basis = bspline_basis(argvals, nbasis.saturating_sub(4).max(2), 4);
707 let actual_nbasis = basis.len() / m;
708 let b_mat = DMatrix::from_column_slice(m, actual_nbasis, &basis);
709
710 let d = difference_matrix(actual_nbasis, order);
711 let penalty = &d.transpose() * &d;
712 let btb = &b_mat.transpose() * &b_mat;
713 let btb_penalized = &btb + lambda * &penalty;
714
715 let btb_inv = svd_pseudoinverse(&btb_penalized)?;
716 let proj = &btb_inv * b_mat.transpose();
717 let h_mat = &b_mat * &proj;
718 let edf: f64 = (0..m).map(|i| h_mat[(i, i)]).sum();
719
720 let curve_vec = DVector::from_column_slice(curve);
721 let coefs = &btb_inv * (b_mat.transpose() * &curve_vec);
722 let fitted = &b_mat * &coefs;
723
724 let rss: f64 = (0..m).map(|j| (curve[j] - fitted[j]).powi(2)).sum();
725 let score = compute_model_criterion(rss, m as f64, edf, criterion);
726
727 let coef_vec: Vec<f64> = (0..actual_nbasis).map(|k| coefs[k]).collect();
728 let fitted_vec: Vec<f64> = (0..m).map(|j| fitted[j]).collect();
729
730 Some((score, coef_vec, fitted_vec, edf))
731}
732
733struct BasisSearchResult {
735 score: f64,
736 nbasis: usize,
737 coefs: Vec<f64>,
738 fitted: Vec<f64>,
739 edf: f64,
740 lambda: f64,
741}
742
743fn search_fourier_basis(
745 curve: &[f64],
746 m: usize,
747 argvals: &[f64],
748 fourier_min: usize,
749 fourier_max: usize,
750 seasonal: bool,
751 criterion: i32,
752) -> Option<BasisSearchResult> {
753 let fourier_start = if seasonal {
754 fourier_min.max(5)
755 } else {
756 fourier_min
757 };
758 let mut nb = if fourier_start % 2 == 0 {
759 fourier_start + 1
760 } else {
761 fourier_start
762 };
763
764 let mut best: Option<BasisSearchResult> = None;
765 while nb <= fourier_max {
766 if let Some((score, coefs, fitted, edf)) =
767 fit_curve_fourier(curve, m, argvals, nb, criterion)
768 {
769 if score.is_finite() && best.as_ref().map_or(true, |b| score < b.score) {
770 best = Some(BasisSearchResult {
771 score,
772 nbasis: nb,
773 coefs,
774 fitted,
775 edf,
776 lambda: f64::NAN,
777 });
778 }
779 }
780 nb += 2;
781 }
782 best
783}
784
785fn try_pspline_fit_update(
787 curve: &[f64],
788 m: usize,
789 argvals: &[f64],
790 nb: usize,
791 lam: f64,
792 criterion: i32,
793 best: &mut Option<BasisSearchResult>,
794) {
795 if let Some((score, coefs, fitted, edf)) =
796 fit_curve_pspline(curve, m, argvals, nb, lam, 2, criterion)
797 {
798 if score.is_finite() && best.as_ref().map_or(true, |b| score < b.score) {
799 *best = Some(BasisSearchResult {
800 score,
801 nbasis: nb,
802 coefs,
803 fitted,
804 edf,
805 lambda: lam,
806 });
807 }
808 }
809}
810
811fn search_pspline_basis(
813 curve: &[f64],
814 m: usize,
815 argvals: &[f64],
816 pspline_min: usize,
817 pspline_max: usize,
818 lambda_grid: &[f64],
819 auto_lambda: bool,
820 lambda: f64,
821 criterion: i32,
822) -> Option<BasisSearchResult> {
823 let mut best: Option<BasisSearchResult> = None;
824 for nb in pspline_min..=pspline_max {
825 let lambdas: Box<dyn Iterator<Item = f64>> = if auto_lambda {
826 Box::new(lambda_grid.iter().copied())
827 } else {
828 Box::new(std::iter::once(lambda))
829 };
830 for lam in lambdas {
831 try_pspline_fit_update(curve, m, argvals, nb, lam, criterion, &mut best);
832 }
833 }
834 best
835}
836
837pub fn select_basis_auto_1d(
857 data: &[f64],
858 n: usize,
859 m: usize,
860 argvals: &[f64],
861 criterion: i32,
862 nbasis_min: usize,
863 nbasis_max: usize,
864 lambda_pspline: f64,
865 use_seasonal_hint: bool,
866) -> BasisAutoSelectionResult {
867 let fourier_min = if nbasis_min > 0 { nbasis_min.max(3) } else { 3 };
868 let fourier_max = if nbasis_max > 0 {
869 nbasis_max.min(m / 3).min(25)
870 } else {
871 (m / 3).min(25)
872 };
873 let pspline_min = if nbasis_min > 0 { nbasis_min.max(6) } else { 6 };
874 let pspline_max = if nbasis_max > 0 {
875 nbasis_max.min(m / 2).min(40)
876 } else {
877 (m / 2).min(40)
878 };
879
880 let lambda_grid = [0.001, 0.01, 0.1, 1.0, 10.0, 100.0];
881 let auto_lambda = lambda_pspline < 0.0;
882
883 let selections: Vec<SingleCurveSelection> = iter_maybe_parallel!(0..n)
884 .map(|i| {
885 let curve: Vec<f64> = (0..m).map(|j| data[i + j * n]).collect();
886 let seasonal_detected = if use_seasonal_hint {
887 detect_seasonality_fft(&curve)
888 } else {
889 false
890 };
891
892 let fourier_best = search_fourier_basis(
893 &curve,
894 m,
895 argvals,
896 fourier_min,
897 fourier_max,
898 seasonal_detected,
899 criterion,
900 );
901 let pspline_best = search_pspline_basis(
902 &curve,
903 m,
904 argvals,
905 pspline_min,
906 pspline_max,
907 &lambda_grid,
908 auto_lambda,
909 lambda_pspline,
910 criterion,
911 );
912
913 let (basis_type, result) = match (fourier_best, pspline_best) {
915 (Some(f), Some(p)) => {
916 if f.score <= p.score {
917 (1i32, f)
918 } else {
919 (0i32, p)
920 }
921 }
922 (Some(f), None) => (1, f),
923 (None, Some(p)) => (0, p),
924 (None, None) => {
925 return SingleCurveSelection {
926 basis_type: 0,
927 nbasis: pspline_min,
928 score: f64::INFINITY,
929 coefficients: Vec::new(),
930 fitted: Vec::new(),
931 edf: 0.0,
932 seasonal_detected,
933 lambda: f64::NAN,
934 };
935 }
936 };
937
938 SingleCurveSelection {
939 basis_type,
940 nbasis: result.nbasis,
941 score: result.score,
942 coefficients: result.coefs,
943 fitted: result.fitted,
944 edf: result.edf,
945 seasonal_detected,
946 lambda: result.lambda,
947 }
948 })
949 .collect();
950
951 BasisAutoSelectionResult {
952 selections,
953 criterion,
954 }
955}
956
957#[cfg(test)]
958mod tests {
959 use super::*;
960 use std::f64::consts::PI;
961
962 fn uniform_grid(n: usize) -> Vec<f64> {
964 (0..n).map(|i| i as f64 / (n - 1) as f64).collect()
965 }
966
967 fn sine_wave(t: &[f64], freq: f64) -> Vec<f64> {
969 t.iter().map(|&ti| (2.0 * PI * freq * ti).sin()).collect()
970 }
971
972 #[test]
975 fn test_bspline_basis_dimensions() {
976 let t = uniform_grid(50);
977 let nknots = 10;
978 let order = 4;
979 let basis = bspline_basis(&t, nknots, order);
980
981 let expected_nbasis = nknots + order;
982 assert_eq!(basis.len(), t.len() * expected_nbasis);
983 }
984
985 #[test]
986 fn test_bspline_basis_partition_of_unity() {
987 let t = uniform_grid(50);
989 let nknots = 8;
990 let order = 4;
991 let basis = bspline_basis(&t, nknots, order);
992
993 let nbasis = nknots + order;
994 for i in 0..t.len() {
995 let sum: f64 = (0..nbasis).map(|j| basis[i + j * t.len()]).sum();
996 assert!(
997 (sum - 1.0).abs() < 1e-10,
998 "B-spline partition of unity failed at point {}: sum = {}",
999 i,
1000 sum
1001 );
1002 }
1003 }
1004
1005 #[test]
1006 fn test_bspline_basis_non_negative() {
1007 let t = uniform_grid(50);
1008 let basis = bspline_basis(&t, 8, 4);
1009
1010 for &val in &basis {
1011 assert!(val >= -1e-10, "B-spline values should be non-negative");
1012 }
1013 }
1014
1015 #[test]
1016 fn test_bspline_basis_boundary() {
1017 let t = vec![0.0, 0.5, 1.0];
1019 let basis = bspline_basis(&t, 5, 4);
1020
1021 for &val in &basis {
1023 assert!(val.is_finite(), "B-spline should produce finite values");
1024 }
1025 }
1026
1027 #[test]
1030 fn test_fourier_basis_dimensions() {
1031 let t = uniform_grid(50);
1032 let nbasis = 7;
1033 let basis = fourier_basis(&t, nbasis);
1034
1035 assert_eq!(basis.len(), t.len() * nbasis);
1036 }
1037
1038 #[test]
1039 fn test_fourier_basis_constant_first_column() {
1040 let t = uniform_grid(50);
1041 let nbasis = 7;
1042 let basis = fourier_basis(&t, nbasis);
1043
1044 let first_val = basis[0];
1046 for i in 0..t.len() {
1047 assert!(
1048 (basis[i] - first_val).abs() < 1e-10,
1049 "First Fourier column should be constant"
1050 );
1051 }
1052 }
1053
1054 #[test]
1055 fn test_fourier_basis_sin_cos_range() {
1056 let t = uniform_grid(100);
1057 let nbasis = 11;
1058 let basis = fourier_basis(&t, nbasis);
1059
1060 for &val in &basis {
1062 assert!((-1.0 - 1e-10..=1.0 + 1e-10).contains(&val));
1063 }
1064 }
1065
1066 #[test]
1067 fn test_fourier_basis_with_period() {
1068 let t = uniform_grid(100);
1069 let nbasis = 5;
1070 let period = 0.5;
1071 let basis = fourier_basis_with_period(&t, nbasis, period);
1072
1073 assert_eq!(basis.len(), t.len() * nbasis);
1074 let first_val = basis[0];
1076 for i in 0..t.len() {
1077 assert!((basis[i] - first_val).abs() < 1e-10);
1078 }
1079 }
1080
1081 #[test]
1082 fn test_fourier_basis_period_affects_frequency() {
1083 let t = uniform_grid(100);
1084 let nbasis = 5;
1085
1086 let basis1 = fourier_basis_with_period(&t, nbasis, 1.0);
1087 let basis2 = fourier_basis_with_period(&t, nbasis, 0.5);
1088
1089 let n = t.len();
1091 let mut any_different = false;
1092 for i in 0..n {
1093 if (basis1[i + n] - basis2[i + n]).abs() > 1e-10 {
1095 any_different = true;
1096 break;
1097 }
1098 }
1099 assert!(
1100 any_different,
1101 "Different periods should produce different bases"
1102 );
1103 }
1104
1105 #[test]
1108 fn test_difference_matrix_order_zero() {
1109 let d = difference_matrix(5, 0);
1110 assert_eq!(d.nrows(), 5);
1111 assert_eq!(d.ncols(), 5);
1112
1113 for i in 0..5 {
1115 for j in 0..5 {
1116 let expected = if i == j { 1.0 } else { 0.0 };
1117 assert!((d[(i, j)] - expected).abs() < 1e-10);
1118 }
1119 }
1120 }
1121
1122 #[test]
1123 fn test_difference_matrix_first_order() {
1124 let d = difference_matrix(5, 1);
1125 assert_eq!(d.nrows(), 4);
1126 assert_eq!(d.ncols(), 5);
1127
1128 let x = DVector::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1130 let dx = &d * x;
1131 for i in 0..4 {
1132 assert!((dx[i] - 1.0).abs() < 1e-10);
1133 }
1134 }
1135
1136 #[test]
1137 fn test_difference_matrix_second_order() {
1138 let d = difference_matrix(5, 2);
1139 assert_eq!(d.nrows(), 3);
1140 assert_eq!(d.ncols(), 5);
1141
1142 let x = DVector::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1144 let dx = &d * x;
1145 for i in 0..3 {
1146 assert!(dx[i].abs() < 1e-10, "Second diff of linear should be zero");
1147 }
1148 }
1149
1150 #[test]
1151 fn test_difference_matrix_quadratic() {
1152 let d = difference_matrix(5, 2);
1153
1154 let x = DVector::from_vec(vec![1.0, 4.0, 9.0, 16.0, 25.0]);
1156 let dx = &d * x;
1157 for i in 0..3 {
1158 assert!(
1159 (dx[i] - 2.0).abs() < 1e-10,
1160 "Second diff of squares should be 2"
1161 );
1162 }
1163 }
1164
1165 #[test]
1168 fn test_fdata_to_basis_1d_bspline() {
1169 let t = uniform_grid(50);
1170 let n = 5;
1171 let m = t.len();
1172
1173 let data: Vec<f64> = (0..n)
1175 .flat_map(|i| t.iter().map(move |&ti| ti + i as f64 * 0.1))
1176 .collect();
1177
1178 let result = fdata_to_basis_1d(&data, n, m, &t, 10, 0);
1179 assert!(result.is_some());
1180
1181 let res = result.unwrap();
1182 assert!(res.n_basis > 0);
1183 assert_eq!(res.coefficients.len(), n * res.n_basis);
1184 }
1185
1186 #[test]
1187 fn test_fdata_to_basis_1d_fourier() {
1188 let t = uniform_grid(50);
1189 let n = 5;
1190 let m = t.len();
1191
1192 let data: Vec<f64> = (0..n).flat_map(|_| sine_wave(&t, 2.0)).collect();
1194
1195 let result = fdata_to_basis_1d(&data, n, m, &t, 7, 1);
1196 assert!(result.is_some());
1197
1198 let res = result.unwrap();
1199 assert_eq!(res.n_basis, 7);
1200 }
1201
1202 #[test]
1203 fn test_fdata_to_basis_1d_invalid_input() {
1204 let t = uniform_grid(50);
1205
1206 let result = fdata_to_basis_1d(&[], 0, 50, &t, 10, 0);
1208 assert!(result.is_none());
1209
1210 let data = vec![0.0; 50];
1212 let result = fdata_to_basis_1d(&data, 1, 50, &t, 1, 0);
1213 assert!(result.is_none());
1214 }
1215
1216 #[test]
1217 fn test_basis_roundtrip() {
1218 let t = uniform_grid(100);
1219 let n = 1;
1220 let m = t.len();
1221
1222 let data = sine_wave(&t, 1.0);
1224
1225 let proj = fdata_to_basis_1d(&data, n, m, &t, 5, 1).unwrap();
1227
1228 let reconstructed =
1230 basis_to_fdata_1d(&proj.coefficients, n, proj.n_basis, &t, proj.n_basis, 1);
1231
1232 let mut max_error = 0.0;
1234 for i in 0..m {
1235 let err = (data[i] - reconstructed[i]).abs();
1236 if err > max_error {
1237 max_error = err;
1238 }
1239 }
1240 assert!(max_error < 0.5, "Roundtrip error too large: {}", max_error);
1241 }
1242
1243 #[test]
1244 fn test_basis_to_fdata_empty_input() {
1245 let result = basis_to_fdata_1d(&[], 0, 0, &[], 5, 0);
1246 assert!(result.is_empty());
1247 }
1248
1249 #[test]
1252 fn test_pspline_fit_1d_basic() {
1253 let t = uniform_grid(50);
1254 let n = 3;
1255 let m = t.len();
1256
1257 let data: Vec<f64> = (0..n)
1259 .flat_map(|i| {
1260 t.iter()
1261 .enumerate()
1262 .map(move |(j, &ti)| (2.0 * PI * ti).sin() + 0.1 * (i * j) as f64 % 1.0)
1263 })
1264 .collect();
1265
1266 let result = pspline_fit_1d(&data, n, m, &t, 15, 1.0, 2);
1267 assert!(result.is_some());
1268
1269 let res = result.unwrap();
1270 assert!(res.n_basis > 0);
1271 assert_eq!(res.fitted.len(), n * m);
1272 assert!(res.rss >= 0.0);
1273 assert!(res.edf > 0.0);
1274 assert!(res.gcv.is_finite());
1275 }
1276
1277 #[test]
1278 fn test_pspline_fit_1d_smoothness() {
1279 let t = uniform_grid(50);
1280 let n = 1;
1281 let m = t.len();
1282
1283 let data: Vec<f64> = t
1285 .iter()
1286 .enumerate()
1287 .map(|(i, &ti)| (2.0 * PI * ti).sin() + 0.3 * ((i * 17) % 100) as f64 / 100.0)
1288 .collect();
1289
1290 let low_lambda = pspline_fit_1d(&data, n, m, &t, 15, 0.01, 2).unwrap();
1291 let high_lambda = pspline_fit_1d(&data, n, m, &t, 15, 100.0, 2).unwrap();
1292
1293 assert!(high_lambda.edf < low_lambda.edf);
1295 }
1296
1297 #[test]
1298 fn test_pspline_fit_1d_invalid_input() {
1299 let t = uniform_grid(50);
1300 let result = pspline_fit_1d(&[], 0, 50, &t, 15, 1.0, 2);
1301 assert!(result.is_none());
1302 }
1303
1304 #[test]
1307 fn test_fourier_fit_1d_sine_wave() {
1308 let t = uniform_grid(100);
1309 let n = 1;
1310 let m = t.len();
1311
1312 let data = sine_wave(&t, 2.0);
1314
1315 let result = fourier_fit_1d(&data, n, m, &t, 11);
1316 assert!(result.is_some());
1317
1318 let res = result.unwrap();
1319 assert!(res.rss < 1e-6, "Pure sine should have near-zero RSS");
1320 }
1321
1322 #[test]
1323 fn test_fourier_fit_1d_makes_nbasis_odd() {
1324 let t = uniform_grid(50);
1325 let data = sine_wave(&t, 1.0);
1326
1327 let result = fourier_fit_1d(&data, 1, t.len(), &t, 6);
1329 assert!(result.is_some());
1330
1331 let res = result.unwrap();
1333 assert!(res.n_basis % 2 == 1);
1334 }
1335
1336 #[test]
1337 fn test_fourier_fit_1d_criteria() {
1338 let t = uniform_grid(50);
1339 let data = sine_wave(&t, 2.0);
1340
1341 let result = fourier_fit_1d(&data, 1, t.len(), &t, 9).unwrap();
1342
1343 assert!(result.gcv.is_finite());
1345 assert!(result.aic.is_finite());
1346 assert!(result.bic.is_finite());
1347 }
1348
1349 #[test]
1350 fn test_fourier_fit_1d_invalid_nbasis() {
1351 let t = uniform_grid(50);
1352 let data = sine_wave(&t, 1.0);
1353
1354 let result = fourier_fit_1d(&data, 1, t.len(), &t, 2);
1356 assert!(result.is_none());
1357 }
1358
1359 #[test]
1362 fn test_select_fourier_nbasis_gcv_range() {
1363 let t = uniform_grid(100);
1364 let data = sine_wave(&t, 3.0);
1365
1366 let best = select_fourier_nbasis_gcv(&data, 1, t.len(), &t, 3, 15);
1367
1368 assert!((3..=15).contains(&best));
1369 assert!(best % 2 == 1, "Selected nbasis should be odd");
1370 }
1371
1372 #[test]
1373 fn test_select_fourier_nbasis_gcv_respects_min() {
1374 let t = uniform_grid(50);
1375 let data = sine_wave(&t, 1.0);
1376
1377 let best = select_fourier_nbasis_gcv(&data, 1, t.len(), &t, 7, 15);
1378 assert!(best >= 7);
1379 }
1380
1381 #[test]
1384 fn test_select_basis_auto_1d_returns_results() {
1385 let t = uniform_grid(50);
1386 let n = 3;
1387 let m = t.len();
1388
1389 let data: Vec<f64> = (0..n).flat_map(|i| sine_wave(&t, 1.0 + i as f64)).collect();
1390
1391 let result = select_basis_auto_1d(&data, n, m, &t, 0, 5, 15, 1.0, false);
1392
1393 assert_eq!(result.selections.len(), n);
1394 for sel in &result.selections {
1395 assert!(sel.nbasis >= 3);
1396 assert!(!sel.coefficients.is_empty());
1397 assert_eq!(sel.fitted.len(), m);
1398 }
1399 }
1400
1401 #[test]
1402 fn test_select_basis_auto_1d_seasonal_hint() {
1403 let t = uniform_grid(100);
1404 let n = 1;
1405 let m = t.len();
1406
1407 let data = sine_wave(&t, 5.0);
1409
1410 let result = select_basis_auto_1d(&data, n, m, &t, 0, 0, 0, -1.0, true);
1411
1412 assert_eq!(result.selections.len(), 1);
1413 assert!(result.selections[0].seasonal_detected);
1414 }
1415
1416 #[test]
1417 fn test_select_basis_auto_1d_non_seasonal() {
1418 let t = uniform_grid(50);
1419 let n = 1;
1420 let m = t.len();
1421
1422 let data: Vec<f64> = vec![1.0; m];
1424
1425 let result = select_basis_auto_1d(&data, n, m, &t, 0, 0, 0, -1.0, true);
1426
1427 assert!(!result.selections[0].seasonal_detected);
1429 }
1430
1431 #[test]
1432 fn test_select_basis_auto_1d_criterion_options() {
1433 let t = uniform_grid(50);
1434 let data = sine_wave(&t, 2.0);
1435
1436 let gcv_result = select_basis_auto_1d(&data, 1, t.len(), &t, 0, 0, 0, 1.0, false);
1438 let aic_result = select_basis_auto_1d(&data, 1, t.len(), &t, 1, 0, 0, 1.0, false);
1439 let bic_result = select_basis_auto_1d(&data, 1, t.len(), &t, 2, 0, 0, 1.0, false);
1440
1441 assert_eq!(gcv_result.criterion, 0);
1442 assert_eq!(aic_result.criterion, 1);
1443 assert_eq!(bic_result.criterion, 2);
1444 }
1445}