fdars_core/
basis.rs

1//! Basis representation functions for functional data.
2//!
3//! This module provides B-spline and Fourier basis expansions for representing
4//! functional data in a finite-dimensional basis.
5
6use nalgebra::{DMatrix, DVector, SVD};
7use rayon::prelude::*;
8use std::f64::consts::PI;
9
10/// Compute B-spline basis matrix for given knots and grid points.
11///
12/// Creates a B-spline basis with uniformly spaced knots extended beyond the data range.
13/// For order k and nknots interior knots, produces nknots + order basis functions.
14pub fn bspline_basis(t: &[f64], nknots: usize, order: usize) -> Vec<f64> {
15    let n = t.len();
16    // Total knots: order (left) + nknots (interior) + order (right) = 2*order + nknots
17    // Number of B-spline basis functions: total_knots - order = nknots + order
18    let nbasis = nknots + order;
19
20    let t_min = t.iter().cloned().fold(f64::INFINITY, f64::min);
21    let t_max = t.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
22    let dt = (t_max - t_min) / (nknots - 1) as f64;
23
24    let mut knots = Vec::with_capacity(nknots + 2 * order);
25    for i in 0..order {
26        knots.push(t_min - (order - i) as f64 * dt);
27    }
28    for i in 0..nknots {
29        knots.push(t_min + i as f64 * dt);
30    }
31    for i in 1..=order {
32        knots.push(t_max + i as f64 * dt);
33    }
34
35    // Index of t_max in knot vector: it's the last interior knot
36    let t_max_knot_idx = order + nknots - 1;
37
38    let mut basis = vec![0.0; n * nbasis];
39
40    for (ti, &t_val) in t.iter().enumerate() {
41        let mut b0 = vec![0.0; knots.len() - 1];
42
43        // Find which interval t_val belongs to
44        // Use half-open intervals [knots[j], knots[j+1]) except at t_max
45        // where we use the closed interval [knots[j], knots[j+1]]
46        for j in 0..(knots.len() - 1) {
47            let in_interval = if j == t_max_knot_idx - 1 {
48                // Last interior interval: use closed [t_max - dt, t_max]
49                t_val >= knots[j] && t_val <= knots[j + 1]
50            } else {
51                // Normal half-open interval [knots[j], knots[j+1])
52                t_val >= knots[j] && t_val < knots[j + 1]
53            };
54
55            if in_interval {
56                b0[j] = 1.0;
57                break;
58            }
59        }
60
61        let mut b = b0;
62        for k in 2..=order {
63            let mut b_new = vec![0.0; knots.len() - k];
64            for j in 0..(knots.len() - k) {
65                let d1 = knots[j + k - 1] - knots[j];
66                let d2 = knots[j + k] - knots[j + 1];
67
68                let left = if d1.abs() > 1e-10 {
69                    (t_val - knots[j]) / d1 * b[j]
70                } else {
71                    0.0
72                };
73                let right = if d2.abs() > 1e-10 {
74                    (knots[j + k] - t_val) / d2 * b[j + 1]
75                } else {
76                    0.0
77                };
78                b_new[j] = left + right;
79            }
80            b = b_new;
81        }
82
83        for j in 0..nbasis {
84            basis[ti + j * n] = b[j];
85        }
86    }
87
88    basis
89}
90
91/// Compute Fourier basis matrix.
92///
93/// The period is automatically set to the range of evaluation points (t_max - t_min).
94/// For explicit period control, use `fourier_basis_with_period`.
95pub fn fourier_basis(t: &[f64], nbasis: usize) -> Vec<f64> {
96    let t_min = t.iter().cloned().fold(f64::INFINITY, f64::min);
97    let t_max = t.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
98    let period = t_max - t_min;
99    fourier_basis_with_period(t, nbasis, period)
100}
101
102/// Compute Fourier basis matrix with explicit period.
103///
104/// This function creates a Fourier basis expansion where the period can be specified
105/// independently of the evaluation range. This is essential for seasonal analysis
106/// where the seasonal period may differ from the observation window.
107///
108/// # Arguments
109/// * `t` - Evaluation points
110/// * `nbasis` - Number of basis functions (1 constant + pairs of sin/cos)
111/// * `period` - The period for the Fourier basis
112///
113/// # Returns
114/// Column-major matrix (n_points x nbasis) stored as flat vector
115pub fn fourier_basis_with_period(t: &[f64], nbasis: usize, period: f64) -> Vec<f64> {
116    let n = t.len();
117    let t_min = t.iter().cloned().fold(f64::INFINITY, f64::min);
118
119    let mut basis = vec![0.0; n * nbasis];
120
121    for (i, &ti) in t.iter().enumerate() {
122        let x = 2.0 * PI * (ti - t_min) / period;
123
124        basis[i] = 1.0;
125
126        let mut k = 1;
127        let mut freq = 1;
128        while k < nbasis {
129            if k < nbasis {
130                basis[i + k * n] = (freq as f64 * x).sin();
131                k += 1;
132            }
133            if k < nbasis {
134                basis[i + k * n] = (freq as f64 * x).cos();
135                k += 1;
136            }
137            freq += 1;
138        }
139    }
140
141    basis
142}
143
144/// Compute difference matrix for P-spline penalty.
145pub fn difference_matrix(n: usize, order: usize) -> DMatrix<f64> {
146    if order == 0 {
147        return DMatrix::identity(n, n);
148    }
149
150    let mut d = DMatrix::zeros(n - 1, n);
151    for i in 0..(n - 1) {
152        d[(i, i)] = -1.0;
153        d[(i, i + 1)] = 1.0;
154    }
155
156    let mut result = d;
157    for _ in 1..order {
158        if result.nrows() <= 1 {
159            break;
160        }
161        let rows = result.nrows() - 1;
162        let cols = result.ncols();
163        let mut d_next = DMatrix::zeros(rows, cols);
164        for i in 0..rows {
165            for j in 0..cols {
166                d_next[(i, j)] = -result[(i, j)] + result[(i + 1, j)];
167            }
168        }
169        result = d_next;
170    }
171
172    result
173}
174
175/// Result of basis projection.
176pub struct BasisProjectionResult {
177    /// Coefficient matrix (n_samples x n_basis)
178    pub coefficients: Vec<f64>,
179    /// Number of basis functions used
180    pub n_basis: usize,
181}
182
183/// Project functional data to basis coefficients.
184///
185/// # Arguments
186/// * `data` - Column-major matrix (n x m)
187/// * `n` - Number of samples
188/// * `m` - Number of evaluation points
189/// * `argvals` - Evaluation points
190/// * `nbasis` - Number of basis functions
191/// * `basis_type` - 0 = B-spline, 1 = Fourier
192pub fn fdata_to_basis_1d(
193    data: &[f64],
194    n: usize,
195    m: usize,
196    argvals: &[f64],
197    nbasis: usize,
198    basis_type: i32,
199) -> Option<BasisProjectionResult> {
200    if n == 0 || m == 0 || argvals.len() != m || nbasis < 2 {
201        return None;
202    }
203
204    let basis = if basis_type == 1 {
205        fourier_basis(argvals, nbasis)
206    } else {
207        // For order 4 B-splines: nbasis = nknots + order, so nknots = nbasis - 4
208        bspline_basis(argvals, nbasis.saturating_sub(4).max(2), 4)
209    };
210
211    let actual_nbasis = basis.len() / m;
212    let b_mat = DMatrix::from_column_slice(m, actual_nbasis, &basis);
213
214    let btb = &b_mat.transpose() * &b_mat;
215    let btb_svd = SVD::new(btb, true, true);
216
217    let max_sv = btb_svd.singular_values.iter().cloned().fold(0.0, f64::max);
218    let eps = 1e-10 * max_sv;
219
220    let s_inv: Vec<f64> = btb_svd
221        .singular_values
222        .iter()
223        .map(|&s| if s > eps { 1.0 / s } else { 0.0 })
224        .collect();
225
226    let v = btb_svd.v_t.as_ref()?.transpose();
227    let u_t = btb_svd.u.as_ref()?.transpose();
228
229    let mut btb_inv = DMatrix::zeros(actual_nbasis, actual_nbasis);
230    for i in 0..actual_nbasis {
231        for j in 0..actual_nbasis {
232            let mut sum = 0.0;
233            for k in 0..actual_nbasis.min(s_inv.len()) {
234                sum += v[(i, k)] * s_inv[k] * u_t[(k, j)];
235            }
236            btb_inv[(i, j)] = sum;
237        }
238    }
239
240    let proj = btb_inv * b_mat.transpose();
241
242    let coefs: Vec<f64> = (0..n)
243        .into_par_iter()
244        .flat_map(|i| {
245            let curve: Vec<f64> = (0..m).map(|j| data[i + j * n]).collect();
246            (0..actual_nbasis)
247                .map(|k| {
248                    let mut sum = 0.0;
249                    for j in 0..m {
250                        sum += proj[(k, j)] * curve[j];
251                    }
252                    sum
253                })
254                .collect::<Vec<_>>()
255        })
256        .collect();
257
258    Some(BasisProjectionResult {
259        coefficients: coefs,
260        n_basis: actual_nbasis,
261    })
262}
263
264/// Reconstruct functional data from basis coefficients.
265pub fn basis_to_fdata_1d(
266    coefs: &[f64],
267    n: usize,
268    coefs_ncols: usize,
269    argvals: &[f64],
270    nbasis: usize,
271    basis_type: i32,
272) -> Vec<f64> {
273    let m = argvals.len();
274    if n == 0 || m == 0 || nbasis < 2 {
275        return Vec::new();
276    }
277
278    let basis = if basis_type == 1 {
279        fourier_basis(argvals, nbasis)
280    } else {
281        // For order 4 B-splines: nbasis = nknots + order, so nknots = nbasis - 4
282        bspline_basis(argvals, nbasis.saturating_sub(4).max(2), 4)
283    };
284
285    let actual_nbasis = basis.len() / m;
286
287    (0..n)
288        .into_par_iter()
289        .flat_map(|i| {
290            (0..m)
291                .map(|j| {
292                    let mut sum = 0.0;
293                    for k in 0..actual_nbasis.min(coefs_ncols) {
294                        sum += coefs[i + k * n] * basis[j + k * m];
295                    }
296                    sum
297                })
298                .collect::<Vec<_>>()
299        })
300        .collect()
301}
302
303/// Result of P-spline fitting.
304pub struct PsplineFitResult {
305    /// Coefficient matrix
306    pub coefficients: Vec<f64>,
307    /// Fitted values
308    pub fitted: Vec<f64>,
309    /// Effective degrees of freedom
310    pub edf: f64,
311    /// Residual sum of squares
312    pub rss: f64,
313    /// GCV score
314    pub gcv: f64,
315    /// AIC
316    pub aic: f64,
317    /// BIC
318    pub bic: f64,
319    /// Number of basis functions
320    pub n_basis: usize,
321}
322
323/// Fit P-splines to functional data.
324pub fn pspline_fit_1d(
325    data: &[f64],
326    n: usize,
327    m: usize,
328    argvals: &[f64],
329    nbasis: usize,
330    lambda: f64,
331    order: usize,
332) -> Option<PsplineFitResult> {
333    if n == 0 || m == 0 || nbasis < 2 || argvals.len() != m {
334        return None;
335    }
336
337    // For order 4 B-splines: nbasis = nknots + order, so nknots = nbasis - 4
338    let basis = bspline_basis(argvals, nbasis.saturating_sub(4).max(2), 4);
339    let actual_nbasis = basis.len() / m;
340    let b_mat = DMatrix::from_column_slice(m, actual_nbasis, &basis);
341
342    let d = difference_matrix(actual_nbasis, order);
343    let penalty = &d.transpose() * &d;
344
345    let btb = &b_mat.transpose() * &b_mat;
346    let btb_penalized = &btb + lambda * &penalty;
347
348    // Use SVD pseudoinverse for robustness with singular matrices
349    let svd = SVD::new(btb_penalized.clone(), true, true);
350    let max_sv = svd.singular_values.iter().cloned().fold(0.0_f64, f64::max);
351    let eps = 1e-10 * max_sv;
352
353    let u = svd.u.as_ref()?;
354    let v_t = svd.v_t.as_ref()?;
355
356    let s_inv: Vec<f64> = svd
357        .singular_values
358        .iter()
359        .map(|&s| if s > eps { 1.0 / s } else { 0.0 })
360        .collect();
361
362    let mut btb_inv = DMatrix::zeros(actual_nbasis, actual_nbasis);
363    for i in 0..actual_nbasis {
364        for j in 0..actual_nbasis {
365            let mut sum = 0.0;
366            for k in 0..actual_nbasis.min(s_inv.len()) {
367                sum += v_t[(k, i)] * s_inv[k] * u[(j, k)];
368            }
369            btb_inv[(i, j)] = sum;
370        }
371    }
372
373    let proj = &btb_inv * b_mat.transpose();
374    let h_mat = &b_mat * &proj;
375    let edf: f64 = (0..m).map(|i| h_mat[(i, i)]).sum();
376
377    let mut all_coefs = vec![0.0; n * actual_nbasis];
378    let mut all_fitted = vec![0.0; n * m];
379    let mut total_rss = 0.0;
380
381    for i in 0..n {
382        let curve: Vec<f64> = (0..m).map(|j| data[i + j * n]).collect();
383        let curve_vec = DVector::from_vec(curve.clone());
384
385        let bt_y = b_mat.transpose() * &curve_vec;
386        let coefs = &btb_inv * bt_y;
387
388        for k in 0..actual_nbasis {
389            all_coefs[i + k * n] = coefs[k];
390        }
391
392        let fitted = &b_mat * &coefs;
393        for j in 0..m {
394            all_fitted[i + j * n] = fitted[j];
395            let resid = curve[j] - fitted[j];
396            total_rss += resid * resid;
397        }
398    }
399
400    let total_points = (n * m) as f64;
401
402    let gcv_denom = 1.0 - edf / m as f64;
403    let gcv = if gcv_denom.abs() > 1e-10 {
404        (total_rss / total_points) / (gcv_denom * gcv_denom)
405    } else {
406        f64::INFINITY
407    };
408
409    let mse = total_rss / total_points;
410    let aic = total_points * mse.ln() + 2.0 * edf;
411    let bic = total_points * mse.ln() + total_points.ln() * edf;
412
413    Some(PsplineFitResult {
414        coefficients: all_coefs,
415        fitted: all_fitted,
416        edf,
417        rss: total_rss,
418        gcv,
419        aic,
420        bic,
421        n_basis: actual_nbasis,
422    })
423}
424
425/// Result of Fourier basis fitting.
426pub struct FourierFitResult {
427    /// Coefficient matrix
428    pub coefficients: Vec<f64>,
429    /// Fitted values
430    pub fitted: Vec<f64>,
431    /// Effective degrees of freedom (equals nbasis for unpenalized fit)
432    pub edf: f64,
433    /// Residual sum of squares
434    pub rss: f64,
435    /// GCV score
436    pub gcv: f64,
437    /// AIC
438    pub aic: f64,
439    /// BIC
440    pub bic: f64,
441    /// Number of basis functions
442    pub n_basis: usize,
443}
444
445/// Fit Fourier basis to functional data using least squares.
446///
447/// Projects data onto Fourier basis and reconstructs fitted values.
448/// Unlike P-splines, this uses unpenalized least squares projection.
449///
450/// # Arguments
451/// * `data` - Column-major matrix (n x m)
452/// * `n` - Number of samples
453/// * `m` - Number of evaluation points
454/// * `argvals` - Evaluation points
455/// * `nbasis` - Number of Fourier basis functions (should be odd: 1 constant + pairs of sin/cos)
456///
457/// # Returns
458/// FourierFitResult with coefficients, fitted values, and model selection criteria
459pub fn fourier_fit_1d(
460    data: &[f64],
461    n: usize,
462    m: usize,
463    argvals: &[f64],
464    nbasis: usize,
465) -> Option<FourierFitResult> {
466    if n == 0 || m == 0 || nbasis < 3 || argvals.len() != m {
467        return None;
468    }
469
470    // Ensure nbasis is odd (1 constant + pairs of sin/cos)
471    let nbasis = if nbasis % 2 == 0 { nbasis + 1 } else { nbasis };
472
473    let basis = fourier_basis(argvals, nbasis);
474    let actual_nbasis = basis.len() / m;
475    let b_mat = DMatrix::from_column_slice(m, actual_nbasis, &basis);
476
477    let btb = &b_mat.transpose() * &b_mat;
478
479    // Use SVD pseudoinverse for robustness
480    let svd = SVD::new(btb.clone(), true, true);
481    let max_sv = svd.singular_values.iter().cloned().fold(0.0_f64, f64::max);
482    let eps = 1e-10 * max_sv;
483
484    let u = svd.u.as_ref()?;
485    let v_t = svd.v_t.as_ref()?;
486
487    let s_inv: Vec<f64> = svd
488        .singular_values
489        .iter()
490        .map(|&s| if s > eps { 1.0 / s } else { 0.0 })
491        .collect();
492
493    let mut btb_inv = DMatrix::zeros(actual_nbasis, actual_nbasis);
494    for i in 0..actual_nbasis {
495        for j in 0..actual_nbasis {
496            let mut sum = 0.0;
497            for k in 0..actual_nbasis.min(s_inv.len()) {
498                sum += v_t[(k, i)] * s_inv[k] * u[(j, k)];
499            }
500            btb_inv[(i, j)] = sum;
501        }
502    }
503
504    let proj = &btb_inv * b_mat.transpose();
505
506    // For unpenalized fit, hat matrix H = B * (B'B)^{-1} * B'
507    let h_mat = &b_mat * &proj;
508    let edf: f64 = (0..m).map(|i| h_mat[(i, i)]).sum();
509
510    let mut all_coefs = vec![0.0; n * actual_nbasis];
511    let mut all_fitted = vec![0.0; n * m];
512    let mut total_rss = 0.0;
513
514    for i in 0..n {
515        let curve: Vec<f64> = (0..m).map(|j| data[i + j * n]).collect();
516        let curve_vec = DVector::from_vec(curve.clone());
517
518        let bt_y = b_mat.transpose() * &curve_vec;
519        let coefs = &btb_inv * bt_y;
520
521        for k in 0..actual_nbasis {
522            all_coefs[i + k * n] = coefs[k];
523        }
524
525        let fitted = &b_mat * &coefs;
526        for j in 0..m {
527            all_fitted[i + j * n] = fitted[j];
528            let resid = curve[j] - fitted[j];
529            total_rss += resid * resid;
530        }
531    }
532
533    let total_points = (n * m) as f64;
534
535    // GCV: RSS / n * (1 - edf/m)^2
536    let gcv_denom = 1.0 - edf / m as f64;
537    let gcv = if gcv_denom.abs() > 1e-10 {
538        (total_rss / total_points) / (gcv_denom * gcv_denom)
539    } else {
540        f64::INFINITY
541    };
542
543    let mse = total_rss / total_points;
544    let aic = total_points * mse.ln() + 2.0 * edf;
545    let bic = total_points * mse.ln() + total_points.ln() * edf;
546
547    Some(FourierFitResult {
548        coefficients: all_coefs,
549        fitted: all_fitted,
550        edf,
551        rss: total_rss,
552        gcv,
553        aic,
554        bic,
555        n_basis: actual_nbasis,
556    })
557}
558
559/// Select optimal number of Fourier basis functions using GCV.
560///
561/// Performs grid search over nbasis values and returns the one with minimum GCV.
562///
563/// # Arguments
564/// * `data` - Column-major matrix (n x m)
565/// * `n` - Number of samples
566/// * `m` - Number of evaluation points
567/// * `argvals` - Evaluation points
568/// * `min_nbasis` - Minimum number of basis functions to try
569/// * `max_nbasis` - Maximum number of basis functions to try
570///
571/// # Returns
572/// Optimal number of basis functions
573pub fn select_fourier_nbasis_gcv(
574    data: &[f64],
575    n: usize,
576    m: usize,
577    argvals: &[f64],
578    min_nbasis: usize,
579    max_nbasis: usize,
580) -> usize {
581    let min_nb = min_nbasis.max(3);
582    // Ensure max doesn't exceed m (can't have more parameters than data points)
583    let max_nb = max_nbasis.min(m / 2);
584
585    if max_nb <= min_nb {
586        return min_nb;
587    }
588
589    let mut best_nbasis = min_nb;
590    let mut best_gcv = f64::INFINITY;
591
592    // Test odd values only (1 constant + pairs of sin/cos)
593    let mut nbasis = if min_nb % 2 == 0 { min_nb + 1 } else { min_nb };
594    while nbasis <= max_nb {
595        if let Some(result) = fourier_fit_1d(data, n, m, argvals, nbasis) {
596            if result.gcv < best_gcv && result.gcv.is_finite() {
597                best_gcv = result.gcv;
598                best_nbasis = nbasis;
599            }
600        }
601        nbasis += 2;
602    }
603
604    best_nbasis
605}
606
607/// Result of automatic basis selection for a single curve.
608#[derive(Clone)]
609pub struct SingleCurveSelection {
610    /// Selected basis type: 0 = P-spline, 1 = Fourier
611    pub basis_type: i32,
612    /// Selected number of basis functions
613    pub nbasis: usize,
614    /// Best criterion score (GCV, AIC, or BIC)
615    pub score: f64,
616    /// Coefficients for the selected basis
617    pub coefficients: Vec<f64>,
618    /// Fitted values
619    pub fitted: Vec<f64>,
620    /// Effective degrees of freedom
621    pub edf: f64,
622    /// Whether seasonal pattern was detected (if use_seasonal_hint)
623    pub seasonal_detected: bool,
624    /// Lambda value (for P-spline, NaN for Fourier)
625    pub lambda: f64,
626}
627
628/// Result of automatic basis selection for all curves.
629pub struct BasisAutoSelectionResult {
630    /// Per-curve selection results
631    pub selections: Vec<SingleCurveSelection>,
632    /// Criterion used (0=GCV, 1=AIC, 2=BIC)
633    pub criterion: i32,
634}
635
636/// Detect if a curve has seasonal/periodic pattern using FFT.
637///
638/// Returns true if the peak power in the periodogram is significantly
639/// above the mean power level.
640fn detect_seasonality_fft(curve: &[f64]) -> bool {
641    use rustfft::{num_complex::Complex, FftPlanner};
642
643    let n = curve.len();
644    if n < 8 {
645        return false;
646    }
647
648    // Remove mean
649    let mean: f64 = curve.iter().sum::<f64>() / n as f64;
650    let mut input: Vec<Complex<f64>> = curve.iter().map(|&x| Complex::new(x - mean, 0.0)).collect();
651
652    let mut planner = FftPlanner::new();
653    let fft = planner.plan_fft_forward(n);
654    fft.process(&mut input);
655
656    // Compute power spectrum (skip DC component and Nyquist)
657    let powers: Vec<f64> = input[1..n / 2].iter().map(|c| c.norm_sqr()).collect();
658
659    if powers.is_empty() {
660        return false;
661    }
662
663    let max_power = powers.iter().cloned().fold(0.0_f64, f64::max);
664    let mean_power = powers.iter().sum::<f64>() / powers.len() as f64;
665
666    // Seasonal if peak is significantly above mean
667    max_power > 2.0 * mean_power
668}
669
670/// Fit a single curve with Fourier basis and compute criterion score.
671fn fit_curve_fourier(
672    curve: &[f64],
673    m: usize,
674    argvals: &[f64],
675    nbasis: usize,
676    criterion: i32,
677) -> Option<(f64, Vec<f64>, Vec<f64>, f64)> {
678    // Ensure nbasis is odd
679    let nbasis = if nbasis % 2 == 0 { nbasis + 1 } else { nbasis };
680
681    let basis = fourier_basis(argvals, nbasis);
682    let actual_nbasis = basis.len() / m;
683    let b_mat = DMatrix::from_column_slice(m, actual_nbasis, &basis);
684
685    let btb = &b_mat.transpose() * &b_mat;
686    let svd = SVD::new(btb.clone(), true, true);
687    let max_sv = svd.singular_values.iter().cloned().fold(0.0_f64, f64::max);
688    let eps = 1e-10 * max_sv;
689
690    let u = svd.u.as_ref()?;
691    let v_t = svd.v_t.as_ref()?;
692
693    let s_inv: Vec<f64> = svd
694        .singular_values
695        .iter()
696        .map(|&s| if s > eps { 1.0 / s } else { 0.0 })
697        .collect();
698
699    let mut btb_inv = DMatrix::zeros(actual_nbasis, actual_nbasis);
700    for i in 0..actual_nbasis {
701        for j in 0..actual_nbasis {
702            let mut sum = 0.0;
703            for k in 0..actual_nbasis.min(s_inv.len()) {
704                sum += v_t[(k, i)] * s_inv[k] * u[(j, k)];
705            }
706            btb_inv[(i, j)] = sum;
707        }
708    }
709
710    let proj = &btb_inv * b_mat.transpose();
711    let h_mat = &b_mat * &proj;
712    let edf: f64 = (0..m).map(|i| h_mat[(i, i)]).sum();
713
714    let curve_vec = DVector::from_column_slice(curve);
715    let bt_y = b_mat.transpose() * &curve_vec;
716    let coefs = &btb_inv * bt_y;
717
718    let fitted = &b_mat * &coefs;
719    let mut rss = 0.0;
720    for j in 0..m {
721        let resid = curve[j] - fitted[j];
722        rss += resid * resid;
723    }
724
725    let n_points = m as f64;
726    let score = match criterion {
727        0 => {
728            // GCV
729            let gcv_denom = 1.0 - edf / n_points;
730            if gcv_denom.abs() > 1e-10 {
731                (rss / n_points) / (gcv_denom * gcv_denom)
732            } else {
733                f64::INFINITY
734            }
735        }
736        1 => {
737            // AIC
738            let mse = rss / n_points;
739            n_points * mse.ln() + 2.0 * edf
740        }
741        _ => {
742            // BIC
743            let mse = rss / n_points;
744            n_points * mse.ln() + n_points.ln() * edf
745        }
746    };
747
748    let coef_vec: Vec<f64> = (0..actual_nbasis).map(|k| coefs[k]).collect();
749    let fitted_vec: Vec<f64> = (0..m).map(|j| fitted[j]).collect();
750
751    Some((score, coef_vec, fitted_vec, edf))
752}
753
754/// Fit a single curve with P-spline basis and compute criterion score.
755fn fit_curve_pspline(
756    curve: &[f64],
757    m: usize,
758    argvals: &[f64],
759    nbasis: usize,
760    lambda: f64,
761    order: usize,
762    criterion: i32,
763) -> Option<(f64, Vec<f64>, Vec<f64>, f64)> {
764    let basis = bspline_basis(argvals, nbasis.saturating_sub(4).max(2), 4);
765    let actual_nbasis = basis.len() / m;
766    let b_mat = DMatrix::from_column_slice(m, actual_nbasis, &basis);
767
768    let d = difference_matrix(actual_nbasis, order);
769    let penalty = &d.transpose() * &d;
770
771    let btb = &b_mat.transpose() * &b_mat;
772    let btb_penalized = &btb + lambda * &penalty;
773
774    let svd = SVD::new(btb_penalized.clone(), true, true);
775    let max_sv = svd.singular_values.iter().cloned().fold(0.0_f64, f64::max);
776    let eps = 1e-10 * max_sv;
777
778    let u = svd.u.as_ref()?;
779    let v_t = svd.v_t.as_ref()?;
780
781    let s_inv: Vec<f64> = svd
782        .singular_values
783        .iter()
784        .map(|&s| if s > eps { 1.0 / s } else { 0.0 })
785        .collect();
786
787    let mut btb_inv = DMatrix::zeros(actual_nbasis, actual_nbasis);
788    for i in 0..actual_nbasis {
789        for j in 0..actual_nbasis {
790            let mut sum = 0.0;
791            for k in 0..actual_nbasis.min(s_inv.len()) {
792                sum += v_t[(k, i)] * s_inv[k] * u[(j, k)];
793            }
794            btb_inv[(i, j)] = sum;
795        }
796    }
797
798    let proj = &btb_inv * b_mat.transpose();
799    let h_mat = &b_mat * &proj;
800    let edf: f64 = (0..m).map(|i| h_mat[(i, i)]).sum();
801
802    let curve_vec = DVector::from_column_slice(curve);
803    let bt_y = b_mat.transpose() * &curve_vec;
804    let coefs = &btb_inv * bt_y;
805
806    let fitted = &b_mat * &coefs;
807    let mut rss = 0.0;
808    for j in 0..m {
809        let resid = curve[j] - fitted[j];
810        rss += resid * resid;
811    }
812
813    let n_points = m as f64;
814    let score = match criterion {
815        0 => {
816            // GCV
817            let gcv_denom = 1.0 - edf / n_points;
818            if gcv_denom.abs() > 1e-10 {
819                (rss / n_points) / (gcv_denom * gcv_denom)
820            } else {
821                f64::INFINITY
822            }
823        }
824        1 => {
825            // AIC
826            let mse = rss / n_points;
827            n_points * mse.ln() + 2.0 * edf
828        }
829        _ => {
830            // BIC
831            let mse = rss / n_points;
832            n_points * mse.ln() + n_points.ln() * edf
833        }
834    };
835
836    let coef_vec: Vec<f64> = (0..actual_nbasis).map(|k| coefs[k]).collect();
837    let fitted_vec: Vec<f64> = (0..m).map(|j| fitted[j]).collect();
838
839    Some((score, coef_vec, fitted_vec, edf))
840}
841
842/// Select optimal basis type and parameters for each curve individually.
843///
844/// This function compares Fourier and P-spline bases for each curve,
845/// selecting the optimal basis type and number of basis functions using
846/// model selection criteria (GCV, AIC, or BIC).
847///
848/// # Arguments
849/// * `data` - Column-major matrix (n x m)
850/// * `n` - Number of curves
851/// * `m` - Number of evaluation points per curve
852/// * `argvals` - Evaluation points
853/// * `criterion` - Model selection criterion: 0=GCV, 1=AIC, 2=BIC
854/// * `nbasis_min` - Minimum number of basis functions (0 for auto)
855/// * `nbasis_max` - Maximum number of basis functions (0 for auto)
856/// * `lambda_pspline` - Smoothing parameter for P-spline (negative for auto-select)
857/// * `use_seasonal_hint` - Whether to use FFT to detect seasonality
858///
859/// # Returns
860/// BasisAutoSelectionResult with per-curve selections
861pub fn select_basis_auto_1d(
862    data: &[f64],
863    n: usize,
864    m: usize,
865    argvals: &[f64],
866    criterion: i32,
867    nbasis_min: usize,
868    nbasis_max: usize,
869    lambda_pspline: f64,
870    use_seasonal_hint: bool,
871) -> BasisAutoSelectionResult {
872    // Determine nbasis ranges
873    let fourier_min = if nbasis_min > 0 { nbasis_min.max(3) } else { 3 };
874    let fourier_max = if nbasis_max > 0 {
875        nbasis_max.min(m / 3).min(25)
876    } else {
877        (m / 3).min(25)
878    };
879
880    let pspline_min = if nbasis_min > 0 { nbasis_min.max(6) } else { 6 };
881    let pspline_max = if nbasis_max > 0 {
882        nbasis_max.min(m / 2).min(40)
883    } else {
884        (m / 2).min(40)
885    };
886
887    // Lambda grid for P-spline when auto-selecting
888    let lambda_grid = [0.001, 0.01, 0.1, 1.0, 10.0, 100.0];
889    let auto_lambda = lambda_pspline < 0.0;
890
891    let selections: Vec<SingleCurveSelection> = (0..n)
892        .into_par_iter()
893        .map(|i| {
894            // Extract single curve
895            let curve: Vec<f64> = (0..m).map(|j| data[i + j * n]).collect();
896
897            // Detect seasonality if requested
898            let seasonal_detected = if use_seasonal_hint {
899                detect_seasonality_fft(&curve)
900            } else {
901                false
902            };
903
904            let mut best_score = f64::INFINITY;
905            let mut best_basis_type = 0i32; // P-spline
906            let mut best_nbasis = pspline_min;
907            let mut best_coefs = Vec::new();
908            let mut best_fitted = Vec::new();
909            let mut best_edf = 0.0;
910            let mut best_lambda = f64::NAN;
911
912            // Try Fourier bases
913            let fourier_start = if seasonal_detected {
914                fourier_min.max(5)
915            } else {
916                fourier_min
917            };
918            let mut nb = if fourier_start % 2 == 0 {
919                fourier_start + 1
920            } else {
921                fourier_start
922            };
923            while nb <= fourier_max {
924                if let Some((score, coefs, fitted, edf)) =
925                    fit_curve_fourier(&curve, m, argvals, nb, criterion)
926                {
927                    if score < best_score && score.is_finite() {
928                        best_score = score;
929                        best_basis_type = 1; // Fourier
930                        best_nbasis = nb;
931                        best_coefs = coefs;
932                        best_fitted = fitted;
933                        best_edf = edf;
934                        best_lambda = f64::NAN;
935                    }
936                }
937                nb += 2;
938            }
939
940            // Try P-spline bases
941            for nb in pspline_min..=pspline_max {
942                if auto_lambda {
943                    // Search over lambda grid
944                    for &lam in &lambda_grid {
945                        if let Some((score, coefs, fitted, edf)) =
946                            fit_curve_pspline(&curve, m, argvals, nb, lam, 2, criterion)
947                        {
948                            if score < best_score && score.is_finite() {
949                                best_score = score;
950                                best_basis_type = 0; // P-spline
951                                best_nbasis = nb;
952                                best_coefs = coefs;
953                                best_fitted = fitted;
954                                best_edf = edf;
955                                best_lambda = lam;
956                            }
957                        }
958                    }
959                } else {
960                    // Use fixed lambda
961                    if let Some((score, coefs, fitted, edf)) =
962                        fit_curve_pspline(&curve, m, argvals, nb, lambda_pspline, 2, criterion)
963                    {
964                        if score < best_score && score.is_finite() {
965                            best_score = score;
966                            best_basis_type = 0; // P-spline
967                            best_nbasis = nb;
968                            best_coefs = coefs;
969                            best_fitted = fitted;
970                            best_edf = edf;
971                            best_lambda = lambda_pspline;
972                        }
973                    }
974                }
975            }
976
977            SingleCurveSelection {
978                basis_type: best_basis_type,
979                nbasis: best_nbasis,
980                score: best_score,
981                coefficients: best_coefs,
982                fitted: best_fitted,
983                edf: best_edf,
984                seasonal_detected,
985                lambda: best_lambda,
986            }
987        })
988        .collect();
989
990    BasisAutoSelectionResult {
991        selections,
992        criterion,
993    }
994}
995
996#[cfg(test)]
997mod tests {
998    use super::*;
999    use std::f64::consts::PI;
1000
1001    /// Generate a uniform grid of points
1002    fn uniform_grid(n: usize) -> Vec<f64> {
1003        (0..n).map(|i| i as f64 / (n - 1) as f64).collect()
1004    }
1005
1006    /// Generate sine wave data
1007    fn sine_wave(t: &[f64], freq: f64) -> Vec<f64> {
1008        t.iter().map(|&ti| (2.0 * PI * freq * ti).sin()).collect()
1009    }
1010
1011    // ============== B-spline basis tests ==============
1012
1013    #[test]
1014    fn test_bspline_basis_dimensions() {
1015        let t = uniform_grid(50);
1016        let nknots = 10;
1017        let order = 4;
1018        let basis = bspline_basis(&t, nknots, order);
1019
1020        let expected_nbasis = nknots + order;
1021        assert_eq!(basis.len(), t.len() * expected_nbasis);
1022    }
1023
1024    #[test]
1025    fn test_bspline_basis_partition_of_unity() {
1026        // B-splines should sum to 1 at each point (partition of unity)
1027        let t = uniform_grid(50);
1028        let nknots = 8;
1029        let order = 4;
1030        let basis = bspline_basis(&t, nknots, order);
1031
1032        let nbasis = nknots + order;
1033        for i in 0..t.len() {
1034            let sum: f64 = (0..nbasis).map(|j| basis[i + j * t.len()]).sum();
1035            assert!(
1036                (sum - 1.0).abs() < 1e-10,
1037                "B-spline partition of unity failed at point {}: sum = {}",
1038                i,
1039                sum
1040            );
1041        }
1042    }
1043
1044    #[test]
1045    fn test_bspline_basis_non_negative() {
1046        let t = uniform_grid(50);
1047        let basis = bspline_basis(&t, 8, 4);
1048
1049        for &val in &basis {
1050            assert!(val >= -1e-10, "B-spline values should be non-negative");
1051        }
1052    }
1053
1054    #[test]
1055    fn test_bspline_basis_boundary() {
1056        // Test that basis functions work at boundary points
1057        let t = vec![0.0, 0.5, 1.0];
1058        let basis = bspline_basis(&t, 5, 4);
1059
1060        // Should have valid output (no NaN or Inf)
1061        for &val in &basis {
1062            assert!(val.is_finite(), "B-spline should produce finite values");
1063        }
1064    }
1065
1066    // ============== Fourier basis tests ==============
1067
1068    #[test]
1069    fn test_fourier_basis_dimensions() {
1070        let t = uniform_grid(50);
1071        let nbasis = 7;
1072        let basis = fourier_basis(&t, nbasis);
1073
1074        assert_eq!(basis.len(), t.len() * nbasis);
1075    }
1076
1077    #[test]
1078    fn test_fourier_basis_constant_first_column() {
1079        let t = uniform_grid(50);
1080        let nbasis = 7;
1081        let basis = fourier_basis(&t, nbasis);
1082
1083        // First column should be constant (DC component = 1)
1084        let first_val = basis[0];
1085        for i in 0..t.len() {
1086            assert!(
1087                (basis[i] - first_val).abs() < 1e-10,
1088                "First Fourier column should be constant"
1089            );
1090        }
1091    }
1092
1093    #[test]
1094    fn test_fourier_basis_sin_cos_range() {
1095        let t = uniform_grid(100);
1096        let nbasis = 11;
1097        let basis = fourier_basis(&t, nbasis);
1098
1099        // Sin and cos values should be in [-1, 1]
1100        for &val in &basis {
1101            assert!((-1.0 - 1e-10..=1.0 + 1e-10).contains(&val));
1102        }
1103    }
1104
1105    #[test]
1106    fn test_fourier_basis_with_period() {
1107        let t = uniform_grid(100);
1108        let nbasis = 5;
1109        let period = 0.5;
1110        let basis = fourier_basis_with_period(&t, nbasis, period);
1111
1112        assert_eq!(basis.len(), t.len() * nbasis);
1113        // Verify first column is constant
1114        let first_val = basis[0];
1115        for i in 0..t.len() {
1116            assert!((basis[i] - first_val).abs() < 1e-10);
1117        }
1118    }
1119
1120    #[test]
1121    fn test_fourier_basis_period_affects_frequency() {
1122        let t = uniform_grid(100);
1123        let nbasis = 5;
1124
1125        let basis1 = fourier_basis_with_period(&t, nbasis, 1.0);
1126        let basis2 = fourier_basis_with_period(&t, nbasis, 0.5);
1127
1128        // Different periods should give different basis matrices
1129        let n = t.len();
1130        let mut any_different = false;
1131        for i in 0..n {
1132            // Compare second column (first sin term)
1133            if (basis1[i + n] - basis2[i + n]).abs() > 1e-10 {
1134                any_different = true;
1135                break;
1136            }
1137        }
1138        assert!(
1139            any_different,
1140            "Different periods should produce different bases"
1141        );
1142    }
1143
1144    // ============== Difference matrix tests ==============
1145
1146    #[test]
1147    fn test_difference_matrix_order_zero() {
1148        let d = difference_matrix(5, 0);
1149        assert_eq!(d.nrows(), 5);
1150        assert_eq!(d.ncols(), 5);
1151
1152        // Should be identity matrix
1153        for i in 0..5 {
1154            for j in 0..5 {
1155                let expected = if i == j { 1.0 } else { 0.0 };
1156                assert!((d[(i, j)] - expected).abs() < 1e-10);
1157            }
1158        }
1159    }
1160
1161    #[test]
1162    fn test_difference_matrix_first_order() {
1163        let d = difference_matrix(5, 1);
1164        assert_eq!(d.nrows(), 4);
1165        assert_eq!(d.ncols(), 5);
1166
1167        // First order differences: D * [1,2,3,4,5] = [1,1,1,1]
1168        let x = DVector::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1169        let dx = &d * x;
1170        for i in 0..4 {
1171            assert!((dx[i] - 1.0).abs() < 1e-10);
1172        }
1173    }
1174
1175    #[test]
1176    fn test_difference_matrix_second_order() {
1177        let d = difference_matrix(5, 2);
1178        assert_eq!(d.nrows(), 3);
1179        assert_eq!(d.ncols(), 5);
1180
1181        // Second order differences of linear: D^2 * [1,2,3,4,5] = [0,0,0]
1182        let x = DVector::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1183        let dx = &d * x;
1184        for i in 0..3 {
1185            assert!(dx[i].abs() < 1e-10, "Second diff of linear should be zero");
1186        }
1187    }
1188
1189    #[test]
1190    fn test_difference_matrix_quadratic() {
1191        let d = difference_matrix(5, 2);
1192
1193        // Second order differences of quadratic: D^2 * [1,4,9,16,25] = [2,2,2]
1194        let x = DVector::from_vec(vec![1.0, 4.0, 9.0, 16.0, 25.0]);
1195        let dx = &d * x;
1196        for i in 0..3 {
1197            assert!(
1198                (dx[i] - 2.0).abs() < 1e-10,
1199                "Second diff of squares should be 2"
1200            );
1201        }
1202    }
1203
1204    // ============== Basis projection tests ==============
1205
1206    #[test]
1207    fn test_fdata_to_basis_1d_bspline() {
1208        let t = uniform_grid(50);
1209        let n = 5;
1210        let m = t.len();
1211
1212        // Create simple data
1213        let data: Vec<f64> = (0..n)
1214            .flat_map(|i| t.iter().map(move |&ti| ti + i as f64 * 0.1))
1215            .collect();
1216
1217        let result = fdata_to_basis_1d(&data, n, m, &t, 10, 0);
1218        assert!(result.is_some());
1219
1220        let res = result.unwrap();
1221        assert!(res.n_basis > 0);
1222        assert_eq!(res.coefficients.len(), n * res.n_basis);
1223    }
1224
1225    #[test]
1226    fn test_fdata_to_basis_1d_fourier() {
1227        let t = uniform_grid(50);
1228        let n = 5;
1229        let m = t.len();
1230
1231        // Create sine wave data
1232        let data: Vec<f64> = (0..n).flat_map(|_| sine_wave(&t, 2.0)).collect();
1233
1234        let result = fdata_to_basis_1d(&data, n, m, &t, 7, 1);
1235        assert!(result.is_some());
1236
1237        let res = result.unwrap();
1238        assert_eq!(res.n_basis, 7);
1239    }
1240
1241    #[test]
1242    fn test_fdata_to_basis_1d_invalid_input() {
1243        let t = uniform_grid(50);
1244
1245        // Empty data
1246        let result = fdata_to_basis_1d(&[], 0, 50, &t, 10, 0);
1247        assert!(result.is_none());
1248
1249        // nbasis too small
1250        let data = vec![0.0; 50];
1251        let result = fdata_to_basis_1d(&data, 1, 50, &t, 1, 0);
1252        assert!(result.is_none());
1253    }
1254
1255    #[test]
1256    fn test_basis_roundtrip() {
1257        let t = uniform_grid(100);
1258        let n = 1;
1259        let m = t.len();
1260
1261        // Create smooth sine wave data (Fourier basis should represent exactly)
1262        let data = sine_wave(&t, 1.0);
1263
1264        // Project to Fourier basis with enough terms
1265        let proj = fdata_to_basis_1d(&data, n, m, &t, 5, 1).unwrap();
1266
1267        // Reconstruct
1268        let reconstructed =
1269            basis_to_fdata_1d(&proj.coefficients, n, proj.n_basis, &t, proj.n_basis, 1);
1270
1271        // Should approximately match original for a simple sine wave
1272        let mut max_error = 0.0;
1273        for i in 0..m {
1274            let err = (data[i] - reconstructed[i]).abs();
1275            if err > max_error {
1276                max_error = err;
1277            }
1278        }
1279        assert!(max_error < 0.5, "Roundtrip error too large: {}", max_error);
1280    }
1281
1282    #[test]
1283    fn test_basis_to_fdata_empty_input() {
1284        let result = basis_to_fdata_1d(&[], 0, 0, &[], 5, 0);
1285        assert!(result.is_empty());
1286    }
1287
1288    // ============== P-spline fitting tests ==============
1289
1290    #[test]
1291    fn test_pspline_fit_1d_basic() {
1292        let t = uniform_grid(50);
1293        let n = 3;
1294        let m = t.len();
1295
1296        // Create noisy data
1297        let data: Vec<f64> = (0..n)
1298            .flat_map(|i| {
1299                t.iter()
1300                    .enumerate()
1301                    .map(move |(j, &ti)| (2.0 * PI * ti).sin() + 0.1 * (i * j) as f64 % 1.0)
1302            })
1303            .collect();
1304
1305        let result = pspline_fit_1d(&data, n, m, &t, 15, 1.0, 2);
1306        assert!(result.is_some());
1307
1308        let res = result.unwrap();
1309        assert!(res.n_basis > 0);
1310        assert_eq!(res.fitted.len(), n * m);
1311        assert!(res.rss >= 0.0);
1312        assert!(res.edf > 0.0);
1313        assert!(res.gcv.is_finite());
1314    }
1315
1316    #[test]
1317    fn test_pspline_fit_1d_smoothness() {
1318        let t = uniform_grid(50);
1319        let n = 1;
1320        let m = t.len();
1321
1322        // Create noisy sine wave
1323        let data: Vec<f64> = t
1324            .iter()
1325            .enumerate()
1326            .map(|(i, &ti)| (2.0 * PI * ti).sin() + 0.3 * ((i * 17) % 100) as f64 / 100.0)
1327            .collect();
1328
1329        let low_lambda = pspline_fit_1d(&data, n, m, &t, 15, 0.01, 2).unwrap();
1330        let high_lambda = pspline_fit_1d(&data, n, m, &t, 15, 100.0, 2).unwrap();
1331
1332        // Higher lambda should give lower edf (more smoothing)
1333        assert!(high_lambda.edf < low_lambda.edf);
1334    }
1335
1336    #[test]
1337    fn test_pspline_fit_1d_invalid_input() {
1338        let t = uniform_grid(50);
1339        let result = pspline_fit_1d(&[], 0, 50, &t, 15, 1.0, 2);
1340        assert!(result.is_none());
1341    }
1342
1343    // ============== Fourier fitting tests ==============
1344
1345    #[test]
1346    fn test_fourier_fit_1d_sine_wave() {
1347        let t = uniform_grid(100);
1348        let n = 1;
1349        let m = t.len();
1350
1351        // Create pure sine wave
1352        let data = sine_wave(&t, 2.0);
1353
1354        let result = fourier_fit_1d(&data, n, m, &t, 11);
1355        assert!(result.is_some());
1356
1357        let res = result.unwrap();
1358        assert!(res.rss < 1e-6, "Pure sine should have near-zero RSS");
1359    }
1360
1361    #[test]
1362    fn test_fourier_fit_1d_makes_nbasis_odd() {
1363        let t = uniform_grid(50);
1364        let data = sine_wave(&t, 1.0);
1365
1366        // Pass even nbasis
1367        let result = fourier_fit_1d(&data, 1, t.len(), &t, 6);
1368        assert!(result.is_some());
1369
1370        // Should have been adjusted to odd
1371        let res = result.unwrap();
1372        assert!(res.n_basis % 2 == 1);
1373    }
1374
1375    #[test]
1376    fn test_fourier_fit_1d_criteria() {
1377        let t = uniform_grid(50);
1378        let data = sine_wave(&t, 2.0);
1379
1380        let result = fourier_fit_1d(&data, 1, t.len(), &t, 9).unwrap();
1381
1382        // All criteria should be finite
1383        assert!(result.gcv.is_finite());
1384        assert!(result.aic.is_finite());
1385        assert!(result.bic.is_finite());
1386    }
1387
1388    #[test]
1389    fn test_fourier_fit_1d_invalid_nbasis() {
1390        let t = uniform_grid(50);
1391        let data = sine_wave(&t, 1.0);
1392
1393        // nbasis < 3 should return None
1394        let result = fourier_fit_1d(&data, 1, t.len(), &t, 2);
1395        assert!(result.is_none());
1396    }
1397
1398    // ============== GCV selection tests ==============
1399
1400    #[test]
1401    fn test_select_fourier_nbasis_gcv_range() {
1402        let t = uniform_grid(100);
1403        let data = sine_wave(&t, 3.0);
1404
1405        let best = select_fourier_nbasis_gcv(&data, 1, t.len(), &t, 3, 15);
1406
1407        assert!((3..=15).contains(&best));
1408        assert!(best % 2 == 1, "Selected nbasis should be odd");
1409    }
1410
1411    #[test]
1412    fn test_select_fourier_nbasis_gcv_respects_min() {
1413        let t = uniform_grid(50);
1414        let data = sine_wave(&t, 1.0);
1415
1416        let best = select_fourier_nbasis_gcv(&data, 1, t.len(), &t, 7, 15);
1417        assert!(best >= 7);
1418    }
1419
1420    // ============== Auto selection tests ==============
1421
1422    #[test]
1423    fn test_select_basis_auto_1d_returns_results() {
1424        let t = uniform_grid(50);
1425        let n = 3;
1426        let m = t.len();
1427
1428        let data: Vec<f64> = (0..n).flat_map(|i| sine_wave(&t, 1.0 + i as f64)).collect();
1429
1430        let result = select_basis_auto_1d(&data, n, m, &t, 0, 5, 15, 1.0, false);
1431
1432        assert_eq!(result.selections.len(), n);
1433        for sel in &result.selections {
1434            assert!(sel.nbasis >= 3);
1435            assert!(!sel.coefficients.is_empty());
1436            assert_eq!(sel.fitted.len(), m);
1437        }
1438    }
1439
1440    #[test]
1441    fn test_select_basis_auto_1d_seasonal_hint() {
1442        let t = uniform_grid(100);
1443        let n = 1;
1444        let m = t.len();
1445
1446        // Strong seasonal pattern
1447        let data = sine_wave(&t, 5.0);
1448
1449        let result = select_basis_auto_1d(&data, n, m, &t, 0, 0, 0, -1.0, true);
1450
1451        assert_eq!(result.selections.len(), 1);
1452        assert!(result.selections[0].seasonal_detected);
1453    }
1454
1455    #[test]
1456    fn test_select_basis_auto_1d_non_seasonal() {
1457        let t = uniform_grid(50);
1458        let n = 1;
1459        let m = t.len();
1460
1461        // Constant data (definitely not seasonal)
1462        let data: Vec<f64> = vec![1.0; m];
1463
1464        let result = select_basis_auto_1d(&data, n, m, &t, 0, 0, 0, -1.0, true);
1465
1466        // Constant data shouldn't be detected as seasonal
1467        assert!(!result.selections[0].seasonal_detected);
1468    }
1469
1470    #[test]
1471    fn test_select_basis_auto_1d_criterion_options() {
1472        let t = uniform_grid(50);
1473        let data = sine_wave(&t, 2.0);
1474
1475        // Test all three criteria
1476        let gcv_result = select_basis_auto_1d(&data, 1, t.len(), &t, 0, 0, 0, 1.0, false);
1477        let aic_result = select_basis_auto_1d(&data, 1, t.len(), &t, 1, 0, 0, 1.0, false);
1478        let bic_result = select_basis_auto_1d(&data, 1, t.len(), &t, 2, 0, 0, 1.0, false);
1479
1480        assert_eq!(gcv_result.criterion, 0);
1481        assert_eq!(aic_result.criterion, 1);
1482        assert_eq!(bic_result.criterion, 2);
1483    }
1484}