1use crate::iter_maybe_parallel;
7use crate::matrix::FdMatrix;
8use nalgebra::{DMatrix, DVector, SVD};
9#[cfg(feature = "parallel")]
10use rayon::iter::ParallelIterator;
11use std::f64::consts::PI;
12
13fn svd_pseudoinverse(mat: &DMatrix<f64>) -> Option<DMatrix<f64>> {
18 let n = mat.nrows();
19 let svd = SVD::new(mat.clone(), true, true);
20 let max_sv = svd.singular_values.iter().cloned().fold(0.0_f64, f64::max);
21 let eps = 1e-10 * max_sv;
22
23 let u = svd.u.as_ref()?;
24 let v_t = svd.v_t.as_ref()?;
25
26 let s_inv: Vec<f64> = svd
27 .singular_values
28 .iter()
29 .map(|&s| if s > eps { 1.0 / s } else { 0.0 })
30 .collect();
31
32 let mut result = DMatrix::zeros(n, n);
33 for i in 0..n {
34 for j in 0..n {
35 let mut sum = 0.0;
36 for k in 0..n.min(s_inv.len()) {
37 sum += v_t[(k, i)] * s_inv[k] * u[(j, k)];
38 }
39 result[(i, j)] = sum;
40 }
41 }
42
43 Some(result)
44}
45
46fn compute_model_criterion(rss: f64, n_points: f64, edf: f64, criterion: i32) -> f64 {
54 match criterion {
55 0 => {
56 let gcv_denom = 1.0 - edf / n_points;
57 if gcv_denom.abs() > 1e-10 {
58 (rss / n_points) / (gcv_denom * gcv_denom)
59 } else {
60 f64::INFINITY
61 }
62 }
63 1 => {
64 let mse = rss / n_points;
65 n_points * mse.ln() + 2.0 * edf
66 }
67 _ => {
68 let mse = rss / n_points;
69 n_points * mse.ln() + n_points.ln() * edf
70 }
71 }
72}
73
74fn construct_bspline_knots(t_min: f64, t_max: f64, nknots: usize, order: usize) -> Vec<f64> {
76 let dt = (t_max - t_min) / (nknots - 1) as f64;
77 let mut knots = Vec::with_capacity(nknots + 2 * order);
78 for i in 0..order {
79 knots.push(t_min - (order - i) as f64 * dt);
80 }
81 for i in 0..nknots {
82 knots.push(t_min + i as f64 * dt);
83 }
84 for i in 1..=order {
85 knots.push(t_max + i as f64 * dt);
86 }
87 knots
88}
89
90fn evaluate_order_zero(t_val: f64, knots: &[f64], t_max_knot_idx: usize) -> Vec<f64> {
92 let mut b0 = vec![0.0; knots.len() - 1];
93 for j in 0..(knots.len() - 1) {
94 let in_interval = if j == t_max_knot_idx - 1 {
95 t_val >= knots[j] && t_val <= knots[j + 1]
96 } else {
97 t_val >= knots[j] && t_val < knots[j + 1]
98 };
99 if in_interval {
100 b0[j] = 1.0;
101 break;
102 }
103 }
104 b0
105}
106
107fn bspline_recurrence_step(b: &[f64], knots: &[f64], t_val: f64, k: usize) -> Vec<f64> {
109 (0..(knots.len() - k))
110 .map(|j| {
111 let d1 = knots[j + k - 1] - knots[j];
112 let d2 = knots[j + k] - knots[j + 1];
113 let left = if d1.abs() > 1e-10 {
114 (t_val - knots[j]) / d1 * b[j]
115 } else {
116 0.0
117 };
118 let right = if d2.abs() > 1e-10 {
119 (knots[j + k] - t_val) / d2 * b[j + 1]
120 } else {
121 0.0
122 };
123 left + right
124 })
125 .collect()
126}
127
128pub fn bspline_basis(t: &[f64], nknots: usize, order: usize) -> Vec<f64> {
133 let n = t.len();
134 let nbasis = nknots + order;
135
136 let t_min = t.iter().cloned().fold(f64::INFINITY, f64::min);
137 let t_max = t.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
138
139 let knots = construct_bspline_knots(t_min, t_max, nknots, order);
140 let t_max_knot_idx = order + nknots - 1;
141
142 let mut basis = vec![0.0; n * nbasis];
143
144 for (ti, &t_val) in t.iter().enumerate() {
145 let mut b = evaluate_order_zero(t_val, &knots, t_max_knot_idx);
146
147 for k in 2..=order {
148 b = bspline_recurrence_step(&b, &knots, t_val, k);
149 }
150
151 for j in 0..nbasis {
152 basis[ti + j * n] = b[j];
153 }
154 }
155
156 basis
157}
158
159pub fn fourier_basis(t: &[f64], nbasis: usize) -> Vec<f64> {
164 let t_min = t.iter().cloned().fold(f64::INFINITY, f64::min);
165 let t_max = t.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
166 let period = t_max - t_min;
167 fourier_basis_with_period(t, nbasis, period)
168}
169
170pub fn fourier_basis_with_period(t: &[f64], nbasis: usize, period: f64) -> Vec<f64> {
184 let n = t.len();
185 let t_min = t.iter().cloned().fold(f64::INFINITY, f64::min);
186
187 let mut basis = vec![0.0; n * nbasis];
188
189 for (i, &ti) in t.iter().enumerate() {
190 let x = 2.0 * PI * (ti - t_min) / period;
191
192 basis[i] = 1.0;
193
194 let mut k = 1;
195 let mut freq = 1;
196 while k < nbasis {
197 if k < nbasis {
198 basis[i + k * n] = (freq as f64 * x).sin();
199 k += 1;
200 }
201 if k < nbasis {
202 basis[i + k * n] = (freq as f64 * x).cos();
203 k += 1;
204 }
205 freq += 1;
206 }
207 }
208
209 basis
210}
211
212pub fn difference_matrix(n: usize, order: usize) -> DMatrix<f64> {
214 if order == 0 {
215 return DMatrix::identity(n, n);
216 }
217
218 let mut d = DMatrix::zeros(n - 1, n);
219 for i in 0..(n - 1) {
220 d[(i, i)] = -1.0;
221 d[(i, i + 1)] = 1.0;
222 }
223
224 let mut result = d;
225 for _ in 1..order {
226 if result.nrows() <= 1 {
227 break;
228 }
229 let rows = result.nrows() - 1;
230 let cols = result.ncols();
231 let mut d_next = DMatrix::zeros(rows, cols);
232 for i in 0..rows {
233 for j in 0..cols {
234 d_next[(i, j)] = -result[(i, j)] + result[(i + 1, j)];
235 }
236 }
237 result = d_next;
238 }
239
240 result
241}
242
243pub struct BasisProjectionResult {
245 pub coefficients: FdMatrix,
247 pub n_basis: usize,
249}
250
251pub fn fdata_to_basis_1d(
259 data: &FdMatrix,
260 argvals: &[f64],
261 nbasis: usize,
262 basis_type: i32,
263) -> Option<BasisProjectionResult> {
264 let n = data.nrows();
265 let m = data.ncols();
266 if n == 0 || m == 0 || argvals.len() != m || nbasis < 2 {
267 return None;
268 }
269
270 let basis = if basis_type == 1 {
271 fourier_basis(argvals, nbasis)
272 } else {
273 bspline_basis(argvals, nbasis.saturating_sub(4).max(2), 4)
275 };
276
277 let actual_nbasis = basis.len() / m;
278 let b_mat = DMatrix::from_column_slice(m, actual_nbasis, &basis);
279
280 let btb = &b_mat.transpose() * &b_mat;
281 let btb_inv = svd_pseudoinverse(&btb)?;
282 let proj = btb_inv * b_mat.transpose();
283
284 let coefs: Vec<f64> = iter_maybe_parallel!(0..n)
285 .flat_map(|i| {
286 let curve: Vec<f64> = (0..m).map(|j| data[(i, j)]).collect();
287 (0..actual_nbasis)
288 .map(|k| {
289 let mut sum = 0.0;
290 for j in 0..m {
291 sum += proj[(k, j)] * curve[j];
292 }
293 sum
294 })
295 .collect::<Vec<_>>()
296 })
297 .collect();
298
299 Some(BasisProjectionResult {
300 coefficients: FdMatrix::from_column_major(coefs, n, actual_nbasis).unwrap(),
301 n_basis: actual_nbasis,
302 })
303}
304
305pub fn basis_to_fdata_1d(
307 coefs: &FdMatrix,
308 argvals: &[f64],
309 nbasis: usize,
310 basis_type: i32,
311) -> FdMatrix {
312 let n = coefs.nrows();
313 let coefs_ncols = coefs.ncols();
314 let m = argvals.len();
315 if n == 0 || m == 0 || nbasis < 2 {
316 return FdMatrix::zeros(0, 0);
317 }
318
319 let basis = if basis_type == 1 {
320 fourier_basis(argvals, nbasis)
321 } else {
322 bspline_basis(argvals, nbasis.saturating_sub(4).max(2), 4)
324 };
325
326 let actual_nbasis = basis.len() / m;
327
328 let flat: Vec<f64> = iter_maybe_parallel!(0..n)
329 .flat_map(|i| {
330 (0..m)
331 .map(|j| {
332 let mut sum = 0.0;
333 for k in 0..actual_nbasis.min(coefs_ncols) {
334 sum += coefs[(i, k)] * basis[j + k * m];
335 }
336 sum
337 })
338 .collect::<Vec<_>>()
339 })
340 .collect();
341
342 FdMatrix::from_column_major(flat, n, m).unwrap()
343}
344
345pub struct PsplineFitResult {
347 pub coefficients: FdMatrix,
349 pub fitted: FdMatrix,
351 pub edf: f64,
353 pub rss: f64,
355 pub gcv: f64,
357 pub aic: f64,
359 pub bic: f64,
361 pub n_basis: usize,
363}
364
365pub fn pspline_fit_1d(
367 data: &FdMatrix,
368 argvals: &[f64],
369 nbasis: usize,
370 lambda: f64,
371 order: usize,
372) -> Option<PsplineFitResult> {
373 let n = data.nrows();
374 let m = data.ncols();
375 if n == 0 || m == 0 || nbasis < 2 || argvals.len() != m {
376 return None;
377 }
378
379 let basis = bspline_basis(argvals, nbasis.saturating_sub(4).max(2), 4);
381 let actual_nbasis = basis.len() / m;
382 let b_mat = DMatrix::from_column_slice(m, actual_nbasis, &basis);
383
384 let d = difference_matrix(actual_nbasis, order);
385 let penalty = &d.transpose() * &d;
386
387 let btb = &b_mat.transpose() * &b_mat;
388 let btb_penalized = &btb + lambda * &penalty;
389
390 let btb_inv = svd_pseudoinverse(&btb_penalized)?;
391 let proj = &btb_inv * b_mat.transpose();
392 let h_mat = &b_mat * &proj;
393 let edf: f64 = (0..m).map(|i| h_mat[(i, i)]).sum();
394
395 let mut all_coefs = FdMatrix::zeros(n, actual_nbasis);
396 let mut all_fitted = FdMatrix::zeros(n, m);
397 let mut total_rss = 0.0;
398
399 for i in 0..n {
400 let curve: Vec<f64> = (0..m).map(|j| data[(i, j)]).collect();
401 let curve_vec = DVector::from_vec(curve.clone());
402
403 let bt_y = b_mat.transpose() * &curve_vec;
404 let coefs = &btb_inv * bt_y;
405
406 for k in 0..actual_nbasis {
407 all_coefs[(i, k)] = coefs[k];
408 }
409
410 let fitted = &b_mat * &coefs;
411 for j in 0..m {
412 all_fitted[(i, j)] = fitted[j];
413 let resid = curve[j] - fitted[j];
414 total_rss += resid * resid;
415 }
416 }
417
418 let total_points = (n * m) as f64;
419
420 let gcv_denom = 1.0 - edf / m as f64;
421 let gcv = if gcv_denom.abs() > 1e-10 {
422 (total_rss / total_points) / (gcv_denom * gcv_denom)
423 } else {
424 f64::INFINITY
425 };
426
427 let mse = total_rss / total_points;
428 let aic = total_points * mse.ln() + 2.0 * edf;
429 let bic = total_points * mse.ln() + total_points.ln() * edf;
430
431 Some(PsplineFitResult {
432 coefficients: all_coefs,
433 fitted: all_fitted,
434 edf,
435 rss: total_rss,
436 gcv,
437 aic,
438 bic,
439 n_basis: actual_nbasis,
440 })
441}
442
443pub struct FourierFitResult {
445 pub coefficients: FdMatrix,
447 pub fitted: FdMatrix,
449 pub edf: f64,
451 pub rss: f64,
453 pub gcv: f64,
455 pub aic: f64,
457 pub bic: f64,
459 pub n_basis: usize,
461}
462
463fn compute_fit_criteria(total_rss: f64, total_points: f64, edf: f64, m: usize) -> (f64, f64, f64) {
465 let gcv_denom = 1.0 - edf / m as f64;
466 let gcv = if gcv_denom.abs() > 1e-10 {
467 (total_rss / total_points) / (gcv_denom * gcv_denom)
468 } else {
469 f64::INFINITY
470 };
471 let mse = total_rss / total_points;
472 let aic = total_points * mse.ln() + 2.0 * edf;
473 let bic = total_points * mse.ln() + total_points.ln() * edf;
474 (gcv, aic, bic)
475}
476
477pub fn fourier_fit_1d(data: &FdMatrix, argvals: &[f64], nbasis: usize) -> Option<FourierFitResult> {
490 let n = data.nrows();
491 let m = data.ncols();
492 if n == 0 || m == 0 || nbasis < 3 || argvals.len() != m {
493 return None;
494 }
495
496 let nbasis = if nbasis % 2 == 0 { nbasis + 1 } else { nbasis };
498
499 let basis = fourier_basis(argvals, nbasis);
500 let actual_nbasis = basis.len() / m;
501 let b_mat = DMatrix::from_column_slice(m, actual_nbasis, &basis);
502
503 let btb = &b_mat.transpose() * &b_mat;
504 let btb_inv = svd_pseudoinverse(&btb)?;
505 let proj = &btb_inv * b_mat.transpose();
506 let h_mat = &b_mat * &proj;
507 let edf: f64 = (0..m).map(|i| h_mat[(i, i)]).sum();
508
509 let mut all_coefs = FdMatrix::zeros(n, actual_nbasis);
510 let mut all_fitted = FdMatrix::zeros(n, m);
511 let mut total_rss = 0.0;
512
513 for i in 0..n {
514 let curve: Vec<f64> = (0..m).map(|j| data[(i, j)]).collect();
515 let curve_vec = DVector::from_vec(curve.clone());
516
517 let bt_y = b_mat.transpose() * &curve_vec;
518 let coefs = &btb_inv * bt_y;
519
520 for k in 0..actual_nbasis {
521 all_coefs[(i, k)] = coefs[k];
522 }
523
524 let fitted = &b_mat * &coefs;
525 for j in 0..m {
526 all_fitted[(i, j)] = fitted[j];
527 let resid = curve[j] - fitted[j];
528 total_rss += resid * resid;
529 }
530 }
531
532 let total_points = (n * m) as f64;
533 let (gcv, aic, bic) = compute_fit_criteria(total_rss, total_points, edf, m);
534
535 Some(FourierFitResult {
536 coefficients: all_coefs,
537 fitted: all_fitted,
538 edf,
539 rss: total_rss,
540 gcv,
541 aic,
542 bic,
543 n_basis: actual_nbasis,
544 })
545}
546
547pub fn select_fourier_nbasis_gcv(
560 data: &FdMatrix,
561 argvals: &[f64],
562 min_nbasis: usize,
563 max_nbasis: usize,
564) -> usize {
565 let m = data.ncols();
566 let min_nb = min_nbasis.max(3);
567 let max_nb = max_nbasis.min(m / 2);
569
570 if max_nb <= min_nb {
571 return min_nb;
572 }
573
574 let mut best_nbasis = min_nb;
575 let mut best_gcv = f64::INFINITY;
576
577 let mut nbasis = if min_nb % 2 == 0 { min_nb + 1 } else { min_nb };
579 while nbasis <= max_nb {
580 if let Some(result) = fourier_fit_1d(data, argvals, nbasis) {
581 if result.gcv < best_gcv && result.gcv.is_finite() {
582 best_gcv = result.gcv;
583 best_nbasis = nbasis;
584 }
585 }
586 nbasis += 2;
587 }
588
589 best_nbasis
590}
591
592#[derive(Clone)]
594pub struct SingleCurveSelection {
595 pub basis_type: i32,
597 pub nbasis: usize,
599 pub score: f64,
601 pub coefficients: Vec<f64>,
603 pub fitted: Vec<f64>,
605 pub edf: f64,
607 pub seasonal_detected: bool,
609 pub lambda: f64,
611}
612
613pub struct BasisAutoSelectionResult {
615 pub selections: Vec<SingleCurveSelection>,
617 pub criterion: i32,
619}
620
621fn detect_seasonality_fft(curve: &[f64]) -> bool {
626 use rustfft::{num_complex::Complex, FftPlanner};
627
628 let n = curve.len();
629 if n < 8 {
630 return false;
631 }
632
633 let mean: f64 = curve.iter().sum::<f64>() / n as f64;
635 let mut input: Vec<Complex<f64>> = curve.iter().map(|&x| Complex::new(x - mean, 0.0)).collect();
636
637 let mut planner = FftPlanner::new();
638 let fft = planner.plan_fft_forward(n);
639 fft.process(&mut input);
640
641 let powers: Vec<f64> = input[1..n / 2].iter().map(|c| c.norm_sqr()).collect();
643
644 if powers.is_empty() {
645 return false;
646 }
647
648 let max_power = powers.iter().cloned().fold(0.0_f64, f64::max);
649 let mean_power = powers.iter().sum::<f64>() / powers.len() as f64;
650
651 max_power > 2.0 * mean_power
653}
654
655fn fit_curve_fourier(
657 curve: &[f64],
658 m: usize,
659 argvals: &[f64],
660 nbasis: usize,
661 criterion: i32,
662) -> Option<(f64, Vec<f64>, Vec<f64>, f64)> {
663 let nbasis = if nbasis % 2 == 0 { nbasis + 1 } else { nbasis };
664
665 let basis = fourier_basis(argvals, nbasis);
666 let actual_nbasis = basis.len() / m;
667 let b_mat = DMatrix::from_column_slice(m, actual_nbasis, &basis);
668
669 let btb = &b_mat.transpose() * &b_mat;
670 let btb_inv = svd_pseudoinverse(&btb)?;
671 let proj = &btb_inv * b_mat.transpose();
672 let h_mat = &b_mat * &proj;
673 let edf: f64 = (0..m).map(|i| h_mat[(i, i)]).sum();
674
675 let curve_vec = DVector::from_column_slice(curve);
676 let coefs = &btb_inv * (b_mat.transpose() * &curve_vec);
677 let fitted = &b_mat * &coefs;
678
679 let rss: f64 = (0..m).map(|j| (curve[j] - fitted[j]).powi(2)).sum();
680 let score = compute_model_criterion(rss, m as f64, edf, criterion);
681
682 let coef_vec: Vec<f64> = (0..actual_nbasis).map(|k| coefs[k]).collect();
683 let fitted_vec: Vec<f64> = (0..m).map(|j| fitted[j]).collect();
684
685 Some((score, coef_vec, fitted_vec, edf))
686}
687
688fn fit_curve_pspline(
690 curve: &[f64],
691 m: usize,
692 argvals: &[f64],
693 nbasis: usize,
694 lambda: f64,
695 order: usize,
696 criterion: i32,
697) -> Option<(f64, Vec<f64>, Vec<f64>, f64)> {
698 let basis = bspline_basis(argvals, nbasis.saturating_sub(4).max(2), 4);
699 let actual_nbasis = basis.len() / m;
700 let b_mat = DMatrix::from_column_slice(m, actual_nbasis, &basis);
701
702 let d = difference_matrix(actual_nbasis, order);
703 let penalty = &d.transpose() * &d;
704 let btb = &b_mat.transpose() * &b_mat;
705 let btb_penalized = &btb + lambda * &penalty;
706
707 let btb_inv = svd_pseudoinverse(&btb_penalized)?;
708 let proj = &btb_inv * b_mat.transpose();
709 let h_mat = &b_mat * &proj;
710 let edf: f64 = (0..m).map(|i| h_mat[(i, i)]).sum();
711
712 let curve_vec = DVector::from_column_slice(curve);
713 let coefs = &btb_inv * (b_mat.transpose() * &curve_vec);
714 let fitted = &b_mat * &coefs;
715
716 let rss: f64 = (0..m).map(|j| (curve[j] - fitted[j]).powi(2)).sum();
717 let score = compute_model_criterion(rss, m as f64, edf, criterion);
718
719 let coef_vec: Vec<f64> = (0..actual_nbasis).map(|k| coefs[k]).collect();
720 let fitted_vec: Vec<f64> = (0..m).map(|j| fitted[j]).collect();
721
722 Some((score, coef_vec, fitted_vec, edf))
723}
724
725struct BasisSearchResult {
727 score: f64,
728 nbasis: usize,
729 coefs: Vec<f64>,
730 fitted: Vec<f64>,
731 edf: f64,
732 lambda: f64,
733}
734
735fn search_fourier_basis(
737 curve: &[f64],
738 m: usize,
739 argvals: &[f64],
740 fourier_min: usize,
741 fourier_max: usize,
742 seasonal: bool,
743 criterion: i32,
744) -> Option<BasisSearchResult> {
745 let fourier_start = if seasonal {
746 fourier_min.max(5)
747 } else {
748 fourier_min
749 };
750 let mut nb = if fourier_start % 2 == 0 {
751 fourier_start + 1
752 } else {
753 fourier_start
754 };
755
756 let mut best: Option<BasisSearchResult> = None;
757 while nb <= fourier_max {
758 if let Some((score, coefs, fitted, edf)) =
759 fit_curve_fourier(curve, m, argvals, nb, criterion)
760 {
761 if score.is_finite() && best.as_ref().map_or(true, |b| score < b.score) {
762 best = Some(BasisSearchResult {
763 score,
764 nbasis: nb,
765 coefs,
766 fitted,
767 edf,
768 lambda: f64::NAN,
769 });
770 }
771 }
772 nb += 2;
773 }
774 best
775}
776
777fn try_pspline_fit_update(
779 curve: &[f64],
780 m: usize,
781 argvals: &[f64],
782 nb: usize,
783 lam: f64,
784 criterion: i32,
785 best: &mut Option<BasisSearchResult>,
786) {
787 if let Some((score, coefs, fitted, edf)) =
788 fit_curve_pspline(curve, m, argvals, nb, lam, 2, criterion)
789 {
790 if score.is_finite() && best.as_ref().map_or(true, |b| score < b.score) {
791 *best = Some(BasisSearchResult {
792 score,
793 nbasis: nb,
794 coefs,
795 fitted,
796 edf,
797 lambda: lam,
798 });
799 }
800 }
801}
802
803fn search_pspline_basis(
805 curve: &[f64],
806 m: usize,
807 argvals: &[f64],
808 pspline_min: usize,
809 pspline_max: usize,
810 lambda_grid: &[f64],
811 auto_lambda: bool,
812 lambda: f64,
813 criterion: i32,
814) -> Option<BasisSearchResult> {
815 let mut best: Option<BasisSearchResult> = None;
816 for nb in pspline_min..=pspline_max {
817 let lambdas: Box<dyn Iterator<Item = f64>> = if auto_lambda {
818 Box::new(lambda_grid.iter().copied())
819 } else {
820 Box::new(std::iter::once(lambda))
821 };
822 for lam in lambdas {
823 try_pspline_fit_update(curve, m, argvals, nb, lam, criterion, &mut best);
824 }
825 }
826 best
827}
828
829pub fn select_basis_auto_1d(
847 data: &FdMatrix,
848 argvals: &[f64],
849 criterion: i32,
850 nbasis_min: usize,
851 nbasis_max: usize,
852 lambda_pspline: f64,
853 use_seasonal_hint: bool,
854) -> BasisAutoSelectionResult {
855 let n = data.nrows();
856 let m = data.ncols();
857 let fourier_min = if nbasis_min > 0 { nbasis_min.max(3) } else { 3 };
858 let fourier_max = if nbasis_max > 0 {
859 nbasis_max.min(m / 3).min(25)
860 } else {
861 (m / 3).min(25)
862 };
863 let pspline_min = if nbasis_min > 0 { nbasis_min.max(6) } else { 6 };
864 let pspline_max = if nbasis_max > 0 {
865 nbasis_max.min(m / 2).min(40)
866 } else {
867 (m / 2).min(40)
868 };
869
870 let lambda_grid = [0.001, 0.01, 0.1, 1.0, 10.0, 100.0];
871 let auto_lambda = lambda_pspline < 0.0;
872
873 let selections: Vec<SingleCurveSelection> = iter_maybe_parallel!(0..n)
874 .map(|i| {
875 let curve: Vec<f64> = (0..m).map(|j| data[(i, j)]).collect();
876 let seasonal_detected = if use_seasonal_hint {
877 detect_seasonality_fft(&curve)
878 } else {
879 false
880 };
881
882 let fourier_best = search_fourier_basis(
883 &curve,
884 m,
885 argvals,
886 fourier_min,
887 fourier_max,
888 seasonal_detected,
889 criterion,
890 );
891 let pspline_best = search_pspline_basis(
892 &curve,
893 m,
894 argvals,
895 pspline_min,
896 pspline_max,
897 &lambda_grid,
898 auto_lambda,
899 lambda_pspline,
900 criterion,
901 );
902
903 let (basis_type, result) = match (fourier_best, pspline_best) {
905 (Some(f), Some(p)) => {
906 if f.score <= p.score {
907 (1i32, f)
908 } else {
909 (0i32, p)
910 }
911 }
912 (Some(f), None) => (1, f),
913 (None, Some(p)) => (0, p),
914 (None, None) => {
915 return SingleCurveSelection {
916 basis_type: 0,
917 nbasis: pspline_min,
918 score: f64::INFINITY,
919 coefficients: Vec::new(),
920 fitted: Vec::new(),
921 edf: 0.0,
922 seasonal_detected,
923 lambda: f64::NAN,
924 };
925 }
926 };
927
928 SingleCurveSelection {
929 basis_type,
930 nbasis: result.nbasis,
931 score: result.score,
932 coefficients: result.coefs,
933 fitted: result.fitted,
934 edf: result.edf,
935 seasonal_detected,
936 lambda: result.lambda,
937 }
938 })
939 .collect();
940
941 BasisAutoSelectionResult {
942 selections,
943 criterion,
944 }
945}
946
947#[cfg(test)]
948mod tests {
949 use super::*;
950 use std::f64::consts::PI;
951
952 fn uniform_grid(n: usize) -> Vec<f64> {
954 (0..n).map(|i| i as f64 / (n - 1) as f64).collect()
955 }
956
957 fn sine_wave(t: &[f64], freq: f64) -> Vec<f64> {
959 t.iter().map(|&ti| (2.0 * PI * freq * ti).sin()).collect()
960 }
961
962 #[test]
965 fn test_bspline_basis_dimensions() {
966 let t = uniform_grid(50);
967 let nknots = 10;
968 let order = 4;
969 let basis = bspline_basis(&t, nknots, order);
970
971 let expected_nbasis = nknots + order;
972 assert_eq!(basis.len(), t.len() * expected_nbasis);
973 }
974
975 #[test]
976 fn test_bspline_basis_partition_of_unity() {
977 let t = uniform_grid(50);
979 let nknots = 8;
980 let order = 4;
981 let basis = bspline_basis(&t, nknots, order);
982
983 let nbasis = nknots + order;
984 for i in 0..t.len() {
985 let sum: f64 = (0..nbasis).map(|j| basis[i + j * t.len()]).sum();
986 assert!(
987 (sum - 1.0).abs() < 1e-10,
988 "B-spline partition of unity failed at point {}: sum = {}",
989 i,
990 sum
991 );
992 }
993 }
994
995 #[test]
996 fn test_bspline_basis_non_negative() {
997 let t = uniform_grid(50);
998 let basis = bspline_basis(&t, 8, 4);
999
1000 for &val in &basis {
1001 assert!(val >= -1e-10, "B-spline values should be non-negative");
1002 }
1003 }
1004
1005 #[test]
1006 fn test_bspline_basis_boundary() {
1007 let t = vec![0.0, 0.5, 1.0];
1009 let basis = bspline_basis(&t, 5, 4);
1010
1011 for &val in &basis {
1013 assert!(val.is_finite(), "B-spline should produce finite values");
1014 }
1015 }
1016
1017 #[test]
1020 fn test_fourier_basis_dimensions() {
1021 let t = uniform_grid(50);
1022 let nbasis = 7;
1023 let basis = fourier_basis(&t, nbasis);
1024
1025 assert_eq!(basis.len(), t.len() * nbasis);
1026 }
1027
1028 #[test]
1029 fn test_fourier_basis_constant_first_column() {
1030 let t = uniform_grid(50);
1031 let nbasis = 7;
1032 let basis = fourier_basis(&t, nbasis);
1033
1034 let first_val = basis[0];
1036 for i in 0..t.len() {
1037 assert!(
1038 (basis[i] - first_val).abs() < 1e-10,
1039 "First Fourier column should be constant"
1040 );
1041 }
1042 }
1043
1044 #[test]
1045 fn test_fourier_basis_sin_cos_range() {
1046 let t = uniform_grid(100);
1047 let nbasis = 11;
1048 let basis = fourier_basis(&t, nbasis);
1049
1050 for &val in &basis {
1052 assert!((-1.0 - 1e-10..=1.0 + 1e-10).contains(&val));
1053 }
1054 }
1055
1056 #[test]
1057 fn test_fourier_basis_with_period() {
1058 let t = uniform_grid(100);
1059 let nbasis = 5;
1060 let period = 0.5;
1061 let basis = fourier_basis_with_period(&t, nbasis, period);
1062
1063 assert_eq!(basis.len(), t.len() * nbasis);
1064 let first_val = basis[0];
1066 for i in 0..t.len() {
1067 assert!((basis[i] - first_val).abs() < 1e-10);
1068 }
1069 }
1070
1071 #[test]
1072 fn test_fourier_basis_period_affects_frequency() {
1073 let t = uniform_grid(100);
1074 let nbasis = 5;
1075
1076 let basis1 = fourier_basis_with_period(&t, nbasis, 1.0);
1077 let basis2 = fourier_basis_with_period(&t, nbasis, 0.5);
1078
1079 let n = t.len();
1081 let mut any_different = false;
1082 for i in 0..n {
1083 if (basis1[i + n] - basis2[i + n]).abs() > 1e-10 {
1085 any_different = true;
1086 break;
1087 }
1088 }
1089 assert!(
1090 any_different,
1091 "Different periods should produce different bases"
1092 );
1093 }
1094
1095 #[test]
1098 fn test_difference_matrix_order_zero() {
1099 let d = difference_matrix(5, 0);
1100 assert_eq!(d.nrows(), 5);
1101 assert_eq!(d.ncols(), 5);
1102
1103 for i in 0..5 {
1105 for j in 0..5 {
1106 let expected = if i == j { 1.0 } else { 0.0 };
1107 assert!((d[(i, j)] - expected).abs() < 1e-10);
1108 }
1109 }
1110 }
1111
1112 #[test]
1113 fn test_difference_matrix_first_order() {
1114 let d = difference_matrix(5, 1);
1115 assert_eq!(d.nrows(), 4);
1116 assert_eq!(d.ncols(), 5);
1117
1118 let x = DVector::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1120 let dx = &d * x;
1121 for i in 0..4 {
1122 assert!((dx[i] - 1.0).abs() < 1e-10);
1123 }
1124 }
1125
1126 #[test]
1127 fn test_difference_matrix_second_order() {
1128 let d = difference_matrix(5, 2);
1129 assert_eq!(d.nrows(), 3);
1130 assert_eq!(d.ncols(), 5);
1131
1132 let x = DVector::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1134 let dx = &d * x;
1135 for i in 0..3 {
1136 assert!(dx[i].abs() < 1e-10, "Second diff of linear should be zero");
1137 }
1138 }
1139
1140 #[test]
1141 fn test_difference_matrix_quadratic() {
1142 let d = difference_matrix(5, 2);
1143
1144 let x = DVector::from_vec(vec![1.0, 4.0, 9.0, 16.0, 25.0]);
1146 let dx = &d * x;
1147 for i in 0..3 {
1148 assert!(
1149 (dx[i] - 2.0).abs() < 1e-10,
1150 "Second diff of squares should be 2"
1151 );
1152 }
1153 }
1154
1155 fn make_matrix(flat_row_major: &[f64], n: usize, m: usize) -> FdMatrix {
1161 let mut col_major = vec![0.0; n * m];
1162 for i in 0..n {
1163 for j in 0..m {
1164 col_major[i + j * n] = flat_row_major[i * m + j];
1165 }
1166 }
1167 FdMatrix::from_column_major(col_major, n, m).unwrap()
1168 }
1169
1170 #[test]
1171 fn test_fdata_to_basis_1d_bspline() {
1172 let t = uniform_grid(50);
1173 let n = 5;
1174 let m = t.len();
1175
1176 let flat: Vec<f64> = (0..n)
1178 .flat_map(|i| t.iter().map(move |&ti| ti + i as f64 * 0.1))
1179 .collect();
1180 let data = make_matrix(&flat, n, m);
1181
1182 let result = fdata_to_basis_1d(&data, &t, 10, 0);
1183 assert!(result.is_some());
1184
1185 let res = result.unwrap();
1186 assert!(res.n_basis > 0);
1187 assert_eq!(res.coefficients.nrows(), n);
1188 assert_eq!(res.coefficients.ncols(), res.n_basis);
1189 }
1190
1191 #[test]
1192 fn test_fdata_to_basis_1d_fourier() {
1193 let t = uniform_grid(50);
1194 let n = 5;
1195 let m = t.len();
1196
1197 let flat: Vec<f64> = (0..n).flat_map(|_| sine_wave(&t, 2.0)).collect();
1199 let data = make_matrix(&flat, n, m);
1200
1201 let result = fdata_to_basis_1d(&data, &t, 7, 1);
1202 assert!(result.is_some());
1203
1204 let res = result.unwrap();
1205 assert_eq!(res.n_basis, 7);
1206 }
1207
1208 #[test]
1209 fn test_fdata_to_basis_1d_invalid_input() {
1210 let t = uniform_grid(50);
1211
1212 let empty = FdMatrix::zeros(0, 50);
1214 let result = fdata_to_basis_1d(&empty, &t, 10, 0);
1215 assert!(result.is_none());
1216
1217 let data = FdMatrix::zeros(1, 50);
1219 let result = fdata_to_basis_1d(&data, &t, 1, 0);
1220 assert!(result.is_none());
1221 }
1222
1223 #[test]
1224 fn test_basis_roundtrip() {
1225 let t = uniform_grid(100);
1226 let n = 1;
1227 let m = t.len();
1228
1229 let raw = sine_wave(&t, 1.0);
1231 let data = FdMatrix::from_column_major(raw.clone(), n, m).unwrap();
1232
1233 let proj = fdata_to_basis_1d(&data, &t, 5, 1).unwrap();
1235
1236 let reconstructed = basis_to_fdata_1d(&proj.coefficients, &t, proj.n_basis, 1);
1238
1239 let mut max_error = 0.0;
1241 for j in 0..m {
1242 let err = (raw[j] - reconstructed[(0, j)]).abs();
1243 if err > max_error {
1244 max_error = err;
1245 }
1246 }
1247 assert!(max_error < 0.5, "Roundtrip error too large: {}", max_error);
1248 }
1249
1250 #[test]
1251 fn test_basis_to_fdata_empty_input() {
1252 let empty = FdMatrix::zeros(0, 0);
1253 let result = basis_to_fdata_1d(&empty, &[], 5, 0);
1254 assert!(result.is_empty());
1255 }
1256
1257 #[test]
1260 fn test_pspline_fit_1d_basic() {
1261 let t = uniform_grid(50);
1262 let n = 3;
1263 let m = t.len();
1264
1265 let flat: Vec<f64> = (0..n)
1267 .flat_map(|i| {
1268 t.iter()
1269 .enumerate()
1270 .map(move |(j, &ti)| (2.0 * PI * ti).sin() + 0.1 * (i * j) as f64 % 1.0)
1271 })
1272 .collect();
1273 let data = make_matrix(&flat, n, m);
1274
1275 let result = pspline_fit_1d(&data, &t, 15, 1.0, 2);
1276 assert!(result.is_some());
1277
1278 let res = result.unwrap();
1279 assert!(res.n_basis > 0);
1280 assert_eq!(res.fitted.nrows(), n);
1281 assert_eq!(res.fitted.ncols(), m);
1282 assert!(res.rss >= 0.0);
1283 assert!(res.edf > 0.0);
1284 assert!(res.gcv.is_finite());
1285 }
1286
1287 #[test]
1288 fn test_pspline_fit_1d_smoothness() {
1289 let t = uniform_grid(50);
1290 let n = 1;
1291 let m = t.len();
1292
1293 let raw: Vec<f64> = t
1295 .iter()
1296 .enumerate()
1297 .map(|(i, &ti)| (2.0 * PI * ti).sin() + 0.3 * ((i * 17) % 100) as f64 / 100.0)
1298 .collect();
1299 let data = FdMatrix::from_column_major(raw, n, m).unwrap();
1300
1301 let low_lambda = pspline_fit_1d(&data, &t, 15, 0.01, 2).unwrap();
1302 let high_lambda = pspline_fit_1d(&data, &t, 15, 100.0, 2).unwrap();
1303
1304 assert!(high_lambda.edf < low_lambda.edf);
1306 }
1307
1308 #[test]
1309 fn test_pspline_fit_1d_invalid_input() {
1310 let t = uniform_grid(50);
1311 let empty = FdMatrix::zeros(0, 50);
1312 let result = pspline_fit_1d(&empty, &t, 15, 1.0, 2);
1313 assert!(result.is_none());
1314 }
1315
1316 #[test]
1319 fn test_fourier_fit_1d_sine_wave() {
1320 let t = uniform_grid(100);
1321 let n = 1;
1322 let m = t.len();
1323
1324 let raw = sine_wave(&t, 2.0);
1326 let data = FdMatrix::from_column_major(raw, n, m).unwrap();
1327
1328 let result = fourier_fit_1d(&data, &t, 11);
1329 assert!(result.is_some());
1330
1331 let res = result.unwrap();
1332 assert!(res.rss < 1e-6, "Pure sine should have near-zero RSS");
1333 }
1334
1335 #[test]
1336 fn test_fourier_fit_1d_makes_nbasis_odd() {
1337 let t = uniform_grid(50);
1338 let raw = sine_wave(&t, 1.0);
1339 let data = FdMatrix::from_column_major(raw, 1, t.len()).unwrap();
1340
1341 let result = fourier_fit_1d(&data, &t, 6);
1343 assert!(result.is_some());
1344
1345 let res = result.unwrap();
1347 assert!(res.n_basis % 2 == 1);
1348 }
1349
1350 #[test]
1351 fn test_fourier_fit_1d_criteria() {
1352 let t = uniform_grid(50);
1353 let raw = sine_wave(&t, 2.0);
1354 let data = FdMatrix::from_column_major(raw, 1, t.len()).unwrap();
1355
1356 let result = fourier_fit_1d(&data, &t, 9).unwrap();
1357
1358 assert!(result.gcv.is_finite());
1360 assert!(result.aic.is_finite());
1361 assert!(result.bic.is_finite());
1362 }
1363
1364 #[test]
1365 fn test_fourier_fit_1d_invalid_nbasis() {
1366 let t = uniform_grid(50);
1367 let raw = sine_wave(&t, 1.0);
1368 let data = FdMatrix::from_column_major(raw, 1, t.len()).unwrap();
1369
1370 let result = fourier_fit_1d(&data, &t, 2);
1372 assert!(result.is_none());
1373 }
1374
1375 #[test]
1378 fn test_select_fourier_nbasis_gcv_range() {
1379 let t = uniform_grid(100);
1380 let raw = sine_wave(&t, 3.0);
1381 let data = FdMatrix::from_column_major(raw, 1, t.len()).unwrap();
1382
1383 let best = select_fourier_nbasis_gcv(&data, &t, 3, 15);
1384
1385 assert!((3..=15).contains(&best));
1386 assert!(best % 2 == 1, "Selected nbasis should be odd");
1387 }
1388
1389 #[test]
1390 fn test_select_fourier_nbasis_gcv_respects_min() {
1391 let t = uniform_grid(50);
1392 let raw = sine_wave(&t, 1.0);
1393 let data = FdMatrix::from_column_major(raw, 1, t.len()).unwrap();
1394
1395 let best = select_fourier_nbasis_gcv(&data, &t, 7, 15);
1396 assert!(best >= 7);
1397 }
1398
1399 #[test]
1402 fn test_select_basis_auto_1d_returns_results() {
1403 let t = uniform_grid(50);
1404 let n = 3;
1405 let m = t.len();
1406
1407 let flat: Vec<f64> = (0..n).flat_map(|i| sine_wave(&t, 1.0 + i as f64)).collect();
1408 let data = make_matrix(&flat, n, m);
1409
1410 let result = select_basis_auto_1d(&data, &t, 0, 5, 15, 1.0, false);
1411
1412 assert_eq!(result.selections.len(), n);
1413 for sel in &result.selections {
1414 assert!(sel.nbasis >= 3);
1415 assert!(!sel.coefficients.is_empty());
1416 assert_eq!(sel.fitted.len(), m);
1417 }
1418 }
1419
1420 #[test]
1421 fn test_select_basis_auto_1d_seasonal_hint() {
1422 let t = uniform_grid(100);
1423 let n = 1;
1424 let m = t.len();
1425
1426 let raw = sine_wave(&t, 5.0);
1428 let data = FdMatrix::from_column_major(raw, n, m).unwrap();
1429
1430 let result = select_basis_auto_1d(&data, &t, 0, 0, 0, -1.0, true);
1431
1432 assert_eq!(result.selections.len(), 1);
1433 assert!(result.selections[0].seasonal_detected);
1434 }
1435
1436 #[test]
1437 fn test_select_basis_auto_1d_non_seasonal() {
1438 let t = uniform_grid(50);
1439 let n = 1;
1440 let m = t.len();
1441
1442 let raw: Vec<f64> = vec![1.0; m];
1444 let data = FdMatrix::from_column_major(raw, n, m).unwrap();
1445
1446 let result = select_basis_auto_1d(&data, &t, 0, 0, 0, -1.0, true);
1447
1448 assert!(!result.selections[0].seasonal_detected);
1450 }
1451
1452 #[test]
1453 fn test_select_basis_auto_1d_criterion_options() {
1454 let t = uniform_grid(50);
1455 let raw = sine_wave(&t, 2.0);
1456 let data = FdMatrix::from_column_major(raw, 1, t.len()).unwrap();
1457
1458 let gcv_result = select_basis_auto_1d(&data, &t, 0, 0, 0, 1.0, false);
1460 let aic_result = select_basis_auto_1d(&data, &t, 1, 0, 0, 1.0, false);
1461 let bic_result = select_basis_auto_1d(&data, &t, 2, 0, 0, 1.0, false);
1462
1463 assert_eq!(gcv_result.criterion, 0);
1464 assert_eq!(aic_result.criterion, 1);
1465 assert_eq!(bic_result.criterion, 2);
1466 }
1467}