Skip to main content

fdars_core/basis/
auto_select.rs

1//! Automatic basis type and parameter selection.
2
3use super::bspline::bspline_basis;
4use super::fourier::fourier_basis;
5use super::helpers::{compute_model_criterion, svd_pseudoinverse};
6use super::projection::ProjectionBasisType;
7use super::pspline::difference_matrix;
8use crate::iter_maybe_parallel;
9use crate::matrix::FdMatrix;
10use nalgebra::{DMatrix, DVector};
11#[cfg(feature = "parallel")]
12use rayon::iter::ParallelIterator;
13
14/// Result of automatic basis selection for a single curve.
15#[derive(Debug, Clone)]
16pub struct SingleCurveSelection {
17    /// Selected basis type.
18    pub basis_type: ProjectionBasisType,
19    /// Selected number of basis functions
20    pub nbasis: usize,
21    /// Best criterion score (GCV, AIC, or BIC)
22    pub score: f64,
23    /// Coefficients for the selected basis
24    pub coefficients: Vec<f64>,
25    /// Fitted values
26    pub fitted: Vec<f64>,
27    /// Effective degrees of freedom
28    pub edf: f64,
29    /// Whether seasonal pattern was detected (if use_seasonal_hint)
30    pub seasonal_detected: bool,
31    /// Lambda value (for P-spline, NaN for Fourier)
32    pub lambda: f64,
33}
34
35/// Result of automatic basis selection for all curves.
36#[derive(Debug, Clone)]
37pub struct BasisAutoSelectionResult {
38    /// Per-curve selection results
39    pub selections: Vec<SingleCurveSelection>,
40    /// Criterion used (0=GCV, 1=AIC, 2=BIC)
41    pub criterion: i32,
42}
43
44/// Detect if a curve has seasonal/periodic pattern using FFT.
45///
46/// Returns true if the peak power in the periodogram is significantly
47/// above the mean power level.
48fn detect_seasonality_fft(curve: &[f64]) -> bool {
49    use rustfft::{num_complex::Complex, FftPlanner};
50
51    let n = curve.len();
52    if n < 8 {
53        return false;
54    }
55
56    // Remove mean
57    let mean: f64 = curve.iter().sum::<f64>() / n as f64;
58    let mut input: Vec<Complex<f64>> = curve.iter().map(|&x| Complex::new(x - mean, 0.0)).collect();
59
60    let mut planner = FftPlanner::new();
61    let fft = planner.plan_fft_forward(n);
62    fft.process(&mut input);
63
64    // Compute power spectrum (skip DC component and Nyquist)
65    let powers: Vec<f64> = input[1..n / 2]
66        .iter()
67        .map(nalgebra::Complex::norm_sqr)
68        .collect();
69
70    if powers.is_empty() {
71        return false;
72    }
73
74    let max_power = powers.iter().copied().fold(0.0_f64, f64::max);
75    let mean_power = powers.iter().sum::<f64>() / powers.len() as f64;
76
77    // Seasonal if peak is significantly above mean
78    max_power > 2.0 * mean_power
79}
80
81/// Fit a single curve with Fourier basis and compute criterion score.
82fn fit_curve_fourier(
83    curve: &[f64],
84    m: usize,
85    argvals: &[f64],
86    nbasis: usize,
87    criterion: i32,
88) -> Option<(f64, Vec<f64>, Vec<f64>, f64)> {
89    let nbasis = if nbasis % 2 == 0 { nbasis + 1 } else { nbasis };
90
91    let basis = fourier_basis(argvals, nbasis);
92    let actual_nbasis = basis.len() / m;
93    let b_mat = DMatrix::from_column_slice(m, actual_nbasis, &basis);
94
95    let btb = &b_mat.transpose() * &b_mat;
96    let btb_inv = svd_pseudoinverse(&btb)?;
97    let proj = &btb_inv * b_mat.transpose();
98    let h_mat = &b_mat * &proj;
99    let edf: f64 = (0..m).map(|i| h_mat[(i, i)]).sum();
100
101    let curve_vec = DVector::from_column_slice(curve);
102    let coefs = &btb_inv * (b_mat.transpose() * &curve_vec);
103    let fitted = &b_mat * &coefs;
104
105    let rss: f64 = (0..m).map(|j| (curve[j] - fitted[j]).powi(2)).sum();
106    let score = compute_model_criterion(rss, m as f64, edf, criterion);
107
108    let coef_vec: Vec<f64> = (0..actual_nbasis).map(|k| coefs[k]).collect();
109    let fitted_vec: Vec<f64> = (0..m).map(|j| fitted[j]).collect();
110
111    Some((score, coef_vec, fitted_vec, edf))
112}
113
114/// Fit a single curve with P-spline basis and compute criterion score.
115fn fit_curve_pspline(
116    curve: &[f64],
117    m: usize,
118    argvals: &[f64],
119    nbasis: usize,
120    lambda: f64,
121    order: usize,
122    criterion: i32,
123) -> Option<(f64, Vec<f64>, Vec<f64>, f64)> {
124    let basis = bspline_basis(argvals, nbasis.saturating_sub(4).max(2), 4);
125    let actual_nbasis = basis.len() / m;
126    let b_mat = DMatrix::from_column_slice(m, actual_nbasis, &basis);
127
128    let d = difference_matrix(actual_nbasis, order);
129    let penalty = &d.transpose() * &d;
130    let btb = &b_mat.transpose() * &b_mat;
131    let btb_penalized = &btb + lambda * &penalty;
132
133    let btb_inv = svd_pseudoinverse(&btb_penalized)?;
134    let proj = &btb_inv * b_mat.transpose();
135    let h_mat = &b_mat * &proj;
136    let edf: f64 = (0..m).map(|i| h_mat[(i, i)]).sum();
137
138    let curve_vec = DVector::from_column_slice(curve);
139    let coefs = &btb_inv * (b_mat.transpose() * &curve_vec);
140    let fitted = &b_mat * &coefs;
141
142    let rss: f64 = (0..m).map(|j| (curve[j] - fitted[j]).powi(2)).sum();
143    let score = compute_model_criterion(rss, m as f64, edf, criterion);
144
145    let coef_vec: Vec<f64> = (0..actual_nbasis).map(|k| coefs[k]).collect();
146    let fitted_vec: Vec<f64> = (0..m).map(|j| fitted[j]).collect();
147
148    Some((score, coef_vec, fitted_vec, edf))
149}
150
151/// Result of a basis search for a single curve.
152struct BasisSearchResult {
153    score: f64,
154    nbasis: usize,
155    coefs: Vec<f64>,
156    fitted: Vec<f64>,
157    edf: f64,
158    lambda: f64,
159}
160
161/// Search over Fourier basis sizes for the best fit.
162fn search_fourier_basis(
163    curve: &[f64],
164    m: usize,
165    argvals: &[f64],
166    fourier_min: usize,
167    fourier_max: usize,
168    seasonal: bool,
169    criterion: i32,
170) -> Option<BasisSearchResult> {
171    let fourier_start = if seasonal {
172        fourier_min.max(5)
173    } else {
174        fourier_min
175    };
176    let mut nb = if fourier_start % 2 == 0 {
177        fourier_start + 1
178    } else {
179        fourier_start
180    };
181
182    let mut best: Option<BasisSearchResult> = None;
183    while nb <= fourier_max {
184        if let Some((score, coefs, fitted, edf)) =
185            fit_curve_fourier(curve, m, argvals, nb, criterion)
186        {
187            if score.is_finite() && best.as_ref().map_or(true, |b| score < b.score) {
188                best = Some(BasisSearchResult {
189                    score,
190                    nbasis: nb,
191                    coefs,
192                    fitted,
193                    edf,
194                    lambda: f64::NAN,
195                });
196            }
197        }
198        nb += 2;
199    }
200    best
201}
202
203/// Try a single P-spline fit and update best if it improves the score.
204fn try_pspline_fit_update(
205    curve: &[f64],
206    m: usize,
207    argvals: &[f64],
208    nb: usize,
209    lam: f64,
210    criterion: i32,
211    best: &mut Option<BasisSearchResult>,
212) {
213    if let Some((score, coefs, fitted, edf)) =
214        fit_curve_pspline(curve, m, argvals, nb, lam, 2, criterion)
215    {
216        if score.is_finite() && best.as_ref().map_or(true, |b| score < b.score) {
217            *best = Some(BasisSearchResult {
218                score,
219                nbasis: nb,
220                coefs,
221                fitted,
222                edf,
223                lambda: lam,
224            });
225        }
226    }
227}
228
229/// Search over P-spline basis sizes (and optionally lambda) for the best fit.
230fn search_pspline_basis(
231    curve: &[f64],
232    m: usize,
233    argvals: &[f64],
234    pspline_min: usize,
235    pspline_max: usize,
236    lambda_grid: &[f64],
237    auto_lambda: bool,
238    lambda: f64,
239    criterion: i32,
240) -> Option<BasisSearchResult> {
241    let mut best: Option<BasisSearchResult> = None;
242    for nb in pspline_min..=pspline_max {
243        let lambdas: Box<dyn Iterator<Item = f64>> = if auto_lambda {
244            Box::new(lambda_grid.iter().copied())
245        } else {
246            Box::new(std::iter::once(lambda))
247        };
248        for lam in lambdas {
249            try_pspline_fit_update(curve, m, argvals, nb, lam, criterion, &mut best);
250        }
251    }
252    best
253}
254
255/// Select optimal basis type and parameters for each curve individually.
256///
257/// This function compares Fourier and P-spline bases for each curve,
258/// selecting the optimal basis type and number of basis functions using
259/// model selection criteria (GCV, AIC, or BIC).
260///
261/// # Arguments
262/// * `data` - Column-major FdMatrix (n x m)
263/// * `argvals` - Evaluation points
264/// * `criterion` - Model selection criterion: 0=GCV, 1=AIC, 2=BIC
265/// * `nbasis_min` - Minimum number of basis functions (0 for auto)
266/// * `nbasis_max` - Maximum number of basis functions (0 for auto)
267/// * `lambda_pspline` - Smoothing parameter for P-spline (negative for auto-select)
268/// * `use_seasonal_hint` - Whether to use FFT to detect seasonality
269///
270/// # Returns
271/// BasisAutoSelectionResult with per-curve selections
272pub fn select_basis_auto_1d(
273    data: &FdMatrix,
274    argvals: &[f64],
275    criterion: i32,
276    nbasis_min: usize,
277    nbasis_max: usize,
278    lambda_pspline: f64,
279    use_seasonal_hint: bool,
280) -> BasisAutoSelectionResult {
281    let n = data.nrows();
282    let m = data.ncols();
283    let fourier_min = if nbasis_min > 0 { nbasis_min.max(3) } else { 3 };
284    let fourier_max = if nbasis_max > 0 {
285        nbasis_max.min(m / 3).min(25)
286    } else {
287        (m / 3).min(25)
288    };
289    let pspline_min = if nbasis_min > 0 { nbasis_min.max(6) } else { 6 };
290    let pspline_max = if nbasis_max > 0 {
291        nbasis_max.min(m / 2).min(40)
292    } else {
293        (m / 2).min(40)
294    };
295
296    let lambda_grid = [0.001, 0.01, 0.1, 1.0, 10.0, 100.0];
297    let auto_lambda = lambda_pspline < 0.0;
298
299    let selections: Vec<SingleCurveSelection> = iter_maybe_parallel!(0..n)
300        .map(|i| {
301            let curve: Vec<f64> = (0..m).map(|j| data[(i, j)]).collect();
302            let seasonal_detected = if use_seasonal_hint {
303                detect_seasonality_fft(&curve)
304            } else {
305                false
306            };
307
308            let fourier_best = search_fourier_basis(
309                &curve,
310                m,
311                argvals,
312                fourier_min,
313                fourier_max,
314                seasonal_detected,
315                criterion,
316            );
317            let pspline_best = search_pspline_basis(
318                &curve,
319                m,
320                argvals,
321                pspline_min,
322                pspline_max,
323                &lambda_grid,
324                auto_lambda,
325                lambda_pspline,
326                criterion,
327            );
328
329            // Pick the best overall result
330            let (basis_type, result) = match (fourier_best, pspline_best) {
331                (Some(f), Some(p)) => {
332                    if f.score <= p.score {
333                        (ProjectionBasisType::Fourier, f)
334                    } else {
335                        (ProjectionBasisType::Bspline, p)
336                    }
337                }
338                (Some(f), None) => (ProjectionBasisType::Fourier, f),
339                (None, Some(p)) => (ProjectionBasisType::Bspline, p),
340                (None, None) => {
341                    return SingleCurveSelection {
342                        basis_type: ProjectionBasisType::Bspline,
343                        nbasis: pspline_min,
344                        score: f64::INFINITY,
345                        coefficients: Vec::new(),
346                        fitted: Vec::new(),
347                        edf: 0.0,
348                        seasonal_detected,
349                        lambda: f64::NAN,
350                    };
351                }
352            };
353
354            SingleCurveSelection {
355                basis_type,
356                nbasis: result.nbasis,
357                score: result.score,
358                coefficients: result.coefs,
359                fitted: result.fitted,
360                edf: result.edf,
361                seasonal_detected,
362                lambda: result.lambda,
363            }
364        })
365        .collect();
366
367    BasisAutoSelectionResult {
368        selections,
369        criterion,
370    }
371}