Skip to main content

fdars_core/
basis.rs

1//! Basis representation functions for functional data.
2//!
3//! This module provides B-spline and Fourier basis expansions for representing
4//! functional data in a finite-dimensional basis.
5
6use crate::iter_maybe_parallel;
7use crate::matrix::FdMatrix;
8use nalgebra::{DMatrix, DVector, SVD};
9#[cfg(feature = "parallel")]
10use rayon::iter::ParallelIterator;
11use std::f64::consts::PI;
12
13/// Compute pseudoinverse of a symmetric matrix via SVD.
14///
15/// Uses singular value decomposition with threshold-based truncation
16/// for numerical stability with near-singular matrices.
17fn svd_pseudoinverse(mat: &DMatrix<f64>) -> Option<DMatrix<f64>> {
18    let n = mat.nrows();
19    let svd = SVD::new(mat.clone(), true, true);
20    let max_sv = svd.singular_values.iter().cloned().fold(0.0_f64, f64::max);
21    let eps = 1e-10 * max_sv;
22
23    let u = svd.u.as_ref()?;
24    let v_t = svd.v_t.as_ref()?;
25
26    let s_inv: Vec<f64> = svd
27        .singular_values
28        .iter()
29        .map(|&s| if s > eps { 1.0 / s } else { 0.0 })
30        .collect();
31
32    let mut result = DMatrix::zeros(n, n);
33    for i in 0..n {
34        for j in 0..n {
35            let mut sum = 0.0;
36            for k in 0..n.min(s_inv.len()) {
37                sum += v_t[(k, i)] * s_inv[k] * u[(j, k)];
38            }
39            result[(i, j)] = sum;
40        }
41    }
42
43    Some(result)
44}
45
46/// Compute model selection criterion (GCV, AIC, or BIC).
47///
48/// # Arguments
49/// * `rss` - Residual sum of squares
50/// * `n_points` - Number of data points
51/// * `edf` - Effective degrees of freedom
52/// * `criterion` - 0=GCV, 1=AIC, 2=BIC
53fn compute_model_criterion(rss: f64, n_points: f64, edf: f64, criterion: i32) -> f64 {
54    match criterion {
55        0 => {
56            let gcv_denom = 1.0 - edf / n_points;
57            if gcv_denom.abs() > 1e-10 {
58                (rss / n_points) / (gcv_denom * gcv_denom)
59            } else {
60                f64::INFINITY
61            }
62        }
63        1 => {
64            let mse = rss / n_points;
65            n_points * mse.ln() + 2.0 * edf
66        }
67        _ => {
68            let mse = rss / n_points;
69            n_points * mse.ln() + n_points.ln() * edf
70        }
71    }
72}
73
74/// Construct B-spline knot vector with extended boundary knots.
75fn construct_bspline_knots(t_min: f64, t_max: f64, nknots: usize, order: usize) -> Vec<f64> {
76    let dt = (t_max - t_min) / (nknots - 1) as f64;
77    let mut knots = Vec::with_capacity(nknots + 2 * order);
78    for i in 0..order {
79        knots.push(t_min - (order - i) as f64 * dt);
80    }
81    for i in 0..nknots {
82        knots.push(t_min + i as f64 * dt);
83    }
84    for i in 1..=order {
85        knots.push(t_max + i as f64 * dt);
86    }
87    knots
88}
89
90/// Evaluate order-zero B-spline basis at a single point.
91fn evaluate_order_zero(t_val: f64, knots: &[f64], t_max_knot_idx: usize) -> Vec<f64> {
92    let mut b0 = vec![0.0; knots.len() - 1];
93    for j in 0..(knots.len() - 1) {
94        let in_interval = if j == t_max_knot_idx - 1 {
95            t_val >= knots[j] && t_val <= knots[j + 1]
96        } else {
97            t_val >= knots[j] && t_val < knots[j + 1]
98        };
99        if in_interval {
100            b0[j] = 1.0;
101            break;
102        }
103    }
104    b0
105}
106
107/// Compute one order of B-spline recurrence from the previous order.
108fn bspline_recurrence_step(b: &[f64], knots: &[f64], t_val: f64, k: usize) -> Vec<f64> {
109    (0..(knots.len() - k))
110        .map(|j| {
111            let d1 = knots[j + k - 1] - knots[j];
112            let d2 = knots[j + k] - knots[j + 1];
113            let left = if d1.abs() > 1e-10 {
114                (t_val - knots[j]) / d1 * b[j]
115            } else {
116                0.0
117            };
118            let right = if d2.abs() > 1e-10 {
119                (knots[j + k] - t_val) / d2 * b[j + 1]
120            } else {
121                0.0
122            };
123            left + right
124        })
125        .collect()
126}
127
128/// Compute B-spline basis matrix for given knots and grid points.
129///
130/// Creates a B-spline basis with uniformly spaced knots extended beyond the data range.
131/// For order k and nknots interior knots, produces nknots + order basis functions.
132pub fn bspline_basis(t: &[f64], nknots: usize, order: usize) -> Vec<f64> {
133    let n = t.len();
134    let nbasis = nknots + order;
135
136    let t_min = t.iter().cloned().fold(f64::INFINITY, f64::min);
137    let t_max = t.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
138
139    let knots = construct_bspline_knots(t_min, t_max, nknots, order);
140    let t_max_knot_idx = order + nknots - 1;
141
142    let mut basis = vec![0.0; n * nbasis];
143
144    for (ti, &t_val) in t.iter().enumerate() {
145        let mut b = evaluate_order_zero(t_val, &knots, t_max_knot_idx);
146
147        for k in 2..=order {
148            b = bspline_recurrence_step(&b, &knots, t_val, k);
149        }
150
151        for j in 0..nbasis {
152            basis[ti + j * n] = b[j];
153        }
154    }
155
156    basis
157}
158
159/// Compute Fourier basis matrix.
160///
161/// The period is automatically set to the range of evaluation points (t_max - t_min).
162/// For explicit period control, use `fourier_basis_with_period`.
163pub fn fourier_basis(t: &[f64], nbasis: usize) -> Vec<f64> {
164    let t_min = t.iter().cloned().fold(f64::INFINITY, f64::min);
165    let t_max = t.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
166    let period = t_max - t_min;
167    fourier_basis_with_period(t, nbasis, period)
168}
169
170/// Compute Fourier basis matrix with explicit period.
171///
172/// This function creates a Fourier basis expansion where the period can be specified
173/// independently of the evaluation range. This is essential for seasonal analysis
174/// where the seasonal period may differ from the observation window.
175///
176/// # Arguments
177/// * `t` - Evaluation points
178/// * `nbasis` - Number of basis functions (1 constant + pairs of sin/cos)
179/// * `period` - The period for the Fourier basis
180///
181/// # Returns
182/// Column-major matrix (n_points x nbasis) stored as flat vector
183pub fn fourier_basis_with_period(t: &[f64], nbasis: usize, period: f64) -> Vec<f64> {
184    let n = t.len();
185    let t_min = t.iter().cloned().fold(f64::INFINITY, f64::min);
186
187    let mut basis = vec![0.0; n * nbasis];
188
189    for (i, &ti) in t.iter().enumerate() {
190        let x = 2.0 * PI * (ti - t_min) / period;
191
192        basis[i] = 1.0;
193
194        let mut k = 1;
195        let mut freq = 1;
196        while k < nbasis {
197            if k < nbasis {
198                basis[i + k * n] = (freq as f64 * x).sin();
199                k += 1;
200            }
201            if k < nbasis {
202                basis[i + k * n] = (freq as f64 * x).cos();
203                k += 1;
204            }
205            freq += 1;
206        }
207    }
208
209    basis
210}
211
212/// Compute difference matrix for P-spline penalty.
213pub fn difference_matrix(n: usize, order: usize) -> DMatrix<f64> {
214    if order == 0 {
215        return DMatrix::identity(n, n);
216    }
217
218    let mut d = DMatrix::zeros(n - 1, n);
219    for i in 0..(n - 1) {
220        d[(i, i)] = -1.0;
221        d[(i, i + 1)] = 1.0;
222    }
223
224    let mut result = d;
225    for _ in 1..order {
226        if result.nrows() <= 1 {
227            break;
228        }
229        let rows = result.nrows() - 1;
230        let cols = result.ncols();
231        let mut d_next = DMatrix::zeros(rows, cols);
232        for i in 0..rows {
233            for j in 0..cols {
234                d_next[(i, j)] = -result[(i, j)] + result[(i + 1, j)];
235            }
236        }
237        result = d_next;
238    }
239
240    result
241}
242
243/// Result of basis projection.
244pub struct BasisProjectionResult {
245    /// Coefficient matrix (n_samples x n_basis)
246    pub coefficients: FdMatrix,
247    /// Number of basis functions used
248    pub n_basis: usize,
249}
250
251/// Project functional data to basis coefficients.
252///
253/// # Arguments
254/// * `data` - Column-major FdMatrix (n x m)
255/// * `argvals` - Evaluation points
256/// * `nbasis` - Number of basis functions
257/// * `basis_type` - 0 = B-spline, 1 = Fourier
258pub fn fdata_to_basis_1d(
259    data: &FdMatrix,
260    argvals: &[f64],
261    nbasis: usize,
262    basis_type: i32,
263) -> Option<BasisProjectionResult> {
264    let n = data.nrows();
265    let m = data.ncols();
266    if n == 0 || m == 0 || argvals.len() != m || nbasis < 2 {
267        return None;
268    }
269
270    let basis = if basis_type == 1 {
271        fourier_basis(argvals, nbasis)
272    } else {
273        // For order 4 B-splines: nbasis = nknots + order, so nknots = nbasis - 4
274        bspline_basis(argvals, nbasis.saturating_sub(4).max(2), 4)
275    };
276
277    let actual_nbasis = basis.len() / m;
278    let b_mat = DMatrix::from_column_slice(m, actual_nbasis, &basis);
279
280    let btb = &b_mat.transpose() * &b_mat;
281    let btb_inv = svd_pseudoinverse(&btb)?;
282    let proj = btb_inv * b_mat.transpose();
283
284    let coefs: Vec<f64> = iter_maybe_parallel!(0..n)
285        .flat_map(|i| {
286            let curve: Vec<f64> = (0..m).map(|j| data[(i, j)]).collect();
287            (0..actual_nbasis)
288                .map(|k| {
289                    let mut sum = 0.0;
290                    for j in 0..m {
291                        sum += proj[(k, j)] * curve[j];
292                    }
293                    sum
294                })
295                .collect::<Vec<_>>()
296        })
297        .collect();
298
299    Some(BasisProjectionResult {
300        coefficients: FdMatrix::from_column_major(coefs, n, actual_nbasis).unwrap(),
301        n_basis: actual_nbasis,
302    })
303}
304
305/// Reconstruct functional data from basis coefficients.
306pub fn basis_to_fdata_1d(
307    coefs: &FdMatrix,
308    argvals: &[f64],
309    nbasis: usize,
310    basis_type: i32,
311) -> FdMatrix {
312    let n = coefs.nrows();
313    let coefs_ncols = coefs.ncols();
314    let m = argvals.len();
315    if n == 0 || m == 0 || nbasis < 2 {
316        return FdMatrix::zeros(0, 0);
317    }
318
319    let basis = if basis_type == 1 {
320        fourier_basis(argvals, nbasis)
321    } else {
322        // For order 4 B-splines: nbasis = nknots + order, so nknots = nbasis - 4
323        bspline_basis(argvals, nbasis.saturating_sub(4).max(2), 4)
324    };
325
326    let actual_nbasis = basis.len() / m;
327
328    let flat: Vec<f64> = iter_maybe_parallel!(0..n)
329        .flat_map(|i| {
330            (0..m)
331                .map(|j| {
332                    let mut sum = 0.0;
333                    for k in 0..actual_nbasis.min(coefs_ncols) {
334                        sum += coefs[(i, k)] * basis[j + k * m];
335                    }
336                    sum
337                })
338                .collect::<Vec<_>>()
339        })
340        .collect();
341
342    FdMatrix::from_column_major(flat, n, m).unwrap()
343}
344
345/// Result of P-spline fitting.
346pub struct PsplineFitResult {
347    /// Coefficient matrix (n x nbasis)
348    pub coefficients: FdMatrix,
349    /// Fitted values (n x m)
350    pub fitted: FdMatrix,
351    /// Effective degrees of freedom
352    pub edf: f64,
353    /// Residual sum of squares
354    pub rss: f64,
355    /// GCV score
356    pub gcv: f64,
357    /// AIC
358    pub aic: f64,
359    /// BIC
360    pub bic: f64,
361    /// Number of basis functions
362    pub n_basis: usize,
363}
364
365/// Fit P-splines to functional data.
366pub fn pspline_fit_1d(
367    data: &FdMatrix,
368    argvals: &[f64],
369    nbasis: usize,
370    lambda: f64,
371    order: usize,
372) -> Option<PsplineFitResult> {
373    let n = data.nrows();
374    let m = data.ncols();
375    if n == 0 || m == 0 || nbasis < 2 || argvals.len() != m {
376        return None;
377    }
378
379    // For order 4 B-splines: nbasis = nknots + order, so nknots = nbasis - 4
380    let basis = bspline_basis(argvals, nbasis.saturating_sub(4).max(2), 4);
381    let actual_nbasis = basis.len() / m;
382    let b_mat = DMatrix::from_column_slice(m, actual_nbasis, &basis);
383
384    let d = difference_matrix(actual_nbasis, order);
385    let penalty = &d.transpose() * &d;
386
387    let btb = &b_mat.transpose() * &b_mat;
388    let btb_penalized = &btb + lambda * &penalty;
389
390    let btb_inv = svd_pseudoinverse(&btb_penalized)?;
391    let proj = &btb_inv * b_mat.transpose();
392    let h_mat = &b_mat * &proj;
393    let edf: f64 = (0..m).map(|i| h_mat[(i, i)]).sum();
394
395    let mut all_coefs = FdMatrix::zeros(n, actual_nbasis);
396    let mut all_fitted = FdMatrix::zeros(n, m);
397    let mut total_rss = 0.0;
398
399    for i in 0..n {
400        let curve: Vec<f64> = (0..m).map(|j| data[(i, j)]).collect();
401        let curve_vec = DVector::from_vec(curve.clone());
402
403        let bt_y = b_mat.transpose() * &curve_vec;
404        let coefs = &btb_inv * bt_y;
405
406        for k in 0..actual_nbasis {
407            all_coefs[(i, k)] = coefs[k];
408        }
409
410        let fitted = &b_mat * &coefs;
411        for j in 0..m {
412            all_fitted[(i, j)] = fitted[j];
413            let resid = curve[j] - fitted[j];
414            total_rss += resid * resid;
415        }
416    }
417
418    let total_points = (n * m) as f64;
419
420    let gcv_denom = 1.0 - edf / m as f64;
421    let gcv = if gcv_denom.abs() > 1e-10 {
422        (total_rss / total_points) / (gcv_denom * gcv_denom)
423    } else {
424        f64::INFINITY
425    };
426
427    let mse = total_rss / total_points;
428    let aic = total_points * mse.ln() + 2.0 * edf;
429    let bic = total_points * mse.ln() + total_points.ln() * edf;
430
431    Some(PsplineFitResult {
432        coefficients: all_coefs,
433        fitted: all_fitted,
434        edf,
435        rss: total_rss,
436        gcv,
437        aic,
438        bic,
439        n_basis: actual_nbasis,
440    })
441}
442
443/// Result of Fourier basis fitting.
444pub struct FourierFitResult {
445    /// Coefficient matrix (n x nbasis)
446    pub coefficients: FdMatrix,
447    /// Fitted values (n x m)
448    pub fitted: FdMatrix,
449    /// Effective degrees of freedom (equals nbasis for unpenalized fit)
450    pub edf: f64,
451    /// Residual sum of squares
452    pub rss: f64,
453    /// GCV score
454    pub gcv: f64,
455    /// AIC
456    pub aic: f64,
457    /// BIC
458    pub bic: f64,
459    /// Number of basis functions
460    pub n_basis: usize,
461}
462
463/// Compute GCV, AIC, and BIC model selection criteria.
464fn compute_fit_criteria(total_rss: f64, total_points: f64, edf: f64, m: usize) -> (f64, f64, f64) {
465    let gcv_denom = 1.0 - edf / m as f64;
466    let gcv = if gcv_denom.abs() > 1e-10 {
467        (total_rss / total_points) / (gcv_denom * gcv_denom)
468    } else {
469        f64::INFINITY
470    };
471    let mse = total_rss / total_points;
472    let aic = total_points * mse.ln() + 2.0 * edf;
473    let bic = total_points * mse.ln() + total_points.ln() * edf;
474    (gcv, aic, bic)
475}
476
477/// Fit Fourier basis to functional data using least squares.
478///
479/// Projects data onto Fourier basis and reconstructs fitted values.
480/// Unlike P-splines, this uses unpenalized least squares projection.
481///
482/// # Arguments
483/// * `data` - Column-major FdMatrix (n x m)
484/// * `argvals` - Evaluation points
485/// * `nbasis` - Number of Fourier basis functions (should be odd: 1 constant + pairs of sin/cos)
486///
487/// # Returns
488/// FourierFitResult with coefficients, fitted values, and model selection criteria
489pub fn fourier_fit_1d(data: &FdMatrix, argvals: &[f64], nbasis: usize) -> Option<FourierFitResult> {
490    let n = data.nrows();
491    let m = data.ncols();
492    if n == 0 || m == 0 || nbasis < 3 || argvals.len() != m {
493        return None;
494    }
495
496    // Ensure nbasis is odd (1 constant + pairs of sin/cos)
497    let nbasis = if nbasis % 2 == 0 { nbasis + 1 } else { nbasis };
498
499    let basis = fourier_basis(argvals, nbasis);
500    let actual_nbasis = basis.len() / m;
501    let b_mat = DMatrix::from_column_slice(m, actual_nbasis, &basis);
502
503    let btb = &b_mat.transpose() * &b_mat;
504    let btb_inv = svd_pseudoinverse(&btb)?;
505    let proj = &btb_inv * b_mat.transpose();
506    let h_mat = &b_mat * &proj;
507    let edf: f64 = (0..m).map(|i| h_mat[(i, i)]).sum();
508
509    let mut all_coefs = FdMatrix::zeros(n, actual_nbasis);
510    let mut all_fitted = FdMatrix::zeros(n, m);
511    let mut total_rss = 0.0;
512
513    for i in 0..n {
514        let curve: Vec<f64> = (0..m).map(|j| data[(i, j)]).collect();
515        let curve_vec = DVector::from_vec(curve.clone());
516
517        let bt_y = b_mat.transpose() * &curve_vec;
518        let coefs = &btb_inv * bt_y;
519
520        for k in 0..actual_nbasis {
521            all_coefs[(i, k)] = coefs[k];
522        }
523
524        let fitted = &b_mat * &coefs;
525        for j in 0..m {
526            all_fitted[(i, j)] = fitted[j];
527            let resid = curve[j] - fitted[j];
528            total_rss += resid * resid;
529        }
530    }
531
532    let total_points = (n * m) as f64;
533    let (gcv, aic, bic) = compute_fit_criteria(total_rss, total_points, edf, m);
534
535    Some(FourierFitResult {
536        coefficients: all_coefs,
537        fitted: all_fitted,
538        edf,
539        rss: total_rss,
540        gcv,
541        aic,
542        bic,
543        n_basis: actual_nbasis,
544    })
545}
546
547/// Select optimal number of Fourier basis functions using GCV.
548///
549/// Performs grid search over nbasis values and returns the one with minimum GCV.
550///
551/// # Arguments
552/// * `data` - Column-major FdMatrix (n x m)
553/// * `argvals` - Evaluation points
554/// * `min_nbasis` - Minimum number of basis functions to try
555/// * `max_nbasis` - Maximum number of basis functions to try
556///
557/// # Returns
558/// Optimal number of basis functions
559pub fn select_fourier_nbasis_gcv(
560    data: &FdMatrix,
561    argvals: &[f64],
562    min_nbasis: usize,
563    max_nbasis: usize,
564) -> usize {
565    let m = data.ncols();
566    let min_nb = min_nbasis.max(3);
567    // Ensure max doesn't exceed m (can't have more parameters than data points)
568    let max_nb = max_nbasis.min(m / 2);
569
570    if max_nb <= min_nb {
571        return min_nb;
572    }
573
574    let mut best_nbasis = min_nb;
575    let mut best_gcv = f64::INFINITY;
576
577    // Test odd values only (1 constant + pairs of sin/cos)
578    let mut nbasis = if min_nb % 2 == 0 { min_nb + 1 } else { min_nb };
579    while nbasis <= max_nb {
580        if let Some(result) = fourier_fit_1d(data, argvals, nbasis) {
581            if result.gcv < best_gcv && result.gcv.is_finite() {
582                best_gcv = result.gcv;
583                best_nbasis = nbasis;
584            }
585        }
586        nbasis += 2;
587    }
588
589    best_nbasis
590}
591
592/// Result of automatic basis selection for a single curve.
593#[derive(Clone)]
594pub struct SingleCurveSelection {
595    /// Selected basis type: 0 = P-spline, 1 = Fourier
596    pub basis_type: i32,
597    /// Selected number of basis functions
598    pub nbasis: usize,
599    /// Best criterion score (GCV, AIC, or BIC)
600    pub score: f64,
601    /// Coefficients for the selected basis
602    pub coefficients: Vec<f64>,
603    /// Fitted values
604    pub fitted: Vec<f64>,
605    /// Effective degrees of freedom
606    pub edf: f64,
607    /// Whether seasonal pattern was detected (if use_seasonal_hint)
608    pub seasonal_detected: bool,
609    /// Lambda value (for P-spline, NaN for Fourier)
610    pub lambda: f64,
611}
612
613/// Result of automatic basis selection for all curves.
614pub struct BasisAutoSelectionResult {
615    /// Per-curve selection results
616    pub selections: Vec<SingleCurveSelection>,
617    /// Criterion used (0=GCV, 1=AIC, 2=BIC)
618    pub criterion: i32,
619}
620
621/// Detect if a curve has seasonal/periodic pattern using FFT.
622///
623/// Returns true if the peak power in the periodogram is significantly
624/// above the mean power level.
625fn detect_seasonality_fft(curve: &[f64]) -> bool {
626    use rustfft::{num_complex::Complex, FftPlanner};
627
628    let n = curve.len();
629    if n < 8 {
630        return false;
631    }
632
633    // Remove mean
634    let mean: f64 = curve.iter().sum::<f64>() / n as f64;
635    let mut input: Vec<Complex<f64>> = curve.iter().map(|&x| Complex::new(x - mean, 0.0)).collect();
636
637    let mut planner = FftPlanner::new();
638    let fft = planner.plan_fft_forward(n);
639    fft.process(&mut input);
640
641    // Compute power spectrum (skip DC component and Nyquist)
642    let powers: Vec<f64> = input[1..n / 2].iter().map(|c| c.norm_sqr()).collect();
643
644    if powers.is_empty() {
645        return false;
646    }
647
648    let max_power = powers.iter().cloned().fold(0.0_f64, f64::max);
649    let mean_power = powers.iter().sum::<f64>() / powers.len() as f64;
650
651    // Seasonal if peak is significantly above mean
652    max_power > 2.0 * mean_power
653}
654
655/// Fit a single curve with Fourier basis and compute criterion score.
656fn fit_curve_fourier(
657    curve: &[f64],
658    m: usize,
659    argvals: &[f64],
660    nbasis: usize,
661    criterion: i32,
662) -> Option<(f64, Vec<f64>, Vec<f64>, f64)> {
663    let nbasis = if nbasis % 2 == 0 { nbasis + 1 } else { nbasis };
664
665    let basis = fourier_basis(argvals, nbasis);
666    let actual_nbasis = basis.len() / m;
667    let b_mat = DMatrix::from_column_slice(m, actual_nbasis, &basis);
668
669    let btb = &b_mat.transpose() * &b_mat;
670    let btb_inv = svd_pseudoinverse(&btb)?;
671    let proj = &btb_inv * b_mat.transpose();
672    let h_mat = &b_mat * &proj;
673    let edf: f64 = (0..m).map(|i| h_mat[(i, i)]).sum();
674
675    let curve_vec = DVector::from_column_slice(curve);
676    let coefs = &btb_inv * (b_mat.transpose() * &curve_vec);
677    let fitted = &b_mat * &coefs;
678
679    let rss: f64 = (0..m).map(|j| (curve[j] - fitted[j]).powi(2)).sum();
680    let score = compute_model_criterion(rss, m as f64, edf, criterion);
681
682    let coef_vec: Vec<f64> = (0..actual_nbasis).map(|k| coefs[k]).collect();
683    let fitted_vec: Vec<f64> = (0..m).map(|j| fitted[j]).collect();
684
685    Some((score, coef_vec, fitted_vec, edf))
686}
687
688/// Fit a single curve with P-spline basis and compute criterion score.
689fn fit_curve_pspline(
690    curve: &[f64],
691    m: usize,
692    argvals: &[f64],
693    nbasis: usize,
694    lambda: f64,
695    order: usize,
696    criterion: i32,
697) -> Option<(f64, Vec<f64>, Vec<f64>, f64)> {
698    let basis = bspline_basis(argvals, nbasis.saturating_sub(4).max(2), 4);
699    let actual_nbasis = basis.len() / m;
700    let b_mat = DMatrix::from_column_slice(m, actual_nbasis, &basis);
701
702    let d = difference_matrix(actual_nbasis, order);
703    let penalty = &d.transpose() * &d;
704    let btb = &b_mat.transpose() * &b_mat;
705    let btb_penalized = &btb + lambda * &penalty;
706
707    let btb_inv = svd_pseudoinverse(&btb_penalized)?;
708    let proj = &btb_inv * b_mat.transpose();
709    let h_mat = &b_mat * &proj;
710    let edf: f64 = (0..m).map(|i| h_mat[(i, i)]).sum();
711
712    let curve_vec = DVector::from_column_slice(curve);
713    let coefs = &btb_inv * (b_mat.transpose() * &curve_vec);
714    let fitted = &b_mat * &coefs;
715
716    let rss: f64 = (0..m).map(|j| (curve[j] - fitted[j]).powi(2)).sum();
717    let score = compute_model_criterion(rss, m as f64, edf, criterion);
718
719    let coef_vec: Vec<f64> = (0..actual_nbasis).map(|k| coefs[k]).collect();
720    let fitted_vec: Vec<f64> = (0..m).map(|j| fitted[j]).collect();
721
722    Some((score, coef_vec, fitted_vec, edf))
723}
724
725/// Result of a basis search for a single curve.
726struct BasisSearchResult {
727    score: f64,
728    nbasis: usize,
729    coefs: Vec<f64>,
730    fitted: Vec<f64>,
731    edf: f64,
732    lambda: f64,
733}
734
735/// Search over Fourier basis sizes for the best fit.
736fn search_fourier_basis(
737    curve: &[f64],
738    m: usize,
739    argvals: &[f64],
740    fourier_min: usize,
741    fourier_max: usize,
742    seasonal: bool,
743    criterion: i32,
744) -> Option<BasisSearchResult> {
745    let fourier_start = if seasonal {
746        fourier_min.max(5)
747    } else {
748        fourier_min
749    };
750    let mut nb = if fourier_start % 2 == 0 {
751        fourier_start + 1
752    } else {
753        fourier_start
754    };
755
756    let mut best: Option<BasisSearchResult> = None;
757    while nb <= fourier_max {
758        if let Some((score, coefs, fitted, edf)) =
759            fit_curve_fourier(curve, m, argvals, nb, criterion)
760        {
761            if score.is_finite() && best.as_ref().map_or(true, |b| score < b.score) {
762                best = Some(BasisSearchResult {
763                    score,
764                    nbasis: nb,
765                    coefs,
766                    fitted,
767                    edf,
768                    lambda: f64::NAN,
769                });
770            }
771        }
772        nb += 2;
773    }
774    best
775}
776
777/// Try a single P-spline fit and update best if it improves the score.
778fn try_pspline_fit_update(
779    curve: &[f64],
780    m: usize,
781    argvals: &[f64],
782    nb: usize,
783    lam: f64,
784    criterion: i32,
785    best: &mut Option<BasisSearchResult>,
786) {
787    if let Some((score, coefs, fitted, edf)) =
788        fit_curve_pspline(curve, m, argvals, nb, lam, 2, criterion)
789    {
790        if score.is_finite() && best.as_ref().map_or(true, |b| score < b.score) {
791            *best = Some(BasisSearchResult {
792                score,
793                nbasis: nb,
794                coefs,
795                fitted,
796                edf,
797                lambda: lam,
798            });
799        }
800    }
801}
802
803/// Search over P-spline basis sizes (and optionally lambda) for the best fit.
804fn search_pspline_basis(
805    curve: &[f64],
806    m: usize,
807    argvals: &[f64],
808    pspline_min: usize,
809    pspline_max: usize,
810    lambda_grid: &[f64],
811    auto_lambda: bool,
812    lambda: f64,
813    criterion: i32,
814) -> Option<BasisSearchResult> {
815    let mut best: Option<BasisSearchResult> = None;
816    for nb in pspline_min..=pspline_max {
817        let lambdas: Box<dyn Iterator<Item = f64>> = if auto_lambda {
818            Box::new(lambda_grid.iter().copied())
819        } else {
820            Box::new(std::iter::once(lambda))
821        };
822        for lam in lambdas {
823            try_pspline_fit_update(curve, m, argvals, nb, lam, criterion, &mut best);
824        }
825    }
826    best
827}
828
829/// Select optimal basis type and parameters for each curve individually.
830///
831/// This function compares Fourier and P-spline bases for each curve,
832/// selecting the optimal basis type and number of basis functions using
833/// model selection criteria (GCV, AIC, or BIC).
834///
835/// # Arguments
836/// * `data` - Column-major FdMatrix (n x m)
837/// * `argvals` - Evaluation points
838/// * `criterion` - Model selection criterion: 0=GCV, 1=AIC, 2=BIC
839/// * `nbasis_min` - Minimum number of basis functions (0 for auto)
840/// * `nbasis_max` - Maximum number of basis functions (0 for auto)
841/// * `lambda_pspline` - Smoothing parameter for P-spline (negative for auto-select)
842/// * `use_seasonal_hint` - Whether to use FFT to detect seasonality
843///
844/// # Returns
845/// BasisAutoSelectionResult with per-curve selections
846pub fn select_basis_auto_1d(
847    data: &FdMatrix,
848    argvals: &[f64],
849    criterion: i32,
850    nbasis_min: usize,
851    nbasis_max: usize,
852    lambda_pspline: f64,
853    use_seasonal_hint: bool,
854) -> BasisAutoSelectionResult {
855    let n = data.nrows();
856    let m = data.ncols();
857    let fourier_min = if nbasis_min > 0 { nbasis_min.max(3) } else { 3 };
858    let fourier_max = if nbasis_max > 0 {
859        nbasis_max.min(m / 3).min(25)
860    } else {
861        (m / 3).min(25)
862    };
863    let pspline_min = if nbasis_min > 0 { nbasis_min.max(6) } else { 6 };
864    let pspline_max = if nbasis_max > 0 {
865        nbasis_max.min(m / 2).min(40)
866    } else {
867        (m / 2).min(40)
868    };
869
870    let lambda_grid = [0.001, 0.01, 0.1, 1.0, 10.0, 100.0];
871    let auto_lambda = lambda_pspline < 0.0;
872
873    let selections: Vec<SingleCurveSelection> = iter_maybe_parallel!(0..n)
874        .map(|i| {
875            let curve: Vec<f64> = (0..m).map(|j| data[(i, j)]).collect();
876            let seasonal_detected = if use_seasonal_hint {
877                detect_seasonality_fft(&curve)
878            } else {
879                false
880            };
881
882            let fourier_best = search_fourier_basis(
883                &curve,
884                m,
885                argvals,
886                fourier_min,
887                fourier_max,
888                seasonal_detected,
889                criterion,
890            );
891            let pspline_best = search_pspline_basis(
892                &curve,
893                m,
894                argvals,
895                pspline_min,
896                pspline_max,
897                &lambda_grid,
898                auto_lambda,
899                lambda_pspline,
900                criterion,
901            );
902
903            // Pick the best overall result
904            let (basis_type, result) = match (fourier_best, pspline_best) {
905                (Some(f), Some(p)) => {
906                    if f.score <= p.score {
907                        (1i32, f)
908                    } else {
909                        (0i32, p)
910                    }
911                }
912                (Some(f), None) => (1, f),
913                (None, Some(p)) => (0, p),
914                (None, None) => {
915                    return SingleCurveSelection {
916                        basis_type: 0,
917                        nbasis: pspline_min,
918                        score: f64::INFINITY,
919                        coefficients: Vec::new(),
920                        fitted: Vec::new(),
921                        edf: 0.0,
922                        seasonal_detected,
923                        lambda: f64::NAN,
924                    };
925                }
926            };
927
928            SingleCurveSelection {
929                basis_type,
930                nbasis: result.nbasis,
931                score: result.score,
932                coefficients: result.coefs,
933                fitted: result.fitted,
934                edf: result.edf,
935                seasonal_detected,
936                lambda: result.lambda,
937            }
938        })
939        .collect();
940
941    BasisAutoSelectionResult {
942        selections,
943        criterion,
944    }
945}
946
947#[cfg(test)]
948mod tests {
949    use super::*;
950    use std::f64::consts::PI;
951
952    /// Generate a uniform grid of points
953    fn uniform_grid(n: usize) -> Vec<f64> {
954        (0..n).map(|i| i as f64 / (n - 1) as f64).collect()
955    }
956
957    /// Generate sine wave data
958    fn sine_wave(t: &[f64], freq: f64) -> Vec<f64> {
959        t.iter().map(|&ti| (2.0 * PI * freq * ti).sin()).collect()
960    }
961
962    // ============== B-spline basis tests ==============
963
964    #[test]
965    fn test_bspline_basis_dimensions() {
966        let t = uniform_grid(50);
967        let nknots = 10;
968        let order = 4;
969        let basis = bspline_basis(&t, nknots, order);
970
971        let expected_nbasis = nknots + order;
972        assert_eq!(basis.len(), t.len() * expected_nbasis);
973    }
974
975    #[test]
976    fn test_bspline_basis_partition_of_unity() {
977        // B-splines should sum to 1 at each point (partition of unity)
978        let t = uniform_grid(50);
979        let nknots = 8;
980        let order = 4;
981        let basis = bspline_basis(&t, nknots, order);
982
983        let nbasis = nknots + order;
984        for i in 0..t.len() {
985            let sum: f64 = (0..nbasis).map(|j| basis[i + j * t.len()]).sum();
986            assert!(
987                (sum - 1.0).abs() < 1e-10,
988                "B-spline partition of unity failed at point {}: sum = {}",
989                i,
990                sum
991            );
992        }
993    }
994
995    #[test]
996    fn test_bspline_basis_non_negative() {
997        let t = uniform_grid(50);
998        let basis = bspline_basis(&t, 8, 4);
999
1000        for &val in &basis {
1001            assert!(val >= -1e-10, "B-spline values should be non-negative");
1002        }
1003    }
1004
1005    #[test]
1006    fn test_bspline_basis_boundary() {
1007        // Test that basis functions work at boundary points
1008        let t = vec![0.0, 0.5, 1.0];
1009        let basis = bspline_basis(&t, 5, 4);
1010
1011        // Should have valid output (no NaN or Inf)
1012        for &val in &basis {
1013            assert!(val.is_finite(), "B-spline should produce finite values");
1014        }
1015    }
1016
1017    // ============== Fourier basis tests ==============
1018
1019    #[test]
1020    fn test_fourier_basis_dimensions() {
1021        let t = uniform_grid(50);
1022        let nbasis = 7;
1023        let basis = fourier_basis(&t, nbasis);
1024
1025        assert_eq!(basis.len(), t.len() * nbasis);
1026    }
1027
1028    #[test]
1029    fn test_fourier_basis_constant_first_column() {
1030        let t = uniform_grid(50);
1031        let nbasis = 7;
1032        let basis = fourier_basis(&t, nbasis);
1033
1034        // First column should be constant (DC component = 1)
1035        let first_val = basis[0];
1036        for i in 0..t.len() {
1037            assert!(
1038                (basis[i] - first_val).abs() < 1e-10,
1039                "First Fourier column should be constant"
1040            );
1041        }
1042    }
1043
1044    #[test]
1045    fn test_fourier_basis_sin_cos_range() {
1046        let t = uniform_grid(100);
1047        let nbasis = 11;
1048        let basis = fourier_basis(&t, nbasis);
1049
1050        // Sin and cos values should be in [-1, 1]
1051        for &val in &basis {
1052            assert!((-1.0 - 1e-10..=1.0 + 1e-10).contains(&val));
1053        }
1054    }
1055
1056    #[test]
1057    fn test_fourier_basis_with_period() {
1058        let t = uniform_grid(100);
1059        let nbasis = 5;
1060        let period = 0.5;
1061        let basis = fourier_basis_with_period(&t, nbasis, period);
1062
1063        assert_eq!(basis.len(), t.len() * nbasis);
1064        // Verify first column is constant
1065        let first_val = basis[0];
1066        for i in 0..t.len() {
1067            assert!((basis[i] - first_val).abs() < 1e-10);
1068        }
1069    }
1070
1071    #[test]
1072    fn test_fourier_basis_period_affects_frequency() {
1073        let t = uniform_grid(100);
1074        let nbasis = 5;
1075
1076        let basis1 = fourier_basis_with_period(&t, nbasis, 1.0);
1077        let basis2 = fourier_basis_with_period(&t, nbasis, 0.5);
1078
1079        // Different periods should give different basis matrices
1080        let n = t.len();
1081        let mut any_different = false;
1082        for i in 0..n {
1083            // Compare second column (first sin term)
1084            if (basis1[i + n] - basis2[i + n]).abs() > 1e-10 {
1085                any_different = true;
1086                break;
1087            }
1088        }
1089        assert!(
1090            any_different,
1091            "Different periods should produce different bases"
1092        );
1093    }
1094
1095    // ============== Difference matrix tests ==============
1096
1097    #[test]
1098    fn test_difference_matrix_order_zero() {
1099        let d = difference_matrix(5, 0);
1100        assert_eq!(d.nrows(), 5);
1101        assert_eq!(d.ncols(), 5);
1102
1103        // Should be identity matrix
1104        for i in 0..5 {
1105            for j in 0..5 {
1106                let expected = if i == j { 1.0 } else { 0.0 };
1107                assert!((d[(i, j)] - expected).abs() < 1e-10);
1108            }
1109        }
1110    }
1111
1112    #[test]
1113    fn test_difference_matrix_first_order() {
1114        let d = difference_matrix(5, 1);
1115        assert_eq!(d.nrows(), 4);
1116        assert_eq!(d.ncols(), 5);
1117
1118        // First order differences: D * [1,2,3,4,5] = [1,1,1,1]
1119        let x = DVector::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1120        let dx = &d * x;
1121        for i in 0..4 {
1122            assert!((dx[i] - 1.0).abs() < 1e-10);
1123        }
1124    }
1125
1126    #[test]
1127    fn test_difference_matrix_second_order() {
1128        let d = difference_matrix(5, 2);
1129        assert_eq!(d.nrows(), 3);
1130        assert_eq!(d.ncols(), 5);
1131
1132        // Second order differences of linear: D^2 * [1,2,3,4,5] = [0,0,0]
1133        let x = DVector::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1134        let dx = &d * x;
1135        for i in 0..3 {
1136            assert!(dx[i].abs() < 1e-10, "Second diff of linear should be zero");
1137        }
1138    }
1139
1140    #[test]
1141    fn test_difference_matrix_quadratic() {
1142        let d = difference_matrix(5, 2);
1143
1144        // Second order differences of quadratic: D^2 * [1,4,9,16,25] = [2,2,2]
1145        let x = DVector::from_vec(vec![1.0, 4.0, 9.0, 16.0, 25.0]);
1146        let dx = &d * x;
1147        for i in 0..3 {
1148            assert!(
1149                (dx[i] - 2.0).abs() < 1e-10,
1150                "Second diff of squares should be 2"
1151            );
1152        }
1153    }
1154
1155    // ============== Basis projection tests ==============
1156
1157    /// Create an FdMatrix from per-curve data (each curve is one row).
1158    /// The input flat data is in "row of rows" order: all values for curve 0, then curve 1, etc.
1159    /// We need to convert to column-major layout.
1160    fn make_matrix(flat_row_major: &[f64], n: usize, m: usize) -> FdMatrix {
1161        let mut col_major = vec![0.0; n * m];
1162        for i in 0..n {
1163            for j in 0..m {
1164                col_major[i + j * n] = flat_row_major[i * m + j];
1165            }
1166        }
1167        FdMatrix::from_column_major(col_major, n, m).unwrap()
1168    }
1169
1170    #[test]
1171    fn test_fdata_to_basis_1d_bspline() {
1172        let t = uniform_grid(50);
1173        let n = 5;
1174        let m = t.len();
1175
1176        // Create simple data (n curves, each shifted linear)
1177        let flat: Vec<f64> = (0..n)
1178            .flat_map(|i| t.iter().map(move |&ti| ti + i as f64 * 0.1))
1179            .collect();
1180        let data = make_matrix(&flat, n, m);
1181
1182        let result = fdata_to_basis_1d(&data, &t, 10, 0);
1183        assert!(result.is_some());
1184
1185        let res = result.unwrap();
1186        assert!(res.n_basis > 0);
1187        assert_eq!(res.coefficients.nrows(), n);
1188        assert_eq!(res.coefficients.ncols(), res.n_basis);
1189    }
1190
1191    #[test]
1192    fn test_fdata_to_basis_1d_fourier() {
1193        let t = uniform_grid(50);
1194        let n = 5;
1195        let m = t.len();
1196
1197        // Create sine wave data
1198        let flat: Vec<f64> = (0..n).flat_map(|_| sine_wave(&t, 2.0)).collect();
1199        let data = make_matrix(&flat, n, m);
1200
1201        let result = fdata_to_basis_1d(&data, &t, 7, 1);
1202        assert!(result.is_some());
1203
1204        let res = result.unwrap();
1205        assert_eq!(res.n_basis, 7);
1206    }
1207
1208    #[test]
1209    fn test_fdata_to_basis_1d_invalid_input() {
1210        let t = uniform_grid(50);
1211
1212        // Empty data
1213        let empty = FdMatrix::zeros(0, 50);
1214        let result = fdata_to_basis_1d(&empty, &t, 10, 0);
1215        assert!(result.is_none());
1216
1217        // nbasis too small
1218        let data = FdMatrix::zeros(1, 50);
1219        let result = fdata_to_basis_1d(&data, &t, 1, 0);
1220        assert!(result.is_none());
1221    }
1222
1223    #[test]
1224    fn test_basis_roundtrip() {
1225        let t = uniform_grid(100);
1226        let n = 1;
1227        let m = t.len();
1228
1229        // Create smooth sine wave data (Fourier basis should represent exactly)
1230        let raw = sine_wave(&t, 1.0);
1231        let data = FdMatrix::from_column_major(raw.clone(), n, m).unwrap();
1232
1233        // Project to Fourier basis with enough terms
1234        let proj = fdata_to_basis_1d(&data, &t, 5, 1).unwrap();
1235
1236        // Reconstruct
1237        let reconstructed = basis_to_fdata_1d(&proj.coefficients, &t, proj.n_basis, 1);
1238
1239        // Should approximately match original for a simple sine wave
1240        let mut max_error = 0.0;
1241        for j in 0..m {
1242            let err = (raw[j] - reconstructed[(0, j)]).abs();
1243            if err > max_error {
1244                max_error = err;
1245            }
1246        }
1247        assert!(max_error < 0.5, "Roundtrip error too large: {}", max_error);
1248    }
1249
1250    #[test]
1251    fn test_basis_to_fdata_empty_input() {
1252        let empty = FdMatrix::zeros(0, 0);
1253        let result = basis_to_fdata_1d(&empty, &[], 5, 0);
1254        assert!(result.is_empty());
1255    }
1256
1257    // ============== P-spline fitting tests ==============
1258
1259    #[test]
1260    fn test_pspline_fit_1d_basic() {
1261        let t = uniform_grid(50);
1262        let n = 3;
1263        let m = t.len();
1264
1265        // Create noisy data
1266        let flat: Vec<f64> = (0..n)
1267            .flat_map(|i| {
1268                t.iter()
1269                    .enumerate()
1270                    .map(move |(j, &ti)| (2.0 * PI * ti).sin() + 0.1 * (i * j) as f64 % 1.0)
1271            })
1272            .collect();
1273        let data = make_matrix(&flat, n, m);
1274
1275        let result = pspline_fit_1d(&data, &t, 15, 1.0, 2);
1276        assert!(result.is_some());
1277
1278        let res = result.unwrap();
1279        assert!(res.n_basis > 0);
1280        assert_eq!(res.fitted.nrows(), n);
1281        assert_eq!(res.fitted.ncols(), m);
1282        assert!(res.rss >= 0.0);
1283        assert!(res.edf > 0.0);
1284        assert!(res.gcv.is_finite());
1285    }
1286
1287    #[test]
1288    fn test_pspline_fit_1d_smoothness() {
1289        let t = uniform_grid(50);
1290        let n = 1;
1291        let m = t.len();
1292
1293        // Create noisy sine wave
1294        let raw: Vec<f64> = t
1295            .iter()
1296            .enumerate()
1297            .map(|(i, &ti)| (2.0 * PI * ti).sin() + 0.3 * ((i * 17) % 100) as f64 / 100.0)
1298            .collect();
1299        let data = FdMatrix::from_column_major(raw, n, m).unwrap();
1300
1301        let low_lambda = pspline_fit_1d(&data, &t, 15, 0.01, 2).unwrap();
1302        let high_lambda = pspline_fit_1d(&data, &t, 15, 100.0, 2).unwrap();
1303
1304        // Higher lambda should give lower edf (more smoothing)
1305        assert!(high_lambda.edf < low_lambda.edf);
1306    }
1307
1308    #[test]
1309    fn test_pspline_fit_1d_invalid_input() {
1310        let t = uniform_grid(50);
1311        let empty = FdMatrix::zeros(0, 50);
1312        let result = pspline_fit_1d(&empty, &t, 15, 1.0, 2);
1313        assert!(result.is_none());
1314    }
1315
1316    // ============== Fourier fitting tests ==============
1317
1318    #[test]
1319    fn test_fourier_fit_1d_sine_wave() {
1320        let t = uniform_grid(100);
1321        let n = 1;
1322        let m = t.len();
1323
1324        // Create pure sine wave
1325        let raw = sine_wave(&t, 2.0);
1326        let data = FdMatrix::from_column_major(raw, n, m).unwrap();
1327
1328        let result = fourier_fit_1d(&data, &t, 11);
1329        assert!(result.is_some());
1330
1331        let res = result.unwrap();
1332        assert!(res.rss < 1e-6, "Pure sine should have near-zero RSS");
1333    }
1334
1335    #[test]
1336    fn test_fourier_fit_1d_makes_nbasis_odd() {
1337        let t = uniform_grid(50);
1338        let raw = sine_wave(&t, 1.0);
1339        let data = FdMatrix::from_column_major(raw, 1, t.len()).unwrap();
1340
1341        // Pass even nbasis
1342        let result = fourier_fit_1d(&data, &t, 6);
1343        assert!(result.is_some());
1344
1345        // Should have been adjusted to odd
1346        let res = result.unwrap();
1347        assert!(res.n_basis % 2 == 1);
1348    }
1349
1350    #[test]
1351    fn test_fourier_fit_1d_criteria() {
1352        let t = uniform_grid(50);
1353        let raw = sine_wave(&t, 2.0);
1354        let data = FdMatrix::from_column_major(raw, 1, t.len()).unwrap();
1355
1356        let result = fourier_fit_1d(&data, &t, 9).unwrap();
1357
1358        // All criteria should be finite
1359        assert!(result.gcv.is_finite());
1360        assert!(result.aic.is_finite());
1361        assert!(result.bic.is_finite());
1362    }
1363
1364    #[test]
1365    fn test_fourier_fit_1d_invalid_nbasis() {
1366        let t = uniform_grid(50);
1367        let raw = sine_wave(&t, 1.0);
1368        let data = FdMatrix::from_column_major(raw, 1, t.len()).unwrap();
1369
1370        // nbasis < 3 should return None
1371        let result = fourier_fit_1d(&data, &t, 2);
1372        assert!(result.is_none());
1373    }
1374
1375    // ============== GCV selection tests ==============
1376
1377    #[test]
1378    fn test_select_fourier_nbasis_gcv_range() {
1379        let t = uniform_grid(100);
1380        let raw = sine_wave(&t, 3.0);
1381        let data = FdMatrix::from_column_major(raw, 1, t.len()).unwrap();
1382
1383        let best = select_fourier_nbasis_gcv(&data, &t, 3, 15);
1384
1385        assert!((3..=15).contains(&best));
1386        assert!(best % 2 == 1, "Selected nbasis should be odd");
1387    }
1388
1389    #[test]
1390    fn test_select_fourier_nbasis_gcv_respects_min() {
1391        let t = uniform_grid(50);
1392        let raw = sine_wave(&t, 1.0);
1393        let data = FdMatrix::from_column_major(raw, 1, t.len()).unwrap();
1394
1395        let best = select_fourier_nbasis_gcv(&data, &t, 7, 15);
1396        assert!(best >= 7);
1397    }
1398
1399    // ============== Auto selection tests ==============
1400
1401    #[test]
1402    fn test_select_basis_auto_1d_returns_results() {
1403        let t = uniform_grid(50);
1404        let n = 3;
1405        let m = t.len();
1406
1407        let flat: Vec<f64> = (0..n).flat_map(|i| sine_wave(&t, 1.0 + i as f64)).collect();
1408        let data = make_matrix(&flat, n, m);
1409
1410        let result = select_basis_auto_1d(&data, &t, 0, 5, 15, 1.0, false);
1411
1412        assert_eq!(result.selections.len(), n);
1413        for sel in &result.selections {
1414            assert!(sel.nbasis >= 3);
1415            assert!(!sel.coefficients.is_empty());
1416            assert_eq!(sel.fitted.len(), m);
1417        }
1418    }
1419
1420    #[test]
1421    fn test_select_basis_auto_1d_seasonal_hint() {
1422        let t = uniform_grid(100);
1423        let n = 1;
1424        let m = t.len();
1425
1426        // Strong seasonal pattern
1427        let raw = sine_wave(&t, 5.0);
1428        let data = FdMatrix::from_column_major(raw, n, m).unwrap();
1429
1430        let result = select_basis_auto_1d(&data, &t, 0, 0, 0, -1.0, true);
1431
1432        assert_eq!(result.selections.len(), 1);
1433        assert!(result.selections[0].seasonal_detected);
1434    }
1435
1436    #[test]
1437    fn test_select_basis_auto_1d_non_seasonal() {
1438        let t = uniform_grid(50);
1439        let n = 1;
1440        let m = t.len();
1441
1442        // Constant data (definitely not seasonal)
1443        let raw: Vec<f64> = vec![1.0; m];
1444        let data = FdMatrix::from_column_major(raw, n, m).unwrap();
1445
1446        let result = select_basis_auto_1d(&data, &t, 0, 0, 0, -1.0, true);
1447
1448        // Constant data shouldn't be detected as seasonal
1449        assert!(!result.selections[0].seasonal_detected);
1450    }
1451
1452    #[test]
1453    fn test_select_basis_auto_1d_criterion_options() {
1454        let t = uniform_grid(50);
1455        let raw = sine_wave(&t, 2.0);
1456        let data = FdMatrix::from_column_major(raw, 1, t.len()).unwrap();
1457
1458        // Test all three criteria
1459        let gcv_result = select_basis_auto_1d(&data, &t, 0, 0, 0, 1.0, false);
1460        let aic_result = select_basis_auto_1d(&data, &t, 1, 0, 0, 1.0, false);
1461        let bic_result = select_basis_auto_1d(&data, &t, 2, 0, 0, 1.0, false);
1462
1463        assert_eq!(gcv_result.criterion, 0);
1464        assert_eq!(aic_result.criterion, 1);
1465        assert_eq!(bic_result.criterion, 2);
1466    }
1467}