fdars_core/
basis.rs

1//! Basis representation functions for functional data.
2//!
3//! This module provides B-spline and Fourier basis expansions for representing
4//! functional data in a finite-dimensional basis.
5
6use crate::iter_maybe_parallel;
7use nalgebra::{DMatrix, DVector, SVD};
8#[cfg(feature = "parallel")]
9use rayon::iter::ParallelIterator;
10use std::f64::consts::PI;
11
12/// Compute B-spline basis matrix for given knots and grid points.
13///
14/// Creates a B-spline basis with uniformly spaced knots extended beyond the data range.
15/// For order k and nknots interior knots, produces nknots + order basis functions.
16pub fn bspline_basis(t: &[f64], nknots: usize, order: usize) -> Vec<f64> {
17    let n = t.len();
18    // Total knots: order (left) + nknots (interior) + order (right) = 2*order + nknots
19    // Number of B-spline basis functions: total_knots - order = nknots + order
20    let nbasis = nknots + order;
21
22    let t_min = t.iter().cloned().fold(f64::INFINITY, f64::min);
23    let t_max = t.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
24    let dt = (t_max - t_min) / (nknots - 1) as f64;
25
26    let mut knots = Vec::with_capacity(nknots + 2 * order);
27    for i in 0..order {
28        knots.push(t_min - (order - i) as f64 * dt);
29    }
30    for i in 0..nknots {
31        knots.push(t_min + i as f64 * dt);
32    }
33    for i in 1..=order {
34        knots.push(t_max + i as f64 * dt);
35    }
36
37    // Index of t_max in knot vector: it's the last interior knot
38    let t_max_knot_idx = order + nknots - 1;
39
40    let mut basis = vec![0.0; n * nbasis];
41
42    for (ti, &t_val) in t.iter().enumerate() {
43        let mut b0 = vec![0.0; knots.len() - 1];
44
45        // Find which interval t_val belongs to
46        // Use half-open intervals [knots[j], knots[j+1]) except at t_max
47        // where we use the closed interval [knots[j], knots[j+1]]
48        for j in 0..(knots.len() - 1) {
49            let in_interval = if j == t_max_knot_idx - 1 {
50                // Last interior interval: use closed [t_max - dt, t_max]
51                t_val >= knots[j] && t_val <= knots[j + 1]
52            } else {
53                // Normal half-open interval [knots[j], knots[j+1])
54                t_val >= knots[j] && t_val < knots[j + 1]
55            };
56
57            if in_interval {
58                b0[j] = 1.0;
59                break;
60            }
61        }
62
63        let mut b = b0;
64        for k in 2..=order {
65            let mut b_new = vec![0.0; knots.len() - k];
66            for j in 0..(knots.len() - k) {
67                let d1 = knots[j + k - 1] - knots[j];
68                let d2 = knots[j + k] - knots[j + 1];
69
70                let left = if d1.abs() > 1e-10 {
71                    (t_val - knots[j]) / d1 * b[j]
72                } else {
73                    0.0
74                };
75                let right = if d2.abs() > 1e-10 {
76                    (knots[j + k] - t_val) / d2 * b[j + 1]
77                } else {
78                    0.0
79                };
80                b_new[j] = left + right;
81            }
82            b = b_new;
83        }
84
85        for j in 0..nbasis {
86            basis[ti + j * n] = b[j];
87        }
88    }
89
90    basis
91}
92
93/// Compute Fourier basis matrix.
94///
95/// The period is automatically set to the range of evaluation points (t_max - t_min).
96/// For explicit period control, use `fourier_basis_with_period`.
97pub fn fourier_basis(t: &[f64], nbasis: usize) -> Vec<f64> {
98    let t_min = t.iter().cloned().fold(f64::INFINITY, f64::min);
99    let t_max = t.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
100    let period = t_max - t_min;
101    fourier_basis_with_period(t, nbasis, period)
102}
103
104/// Compute Fourier basis matrix with explicit period.
105///
106/// This function creates a Fourier basis expansion where the period can be specified
107/// independently of the evaluation range. This is essential for seasonal analysis
108/// where the seasonal period may differ from the observation window.
109///
110/// # Arguments
111/// * `t` - Evaluation points
112/// * `nbasis` - Number of basis functions (1 constant + pairs of sin/cos)
113/// * `period` - The period for the Fourier basis
114///
115/// # Returns
116/// Column-major matrix (n_points x nbasis) stored as flat vector
117pub fn fourier_basis_with_period(t: &[f64], nbasis: usize, period: f64) -> Vec<f64> {
118    let n = t.len();
119    let t_min = t.iter().cloned().fold(f64::INFINITY, f64::min);
120
121    let mut basis = vec![0.0; n * nbasis];
122
123    for (i, &ti) in t.iter().enumerate() {
124        let x = 2.0 * PI * (ti - t_min) / period;
125
126        basis[i] = 1.0;
127
128        let mut k = 1;
129        let mut freq = 1;
130        while k < nbasis {
131            if k < nbasis {
132                basis[i + k * n] = (freq as f64 * x).sin();
133                k += 1;
134            }
135            if k < nbasis {
136                basis[i + k * n] = (freq as f64 * x).cos();
137                k += 1;
138            }
139            freq += 1;
140        }
141    }
142
143    basis
144}
145
146/// Compute difference matrix for P-spline penalty.
147pub fn difference_matrix(n: usize, order: usize) -> DMatrix<f64> {
148    if order == 0 {
149        return DMatrix::identity(n, n);
150    }
151
152    let mut d = DMatrix::zeros(n - 1, n);
153    for i in 0..(n - 1) {
154        d[(i, i)] = -1.0;
155        d[(i, i + 1)] = 1.0;
156    }
157
158    let mut result = d;
159    for _ in 1..order {
160        if result.nrows() <= 1 {
161            break;
162        }
163        let rows = result.nrows() - 1;
164        let cols = result.ncols();
165        let mut d_next = DMatrix::zeros(rows, cols);
166        for i in 0..rows {
167            for j in 0..cols {
168                d_next[(i, j)] = -result[(i, j)] + result[(i + 1, j)];
169            }
170        }
171        result = d_next;
172    }
173
174    result
175}
176
177/// Result of basis projection.
178pub struct BasisProjectionResult {
179    /// Coefficient matrix (n_samples x n_basis)
180    pub coefficients: Vec<f64>,
181    /// Number of basis functions used
182    pub n_basis: usize,
183}
184
185/// Project functional data to basis coefficients.
186///
187/// # Arguments
188/// * `data` - Column-major matrix (n x m)
189/// * `n` - Number of samples
190/// * `m` - Number of evaluation points
191/// * `argvals` - Evaluation points
192/// * `nbasis` - Number of basis functions
193/// * `basis_type` - 0 = B-spline, 1 = Fourier
194pub fn fdata_to_basis_1d(
195    data: &[f64],
196    n: usize,
197    m: usize,
198    argvals: &[f64],
199    nbasis: usize,
200    basis_type: i32,
201) -> Option<BasisProjectionResult> {
202    if n == 0 || m == 0 || argvals.len() != m || nbasis < 2 {
203        return None;
204    }
205
206    let basis = if basis_type == 1 {
207        fourier_basis(argvals, nbasis)
208    } else {
209        // For order 4 B-splines: nbasis = nknots + order, so nknots = nbasis - 4
210        bspline_basis(argvals, nbasis.saturating_sub(4).max(2), 4)
211    };
212
213    let actual_nbasis = basis.len() / m;
214    let b_mat = DMatrix::from_column_slice(m, actual_nbasis, &basis);
215
216    let btb = &b_mat.transpose() * &b_mat;
217    let btb_svd = SVD::new(btb, true, true);
218
219    let max_sv = btb_svd.singular_values.iter().cloned().fold(0.0, f64::max);
220    let eps = 1e-10 * max_sv;
221
222    let s_inv: Vec<f64> = btb_svd
223        .singular_values
224        .iter()
225        .map(|&s| if s > eps { 1.0 / s } else { 0.0 })
226        .collect();
227
228    let v = btb_svd.v_t.as_ref()?.transpose();
229    let u_t = btb_svd.u.as_ref()?.transpose();
230
231    let mut btb_inv = DMatrix::zeros(actual_nbasis, actual_nbasis);
232    for i in 0..actual_nbasis {
233        for j in 0..actual_nbasis {
234            let mut sum = 0.0;
235            for k in 0..actual_nbasis.min(s_inv.len()) {
236                sum += v[(i, k)] * s_inv[k] * u_t[(k, j)];
237            }
238            btb_inv[(i, j)] = sum;
239        }
240    }
241
242    let proj = btb_inv * b_mat.transpose();
243
244    let coefs: Vec<f64> = iter_maybe_parallel!(0..n)
245        .flat_map(|i| {
246            let curve: Vec<f64> = (0..m).map(|j| data[i + j * n]).collect();
247            (0..actual_nbasis)
248                .map(|k| {
249                    let mut sum = 0.0;
250                    for j in 0..m {
251                        sum += proj[(k, j)] * curve[j];
252                    }
253                    sum
254                })
255                .collect::<Vec<_>>()
256        })
257        .collect();
258
259    Some(BasisProjectionResult {
260        coefficients: coefs,
261        n_basis: actual_nbasis,
262    })
263}
264
265/// Reconstruct functional data from basis coefficients.
266pub fn basis_to_fdata_1d(
267    coefs: &[f64],
268    n: usize,
269    coefs_ncols: usize,
270    argvals: &[f64],
271    nbasis: usize,
272    basis_type: i32,
273) -> Vec<f64> {
274    let m = argvals.len();
275    if n == 0 || m == 0 || nbasis < 2 {
276        return Vec::new();
277    }
278
279    let basis = if basis_type == 1 {
280        fourier_basis(argvals, nbasis)
281    } else {
282        // For order 4 B-splines: nbasis = nknots + order, so nknots = nbasis - 4
283        bspline_basis(argvals, nbasis.saturating_sub(4).max(2), 4)
284    };
285
286    let actual_nbasis = basis.len() / m;
287
288    iter_maybe_parallel!(0..n)
289        .flat_map(|i| {
290            (0..m)
291                .map(|j| {
292                    let mut sum = 0.0;
293                    for k in 0..actual_nbasis.min(coefs_ncols) {
294                        sum += coefs[i + k * n] * basis[j + k * m];
295                    }
296                    sum
297                })
298                .collect::<Vec<_>>()
299        })
300        .collect()
301}
302
303/// Result of P-spline fitting.
304pub struct PsplineFitResult {
305    /// Coefficient matrix
306    pub coefficients: Vec<f64>,
307    /// Fitted values
308    pub fitted: Vec<f64>,
309    /// Effective degrees of freedom
310    pub edf: f64,
311    /// Residual sum of squares
312    pub rss: f64,
313    /// GCV score
314    pub gcv: f64,
315    /// AIC
316    pub aic: f64,
317    /// BIC
318    pub bic: f64,
319    /// Number of basis functions
320    pub n_basis: usize,
321}
322
323/// Fit P-splines to functional data.
324pub fn pspline_fit_1d(
325    data: &[f64],
326    n: usize,
327    m: usize,
328    argvals: &[f64],
329    nbasis: usize,
330    lambda: f64,
331    order: usize,
332) -> Option<PsplineFitResult> {
333    if n == 0 || m == 0 || nbasis < 2 || argvals.len() != m {
334        return None;
335    }
336
337    // For order 4 B-splines: nbasis = nknots + order, so nknots = nbasis - 4
338    let basis = bspline_basis(argvals, nbasis.saturating_sub(4).max(2), 4);
339    let actual_nbasis = basis.len() / m;
340    let b_mat = DMatrix::from_column_slice(m, actual_nbasis, &basis);
341
342    let d = difference_matrix(actual_nbasis, order);
343    let penalty = &d.transpose() * &d;
344
345    let btb = &b_mat.transpose() * &b_mat;
346    let btb_penalized = &btb + lambda * &penalty;
347
348    // Use SVD pseudoinverse for robustness with singular matrices
349    let svd = SVD::new(btb_penalized.clone(), true, true);
350    let max_sv = svd.singular_values.iter().cloned().fold(0.0_f64, f64::max);
351    let eps = 1e-10 * max_sv;
352
353    let u = svd.u.as_ref()?;
354    let v_t = svd.v_t.as_ref()?;
355
356    let s_inv: Vec<f64> = svd
357        .singular_values
358        .iter()
359        .map(|&s| if s > eps { 1.0 / s } else { 0.0 })
360        .collect();
361
362    let mut btb_inv = DMatrix::zeros(actual_nbasis, actual_nbasis);
363    for i in 0..actual_nbasis {
364        for j in 0..actual_nbasis {
365            let mut sum = 0.0;
366            for k in 0..actual_nbasis.min(s_inv.len()) {
367                sum += v_t[(k, i)] * s_inv[k] * u[(j, k)];
368            }
369            btb_inv[(i, j)] = sum;
370        }
371    }
372
373    let proj = &btb_inv * b_mat.transpose();
374    let h_mat = &b_mat * &proj;
375    let edf: f64 = (0..m).map(|i| h_mat[(i, i)]).sum();
376
377    let mut all_coefs = vec![0.0; n * actual_nbasis];
378    let mut all_fitted = vec![0.0; n * m];
379    let mut total_rss = 0.0;
380
381    for i in 0..n {
382        let curve: Vec<f64> = (0..m).map(|j| data[i + j * n]).collect();
383        let curve_vec = DVector::from_vec(curve.clone());
384
385        let bt_y = b_mat.transpose() * &curve_vec;
386        let coefs = &btb_inv * bt_y;
387
388        for k in 0..actual_nbasis {
389            all_coefs[i + k * n] = coefs[k];
390        }
391
392        let fitted = &b_mat * &coefs;
393        for j in 0..m {
394            all_fitted[i + j * n] = fitted[j];
395            let resid = curve[j] - fitted[j];
396            total_rss += resid * resid;
397        }
398    }
399
400    let total_points = (n * m) as f64;
401
402    let gcv_denom = 1.0 - edf / m as f64;
403    let gcv = if gcv_denom.abs() > 1e-10 {
404        (total_rss / total_points) / (gcv_denom * gcv_denom)
405    } else {
406        f64::INFINITY
407    };
408
409    let mse = total_rss / total_points;
410    let aic = total_points * mse.ln() + 2.0 * edf;
411    let bic = total_points * mse.ln() + total_points.ln() * edf;
412
413    Some(PsplineFitResult {
414        coefficients: all_coefs,
415        fitted: all_fitted,
416        edf,
417        rss: total_rss,
418        gcv,
419        aic,
420        bic,
421        n_basis: actual_nbasis,
422    })
423}
424
425/// Result of Fourier basis fitting.
426pub struct FourierFitResult {
427    /// Coefficient matrix
428    pub coefficients: Vec<f64>,
429    /// Fitted values
430    pub fitted: Vec<f64>,
431    /// Effective degrees of freedom (equals nbasis for unpenalized fit)
432    pub edf: f64,
433    /// Residual sum of squares
434    pub rss: f64,
435    /// GCV score
436    pub gcv: f64,
437    /// AIC
438    pub aic: f64,
439    /// BIC
440    pub bic: f64,
441    /// Number of basis functions
442    pub n_basis: usize,
443}
444
445/// Fit Fourier basis to functional data using least squares.
446///
447/// Projects data onto Fourier basis and reconstructs fitted values.
448/// Unlike P-splines, this uses unpenalized least squares projection.
449///
450/// # Arguments
451/// * `data` - Column-major matrix (n x m)
452/// * `n` - Number of samples
453/// * `m` - Number of evaluation points
454/// * `argvals` - Evaluation points
455/// * `nbasis` - Number of Fourier basis functions (should be odd: 1 constant + pairs of sin/cos)
456///
457/// # Returns
458/// FourierFitResult with coefficients, fitted values, and model selection criteria
459pub fn fourier_fit_1d(
460    data: &[f64],
461    n: usize,
462    m: usize,
463    argvals: &[f64],
464    nbasis: usize,
465) -> Option<FourierFitResult> {
466    if n == 0 || m == 0 || nbasis < 3 || argvals.len() != m {
467        return None;
468    }
469
470    // Ensure nbasis is odd (1 constant + pairs of sin/cos)
471    let nbasis = if nbasis % 2 == 0 { nbasis + 1 } else { nbasis };
472
473    let basis = fourier_basis(argvals, nbasis);
474    let actual_nbasis = basis.len() / m;
475    let b_mat = DMatrix::from_column_slice(m, actual_nbasis, &basis);
476
477    let btb = &b_mat.transpose() * &b_mat;
478
479    // Use SVD pseudoinverse for robustness
480    let svd = SVD::new(btb.clone(), true, true);
481    let max_sv = svd.singular_values.iter().cloned().fold(0.0_f64, f64::max);
482    let eps = 1e-10 * max_sv;
483
484    let u = svd.u.as_ref()?;
485    let v_t = svd.v_t.as_ref()?;
486
487    let s_inv: Vec<f64> = svd
488        .singular_values
489        .iter()
490        .map(|&s| if s > eps { 1.0 / s } else { 0.0 })
491        .collect();
492
493    let mut btb_inv = DMatrix::zeros(actual_nbasis, actual_nbasis);
494    for i in 0..actual_nbasis {
495        for j in 0..actual_nbasis {
496            let mut sum = 0.0;
497            for k in 0..actual_nbasis.min(s_inv.len()) {
498                sum += v_t[(k, i)] * s_inv[k] * u[(j, k)];
499            }
500            btb_inv[(i, j)] = sum;
501        }
502    }
503
504    let proj = &btb_inv * b_mat.transpose();
505
506    // For unpenalized fit, hat matrix H = B * (B'B)^{-1} * B'
507    let h_mat = &b_mat * &proj;
508    let edf: f64 = (0..m).map(|i| h_mat[(i, i)]).sum();
509
510    let mut all_coefs = vec![0.0; n * actual_nbasis];
511    let mut all_fitted = vec![0.0; n * m];
512    let mut total_rss = 0.0;
513
514    for i in 0..n {
515        let curve: Vec<f64> = (0..m).map(|j| data[i + j * n]).collect();
516        let curve_vec = DVector::from_vec(curve.clone());
517
518        let bt_y = b_mat.transpose() * &curve_vec;
519        let coefs = &btb_inv * bt_y;
520
521        for k in 0..actual_nbasis {
522            all_coefs[i + k * n] = coefs[k];
523        }
524
525        let fitted = &b_mat * &coefs;
526        for j in 0..m {
527            all_fitted[i + j * n] = fitted[j];
528            let resid = curve[j] - fitted[j];
529            total_rss += resid * resid;
530        }
531    }
532
533    let total_points = (n * m) as f64;
534
535    // GCV: RSS / n * (1 - edf/m)^2
536    let gcv_denom = 1.0 - edf / m as f64;
537    let gcv = if gcv_denom.abs() > 1e-10 {
538        (total_rss / total_points) / (gcv_denom * gcv_denom)
539    } else {
540        f64::INFINITY
541    };
542
543    let mse = total_rss / total_points;
544    let aic = total_points * mse.ln() + 2.0 * edf;
545    let bic = total_points * mse.ln() + total_points.ln() * edf;
546
547    Some(FourierFitResult {
548        coefficients: all_coefs,
549        fitted: all_fitted,
550        edf,
551        rss: total_rss,
552        gcv,
553        aic,
554        bic,
555        n_basis: actual_nbasis,
556    })
557}
558
559/// Select optimal number of Fourier basis functions using GCV.
560///
561/// Performs grid search over nbasis values and returns the one with minimum GCV.
562///
563/// # Arguments
564/// * `data` - Column-major matrix (n x m)
565/// * `n` - Number of samples
566/// * `m` - Number of evaluation points
567/// * `argvals` - Evaluation points
568/// * `min_nbasis` - Minimum number of basis functions to try
569/// * `max_nbasis` - Maximum number of basis functions to try
570///
571/// # Returns
572/// Optimal number of basis functions
573pub fn select_fourier_nbasis_gcv(
574    data: &[f64],
575    n: usize,
576    m: usize,
577    argvals: &[f64],
578    min_nbasis: usize,
579    max_nbasis: usize,
580) -> usize {
581    let min_nb = min_nbasis.max(3);
582    // Ensure max doesn't exceed m (can't have more parameters than data points)
583    let max_nb = max_nbasis.min(m / 2);
584
585    if max_nb <= min_nb {
586        return min_nb;
587    }
588
589    let mut best_nbasis = min_nb;
590    let mut best_gcv = f64::INFINITY;
591
592    // Test odd values only (1 constant + pairs of sin/cos)
593    let mut nbasis = if min_nb % 2 == 0 { min_nb + 1 } else { min_nb };
594    while nbasis <= max_nb {
595        if let Some(result) = fourier_fit_1d(data, n, m, argvals, nbasis) {
596            if result.gcv < best_gcv && result.gcv.is_finite() {
597                best_gcv = result.gcv;
598                best_nbasis = nbasis;
599            }
600        }
601        nbasis += 2;
602    }
603
604    best_nbasis
605}
606
607/// Result of automatic basis selection for a single curve.
608#[derive(Clone)]
609pub struct SingleCurveSelection {
610    /// Selected basis type: 0 = P-spline, 1 = Fourier
611    pub basis_type: i32,
612    /// Selected number of basis functions
613    pub nbasis: usize,
614    /// Best criterion score (GCV, AIC, or BIC)
615    pub score: f64,
616    /// Coefficients for the selected basis
617    pub coefficients: Vec<f64>,
618    /// Fitted values
619    pub fitted: Vec<f64>,
620    /// Effective degrees of freedom
621    pub edf: f64,
622    /// Whether seasonal pattern was detected (if use_seasonal_hint)
623    pub seasonal_detected: bool,
624    /// Lambda value (for P-spline, NaN for Fourier)
625    pub lambda: f64,
626}
627
628/// Result of automatic basis selection for all curves.
629pub struct BasisAutoSelectionResult {
630    /// Per-curve selection results
631    pub selections: Vec<SingleCurveSelection>,
632    /// Criterion used (0=GCV, 1=AIC, 2=BIC)
633    pub criterion: i32,
634}
635
636/// Detect if a curve has seasonal/periodic pattern using FFT.
637///
638/// Returns true if the peak power in the periodogram is significantly
639/// above the mean power level.
640fn detect_seasonality_fft(curve: &[f64]) -> bool {
641    use rustfft::{num_complex::Complex, FftPlanner};
642
643    let n = curve.len();
644    if n < 8 {
645        return false;
646    }
647
648    // Remove mean
649    let mean: f64 = curve.iter().sum::<f64>() / n as f64;
650    let mut input: Vec<Complex<f64>> = curve.iter().map(|&x| Complex::new(x - mean, 0.0)).collect();
651
652    let mut planner = FftPlanner::new();
653    let fft = planner.plan_fft_forward(n);
654    fft.process(&mut input);
655
656    // Compute power spectrum (skip DC component and Nyquist)
657    let powers: Vec<f64> = input[1..n / 2].iter().map(|c| c.norm_sqr()).collect();
658
659    if powers.is_empty() {
660        return false;
661    }
662
663    let max_power = powers.iter().cloned().fold(0.0_f64, f64::max);
664    let mean_power = powers.iter().sum::<f64>() / powers.len() as f64;
665
666    // Seasonal if peak is significantly above mean
667    max_power > 2.0 * mean_power
668}
669
670/// Fit a single curve with Fourier basis and compute criterion score.
671fn fit_curve_fourier(
672    curve: &[f64],
673    m: usize,
674    argvals: &[f64],
675    nbasis: usize,
676    criterion: i32,
677) -> Option<(f64, Vec<f64>, Vec<f64>, f64)> {
678    // Ensure nbasis is odd
679    let nbasis = if nbasis % 2 == 0 { nbasis + 1 } else { nbasis };
680
681    let basis = fourier_basis(argvals, nbasis);
682    let actual_nbasis = basis.len() / m;
683    let b_mat = DMatrix::from_column_slice(m, actual_nbasis, &basis);
684
685    let btb = &b_mat.transpose() * &b_mat;
686    let svd = SVD::new(btb.clone(), true, true);
687    let max_sv = svd.singular_values.iter().cloned().fold(0.0_f64, f64::max);
688    let eps = 1e-10 * max_sv;
689
690    let u = svd.u.as_ref()?;
691    let v_t = svd.v_t.as_ref()?;
692
693    let s_inv: Vec<f64> = svd
694        .singular_values
695        .iter()
696        .map(|&s| if s > eps { 1.0 / s } else { 0.0 })
697        .collect();
698
699    let mut btb_inv = DMatrix::zeros(actual_nbasis, actual_nbasis);
700    for i in 0..actual_nbasis {
701        for j in 0..actual_nbasis {
702            let mut sum = 0.0;
703            for k in 0..actual_nbasis.min(s_inv.len()) {
704                sum += v_t[(k, i)] * s_inv[k] * u[(j, k)];
705            }
706            btb_inv[(i, j)] = sum;
707        }
708    }
709
710    let proj = &btb_inv * b_mat.transpose();
711    let h_mat = &b_mat * &proj;
712    let edf: f64 = (0..m).map(|i| h_mat[(i, i)]).sum();
713
714    let curve_vec = DVector::from_column_slice(curve);
715    let bt_y = b_mat.transpose() * &curve_vec;
716    let coefs = &btb_inv * bt_y;
717
718    let fitted = &b_mat * &coefs;
719    let mut rss = 0.0;
720    for j in 0..m {
721        let resid = curve[j] - fitted[j];
722        rss += resid * resid;
723    }
724
725    let n_points = m as f64;
726    let score = match criterion {
727        0 => {
728            // GCV
729            let gcv_denom = 1.0 - edf / n_points;
730            if gcv_denom.abs() > 1e-10 {
731                (rss / n_points) / (gcv_denom * gcv_denom)
732            } else {
733                f64::INFINITY
734            }
735        }
736        1 => {
737            // AIC
738            let mse = rss / n_points;
739            n_points * mse.ln() + 2.0 * edf
740        }
741        _ => {
742            // BIC
743            let mse = rss / n_points;
744            n_points * mse.ln() + n_points.ln() * edf
745        }
746    };
747
748    let coef_vec: Vec<f64> = (0..actual_nbasis).map(|k| coefs[k]).collect();
749    let fitted_vec: Vec<f64> = (0..m).map(|j| fitted[j]).collect();
750
751    Some((score, coef_vec, fitted_vec, edf))
752}
753
754/// Fit a single curve with P-spline basis and compute criterion score.
755fn fit_curve_pspline(
756    curve: &[f64],
757    m: usize,
758    argvals: &[f64],
759    nbasis: usize,
760    lambda: f64,
761    order: usize,
762    criterion: i32,
763) -> Option<(f64, Vec<f64>, Vec<f64>, f64)> {
764    let basis = bspline_basis(argvals, nbasis.saturating_sub(4).max(2), 4);
765    let actual_nbasis = basis.len() / m;
766    let b_mat = DMatrix::from_column_slice(m, actual_nbasis, &basis);
767
768    let d = difference_matrix(actual_nbasis, order);
769    let penalty = &d.transpose() * &d;
770
771    let btb = &b_mat.transpose() * &b_mat;
772    let btb_penalized = &btb + lambda * &penalty;
773
774    let svd = SVD::new(btb_penalized.clone(), true, true);
775    let max_sv = svd.singular_values.iter().cloned().fold(0.0_f64, f64::max);
776    let eps = 1e-10 * max_sv;
777
778    let u = svd.u.as_ref()?;
779    let v_t = svd.v_t.as_ref()?;
780
781    let s_inv: Vec<f64> = svd
782        .singular_values
783        .iter()
784        .map(|&s| if s > eps { 1.0 / s } else { 0.0 })
785        .collect();
786
787    let mut btb_inv = DMatrix::zeros(actual_nbasis, actual_nbasis);
788    for i in 0..actual_nbasis {
789        for j in 0..actual_nbasis {
790            let mut sum = 0.0;
791            for k in 0..actual_nbasis.min(s_inv.len()) {
792                sum += v_t[(k, i)] * s_inv[k] * u[(j, k)];
793            }
794            btb_inv[(i, j)] = sum;
795        }
796    }
797
798    let proj = &btb_inv * b_mat.transpose();
799    let h_mat = &b_mat * &proj;
800    let edf: f64 = (0..m).map(|i| h_mat[(i, i)]).sum();
801
802    let curve_vec = DVector::from_column_slice(curve);
803    let bt_y = b_mat.transpose() * &curve_vec;
804    let coefs = &btb_inv * bt_y;
805
806    let fitted = &b_mat * &coefs;
807    let mut rss = 0.0;
808    for j in 0..m {
809        let resid = curve[j] - fitted[j];
810        rss += resid * resid;
811    }
812
813    let n_points = m as f64;
814    let score = match criterion {
815        0 => {
816            // GCV
817            let gcv_denom = 1.0 - edf / n_points;
818            if gcv_denom.abs() > 1e-10 {
819                (rss / n_points) / (gcv_denom * gcv_denom)
820            } else {
821                f64::INFINITY
822            }
823        }
824        1 => {
825            // AIC
826            let mse = rss / n_points;
827            n_points * mse.ln() + 2.0 * edf
828        }
829        _ => {
830            // BIC
831            let mse = rss / n_points;
832            n_points * mse.ln() + n_points.ln() * edf
833        }
834    };
835
836    let coef_vec: Vec<f64> = (0..actual_nbasis).map(|k| coefs[k]).collect();
837    let fitted_vec: Vec<f64> = (0..m).map(|j| fitted[j]).collect();
838
839    Some((score, coef_vec, fitted_vec, edf))
840}
841
842/// Select optimal basis type and parameters for each curve individually.
843///
844/// This function compares Fourier and P-spline bases for each curve,
845/// selecting the optimal basis type and number of basis functions using
846/// model selection criteria (GCV, AIC, or BIC).
847///
848/// # Arguments
849/// * `data` - Column-major matrix (n x m)
850/// * `n` - Number of curves
851/// * `m` - Number of evaluation points per curve
852/// * `argvals` - Evaluation points
853/// * `criterion` - Model selection criterion: 0=GCV, 1=AIC, 2=BIC
854/// * `nbasis_min` - Minimum number of basis functions (0 for auto)
855/// * `nbasis_max` - Maximum number of basis functions (0 for auto)
856/// * `lambda_pspline` - Smoothing parameter for P-spline (negative for auto-select)
857/// * `use_seasonal_hint` - Whether to use FFT to detect seasonality
858///
859/// # Returns
860/// BasisAutoSelectionResult with per-curve selections
861pub fn select_basis_auto_1d(
862    data: &[f64],
863    n: usize,
864    m: usize,
865    argvals: &[f64],
866    criterion: i32,
867    nbasis_min: usize,
868    nbasis_max: usize,
869    lambda_pspline: f64,
870    use_seasonal_hint: bool,
871) -> BasisAutoSelectionResult {
872    // Determine nbasis ranges
873    let fourier_min = if nbasis_min > 0 { nbasis_min.max(3) } else { 3 };
874    let fourier_max = if nbasis_max > 0 {
875        nbasis_max.min(m / 3).min(25)
876    } else {
877        (m / 3).min(25)
878    };
879
880    let pspline_min = if nbasis_min > 0 { nbasis_min.max(6) } else { 6 };
881    let pspline_max = if nbasis_max > 0 {
882        nbasis_max.min(m / 2).min(40)
883    } else {
884        (m / 2).min(40)
885    };
886
887    // Lambda grid for P-spline when auto-selecting
888    let lambda_grid = [0.001, 0.01, 0.1, 1.0, 10.0, 100.0];
889    let auto_lambda = lambda_pspline < 0.0;
890
891    let selections: Vec<SingleCurveSelection> = iter_maybe_parallel!(0..n)
892        .map(|i| {
893            // Extract single curve
894            let curve: Vec<f64> = (0..m).map(|j| data[i + j * n]).collect();
895
896            // Detect seasonality if requested
897            let seasonal_detected = if use_seasonal_hint {
898                detect_seasonality_fft(&curve)
899            } else {
900                false
901            };
902
903            let mut best_score = f64::INFINITY;
904            let mut best_basis_type = 0i32; // P-spline
905            let mut best_nbasis = pspline_min;
906            let mut best_coefs = Vec::new();
907            let mut best_fitted = Vec::new();
908            let mut best_edf = 0.0;
909            let mut best_lambda = f64::NAN;
910
911            // Try Fourier bases
912            let fourier_start = if seasonal_detected {
913                fourier_min.max(5)
914            } else {
915                fourier_min
916            };
917            let mut nb = if fourier_start % 2 == 0 {
918                fourier_start + 1
919            } else {
920                fourier_start
921            };
922            while nb <= fourier_max {
923                if let Some((score, coefs, fitted, edf)) =
924                    fit_curve_fourier(&curve, m, argvals, nb, criterion)
925                {
926                    if score < best_score && score.is_finite() {
927                        best_score = score;
928                        best_basis_type = 1; // Fourier
929                        best_nbasis = nb;
930                        best_coefs = coefs;
931                        best_fitted = fitted;
932                        best_edf = edf;
933                        best_lambda = f64::NAN;
934                    }
935                }
936                nb += 2;
937            }
938
939            // Try P-spline bases
940            for nb in pspline_min..=pspline_max {
941                if auto_lambda {
942                    // Search over lambda grid
943                    for &lam in &lambda_grid {
944                        if let Some((score, coefs, fitted, edf)) =
945                            fit_curve_pspline(&curve, m, argvals, nb, lam, 2, criterion)
946                        {
947                            if score < best_score && score.is_finite() {
948                                best_score = score;
949                                best_basis_type = 0; // P-spline
950                                best_nbasis = nb;
951                                best_coefs = coefs;
952                                best_fitted = fitted;
953                                best_edf = edf;
954                                best_lambda = lam;
955                            }
956                        }
957                    }
958                } else {
959                    // Use fixed lambda
960                    if let Some((score, coefs, fitted, edf)) =
961                        fit_curve_pspline(&curve, m, argvals, nb, lambda_pspline, 2, criterion)
962                    {
963                        if score < best_score && score.is_finite() {
964                            best_score = score;
965                            best_basis_type = 0; // P-spline
966                            best_nbasis = nb;
967                            best_coefs = coefs;
968                            best_fitted = fitted;
969                            best_edf = edf;
970                            best_lambda = lambda_pspline;
971                        }
972                    }
973                }
974            }
975
976            SingleCurveSelection {
977                basis_type: best_basis_type,
978                nbasis: best_nbasis,
979                score: best_score,
980                coefficients: best_coefs,
981                fitted: best_fitted,
982                edf: best_edf,
983                seasonal_detected,
984                lambda: best_lambda,
985            }
986        })
987        .collect();
988
989    BasisAutoSelectionResult {
990        selections,
991        criterion,
992    }
993}
994
995#[cfg(test)]
996mod tests {
997    use super::*;
998    use std::f64::consts::PI;
999
1000    /// Generate a uniform grid of points
1001    fn uniform_grid(n: usize) -> Vec<f64> {
1002        (0..n).map(|i| i as f64 / (n - 1) as f64).collect()
1003    }
1004
1005    /// Generate sine wave data
1006    fn sine_wave(t: &[f64], freq: f64) -> Vec<f64> {
1007        t.iter().map(|&ti| (2.0 * PI * freq * ti).sin()).collect()
1008    }
1009
1010    // ============== B-spline basis tests ==============
1011
1012    #[test]
1013    fn test_bspline_basis_dimensions() {
1014        let t = uniform_grid(50);
1015        let nknots = 10;
1016        let order = 4;
1017        let basis = bspline_basis(&t, nknots, order);
1018
1019        let expected_nbasis = nknots + order;
1020        assert_eq!(basis.len(), t.len() * expected_nbasis);
1021    }
1022
1023    #[test]
1024    fn test_bspline_basis_partition_of_unity() {
1025        // B-splines should sum to 1 at each point (partition of unity)
1026        let t = uniform_grid(50);
1027        let nknots = 8;
1028        let order = 4;
1029        let basis = bspline_basis(&t, nknots, order);
1030
1031        let nbasis = nknots + order;
1032        for i in 0..t.len() {
1033            let sum: f64 = (0..nbasis).map(|j| basis[i + j * t.len()]).sum();
1034            assert!(
1035                (sum - 1.0).abs() < 1e-10,
1036                "B-spline partition of unity failed at point {}: sum = {}",
1037                i,
1038                sum
1039            );
1040        }
1041    }
1042
1043    #[test]
1044    fn test_bspline_basis_non_negative() {
1045        let t = uniform_grid(50);
1046        let basis = bspline_basis(&t, 8, 4);
1047
1048        for &val in &basis {
1049            assert!(val >= -1e-10, "B-spline values should be non-negative");
1050        }
1051    }
1052
1053    #[test]
1054    fn test_bspline_basis_boundary() {
1055        // Test that basis functions work at boundary points
1056        let t = vec![0.0, 0.5, 1.0];
1057        let basis = bspline_basis(&t, 5, 4);
1058
1059        // Should have valid output (no NaN or Inf)
1060        for &val in &basis {
1061            assert!(val.is_finite(), "B-spline should produce finite values");
1062        }
1063    }
1064
1065    // ============== Fourier basis tests ==============
1066
1067    #[test]
1068    fn test_fourier_basis_dimensions() {
1069        let t = uniform_grid(50);
1070        let nbasis = 7;
1071        let basis = fourier_basis(&t, nbasis);
1072
1073        assert_eq!(basis.len(), t.len() * nbasis);
1074    }
1075
1076    #[test]
1077    fn test_fourier_basis_constant_first_column() {
1078        let t = uniform_grid(50);
1079        let nbasis = 7;
1080        let basis = fourier_basis(&t, nbasis);
1081
1082        // First column should be constant (DC component = 1)
1083        let first_val = basis[0];
1084        for i in 0..t.len() {
1085            assert!(
1086                (basis[i] - first_val).abs() < 1e-10,
1087                "First Fourier column should be constant"
1088            );
1089        }
1090    }
1091
1092    #[test]
1093    fn test_fourier_basis_sin_cos_range() {
1094        let t = uniform_grid(100);
1095        let nbasis = 11;
1096        let basis = fourier_basis(&t, nbasis);
1097
1098        // Sin and cos values should be in [-1, 1]
1099        for &val in &basis {
1100            assert!((-1.0 - 1e-10..=1.0 + 1e-10).contains(&val));
1101        }
1102    }
1103
1104    #[test]
1105    fn test_fourier_basis_with_period() {
1106        let t = uniform_grid(100);
1107        let nbasis = 5;
1108        let period = 0.5;
1109        let basis = fourier_basis_with_period(&t, nbasis, period);
1110
1111        assert_eq!(basis.len(), t.len() * nbasis);
1112        // Verify first column is constant
1113        let first_val = basis[0];
1114        for i in 0..t.len() {
1115            assert!((basis[i] - first_val).abs() < 1e-10);
1116        }
1117    }
1118
1119    #[test]
1120    fn test_fourier_basis_period_affects_frequency() {
1121        let t = uniform_grid(100);
1122        let nbasis = 5;
1123
1124        let basis1 = fourier_basis_with_period(&t, nbasis, 1.0);
1125        let basis2 = fourier_basis_with_period(&t, nbasis, 0.5);
1126
1127        // Different periods should give different basis matrices
1128        let n = t.len();
1129        let mut any_different = false;
1130        for i in 0..n {
1131            // Compare second column (first sin term)
1132            if (basis1[i + n] - basis2[i + n]).abs() > 1e-10 {
1133                any_different = true;
1134                break;
1135            }
1136        }
1137        assert!(
1138            any_different,
1139            "Different periods should produce different bases"
1140        );
1141    }
1142
1143    // ============== Difference matrix tests ==============
1144
1145    #[test]
1146    fn test_difference_matrix_order_zero() {
1147        let d = difference_matrix(5, 0);
1148        assert_eq!(d.nrows(), 5);
1149        assert_eq!(d.ncols(), 5);
1150
1151        // Should be identity matrix
1152        for i in 0..5 {
1153            for j in 0..5 {
1154                let expected = if i == j { 1.0 } else { 0.0 };
1155                assert!((d[(i, j)] - expected).abs() < 1e-10);
1156            }
1157        }
1158    }
1159
1160    #[test]
1161    fn test_difference_matrix_first_order() {
1162        let d = difference_matrix(5, 1);
1163        assert_eq!(d.nrows(), 4);
1164        assert_eq!(d.ncols(), 5);
1165
1166        // First order differences: D * [1,2,3,4,5] = [1,1,1,1]
1167        let x = DVector::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1168        let dx = &d * x;
1169        for i in 0..4 {
1170            assert!((dx[i] - 1.0).abs() < 1e-10);
1171        }
1172    }
1173
1174    #[test]
1175    fn test_difference_matrix_second_order() {
1176        let d = difference_matrix(5, 2);
1177        assert_eq!(d.nrows(), 3);
1178        assert_eq!(d.ncols(), 5);
1179
1180        // Second order differences of linear: D^2 * [1,2,3,4,5] = [0,0,0]
1181        let x = DVector::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1182        let dx = &d * x;
1183        for i in 0..3 {
1184            assert!(dx[i].abs() < 1e-10, "Second diff of linear should be zero");
1185        }
1186    }
1187
1188    #[test]
1189    fn test_difference_matrix_quadratic() {
1190        let d = difference_matrix(5, 2);
1191
1192        // Second order differences of quadratic: D^2 * [1,4,9,16,25] = [2,2,2]
1193        let x = DVector::from_vec(vec![1.0, 4.0, 9.0, 16.0, 25.0]);
1194        let dx = &d * x;
1195        for i in 0..3 {
1196            assert!(
1197                (dx[i] - 2.0).abs() < 1e-10,
1198                "Second diff of squares should be 2"
1199            );
1200        }
1201    }
1202
1203    // ============== Basis projection tests ==============
1204
1205    #[test]
1206    fn test_fdata_to_basis_1d_bspline() {
1207        let t = uniform_grid(50);
1208        let n = 5;
1209        let m = t.len();
1210
1211        // Create simple data
1212        let data: Vec<f64> = (0..n)
1213            .flat_map(|i| t.iter().map(move |&ti| ti + i as f64 * 0.1))
1214            .collect();
1215
1216        let result = fdata_to_basis_1d(&data, n, m, &t, 10, 0);
1217        assert!(result.is_some());
1218
1219        let res = result.unwrap();
1220        assert!(res.n_basis > 0);
1221        assert_eq!(res.coefficients.len(), n * res.n_basis);
1222    }
1223
1224    #[test]
1225    fn test_fdata_to_basis_1d_fourier() {
1226        let t = uniform_grid(50);
1227        let n = 5;
1228        let m = t.len();
1229
1230        // Create sine wave data
1231        let data: Vec<f64> = (0..n).flat_map(|_| sine_wave(&t, 2.0)).collect();
1232
1233        let result = fdata_to_basis_1d(&data, n, m, &t, 7, 1);
1234        assert!(result.is_some());
1235
1236        let res = result.unwrap();
1237        assert_eq!(res.n_basis, 7);
1238    }
1239
1240    #[test]
1241    fn test_fdata_to_basis_1d_invalid_input() {
1242        let t = uniform_grid(50);
1243
1244        // Empty data
1245        let result = fdata_to_basis_1d(&[], 0, 50, &t, 10, 0);
1246        assert!(result.is_none());
1247
1248        // nbasis too small
1249        let data = vec![0.0; 50];
1250        let result = fdata_to_basis_1d(&data, 1, 50, &t, 1, 0);
1251        assert!(result.is_none());
1252    }
1253
1254    #[test]
1255    fn test_basis_roundtrip() {
1256        let t = uniform_grid(100);
1257        let n = 1;
1258        let m = t.len();
1259
1260        // Create smooth sine wave data (Fourier basis should represent exactly)
1261        let data = sine_wave(&t, 1.0);
1262
1263        // Project to Fourier basis with enough terms
1264        let proj = fdata_to_basis_1d(&data, n, m, &t, 5, 1).unwrap();
1265
1266        // Reconstruct
1267        let reconstructed =
1268            basis_to_fdata_1d(&proj.coefficients, n, proj.n_basis, &t, proj.n_basis, 1);
1269
1270        // Should approximately match original for a simple sine wave
1271        let mut max_error = 0.0;
1272        for i in 0..m {
1273            let err = (data[i] - reconstructed[i]).abs();
1274            if err > max_error {
1275                max_error = err;
1276            }
1277        }
1278        assert!(max_error < 0.5, "Roundtrip error too large: {}", max_error);
1279    }
1280
1281    #[test]
1282    fn test_basis_to_fdata_empty_input() {
1283        let result = basis_to_fdata_1d(&[], 0, 0, &[], 5, 0);
1284        assert!(result.is_empty());
1285    }
1286
1287    // ============== P-spline fitting tests ==============
1288
1289    #[test]
1290    fn test_pspline_fit_1d_basic() {
1291        let t = uniform_grid(50);
1292        let n = 3;
1293        let m = t.len();
1294
1295        // Create noisy data
1296        let data: Vec<f64> = (0..n)
1297            .flat_map(|i| {
1298                t.iter()
1299                    .enumerate()
1300                    .map(move |(j, &ti)| (2.0 * PI * ti).sin() + 0.1 * (i * j) as f64 % 1.0)
1301            })
1302            .collect();
1303
1304        let result = pspline_fit_1d(&data, n, m, &t, 15, 1.0, 2);
1305        assert!(result.is_some());
1306
1307        let res = result.unwrap();
1308        assert!(res.n_basis > 0);
1309        assert_eq!(res.fitted.len(), n * m);
1310        assert!(res.rss >= 0.0);
1311        assert!(res.edf > 0.0);
1312        assert!(res.gcv.is_finite());
1313    }
1314
1315    #[test]
1316    fn test_pspline_fit_1d_smoothness() {
1317        let t = uniform_grid(50);
1318        let n = 1;
1319        let m = t.len();
1320
1321        // Create noisy sine wave
1322        let data: Vec<f64> = t
1323            .iter()
1324            .enumerate()
1325            .map(|(i, &ti)| (2.0 * PI * ti).sin() + 0.3 * ((i * 17) % 100) as f64 / 100.0)
1326            .collect();
1327
1328        let low_lambda = pspline_fit_1d(&data, n, m, &t, 15, 0.01, 2).unwrap();
1329        let high_lambda = pspline_fit_1d(&data, n, m, &t, 15, 100.0, 2).unwrap();
1330
1331        // Higher lambda should give lower edf (more smoothing)
1332        assert!(high_lambda.edf < low_lambda.edf);
1333    }
1334
1335    #[test]
1336    fn test_pspline_fit_1d_invalid_input() {
1337        let t = uniform_grid(50);
1338        let result = pspline_fit_1d(&[], 0, 50, &t, 15, 1.0, 2);
1339        assert!(result.is_none());
1340    }
1341
1342    // ============== Fourier fitting tests ==============
1343
1344    #[test]
1345    fn test_fourier_fit_1d_sine_wave() {
1346        let t = uniform_grid(100);
1347        let n = 1;
1348        let m = t.len();
1349
1350        // Create pure sine wave
1351        let data = sine_wave(&t, 2.0);
1352
1353        let result = fourier_fit_1d(&data, n, m, &t, 11);
1354        assert!(result.is_some());
1355
1356        let res = result.unwrap();
1357        assert!(res.rss < 1e-6, "Pure sine should have near-zero RSS");
1358    }
1359
1360    #[test]
1361    fn test_fourier_fit_1d_makes_nbasis_odd() {
1362        let t = uniform_grid(50);
1363        let data = sine_wave(&t, 1.0);
1364
1365        // Pass even nbasis
1366        let result = fourier_fit_1d(&data, 1, t.len(), &t, 6);
1367        assert!(result.is_some());
1368
1369        // Should have been adjusted to odd
1370        let res = result.unwrap();
1371        assert!(res.n_basis % 2 == 1);
1372    }
1373
1374    #[test]
1375    fn test_fourier_fit_1d_criteria() {
1376        let t = uniform_grid(50);
1377        let data = sine_wave(&t, 2.0);
1378
1379        let result = fourier_fit_1d(&data, 1, t.len(), &t, 9).unwrap();
1380
1381        // All criteria should be finite
1382        assert!(result.gcv.is_finite());
1383        assert!(result.aic.is_finite());
1384        assert!(result.bic.is_finite());
1385    }
1386
1387    #[test]
1388    fn test_fourier_fit_1d_invalid_nbasis() {
1389        let t = uniform_grid(50);
1390        let data = sine_wave(&t, 1.0);
1391
1392        // nbasis < 3 should return None
1393        let result = fourier_fit_1d(&data, 1, t.len(), &t, 2);
1394        assert!(result.is_none());
1395    }
1396
1397    // ============== GCV selection tests ==============
1398
1399    #[test]
1400    fn test_select_fourier_nbasis_gcv_range() {
1401        let t = uniform_grid(100);
1402        let data = sine_wave(&t, 3.0);
1403
1404        let best = select_fourier_nbasis_gcv(&data, 1, t.len(), &t, 3, 15);
1405
1406        assert!((3..=15).contains(&best));
1407        assert!(best % 2 == 1, "Selected nbasis should be odd");
1408    }
1409
1410    #[test]
1411    fn test_select_fourier_nbasis_gcv_respects_min() {
1412        let t = uniform_grid(50);
1413        let data = sine_wave(&t, 1.0);
1414
1415        let best = select_fourier_nbasis_gcv(&data, 1, t.len(), &t, 7, 15);
1416        assert!(best >= 7);
1417    }
1418
1419    // ============== Auto selection tests ==============
1420
1421    #[test]
1422    fn test_select_basis_auto_1d_returns_results() {
1423        let t = uniform_grid(50);
1424        let n = 3;
1425        let m = t.len();
1426
1427        let data: Vec<f64> = (0..n).flat_map(|i| sine_wave(&t, 1.0 + i as f64)).collect();
1428
1429        let result = select_basis_auto_1d(&data, n, m, &t, 0, 5, 15, 1.0, false);
1430
1431        assert_eq!(result.selections.len(), n);
1432        for sel in &result.selections {
1433            assert!(sel.nbasis >= 3);
1434            assert!(!sel.coefficients.is_empty());
1435            assert_eq!(sel.fitted.len(), m);
1436        }
1437    }
1438
1439    #[test]
1440    fn test_select_basis_auto_1d_seasonal_hint() {
1441        let t = uniform_grid(100);
1442        let n = 1;
1443        let m = t.len();
1444
1445        // Strong seasonal pattern
1446        let data = sine_wave(&t, 5.0);
1447
1448        let result = select_basis_auto_1d(&data, n, m, &t, 0, 0, 0, -1.0, true);
1449
1450        assert_eq!(result.selections.len(), 1);
1451        assert!(result.selections[0].seasonal_detected);
1452    }
1453
1454    #[test]
1455    fn test_select_basis_auto_1d_non_seasonal() {
1456        let t = uniform_grid(50);
1457        let n = 1;
1458        let m = t.len();
1459
1460        // Constant data (definitely not seasonal)
1461        let data: Vec<f64> = vec![1.0; m];
1462
1463        let result = select_basis_auto_1d(&data, n, m, &t, 0, 0, 0, -1.0, true);
1464
1465        // Constant data shouldn't be detected as seasonal
1466        assert!(!result.selections[0].seasonal_detected);
1467    }
1468
1469    #[test]
1470    fn test_select_basis_auto_1d_criterion_options() {
1471        let t = uniform_grid(50);
1472        let data = sine_wave(&t, 2.0);
1473
1474        // Test all three criteria
1475        let gcv_result = select_basis_auto_1d(&data, 1, t.len(), &t, 0, 0, 0, 1.0, false);
1476        let aic_result = select_basis_auto_1d(&data, 1, t.len(), &t, 1, 0, 0, 1.0, false);
1477        let bic_result = select_basis_auto_1d(&data, 1, t.len(), &t, 2, 0, 0, 1.0, false);
1478
1479        assert_eq!(gcv_result.criterion, 0);
1480        assert_eq!(aic_result.criterion, 1);
1481        assert_eq!(bic_result.criterion, 2);
1482    }
1483}