Skip to main content

fdars_core/basis/
auto_select.rs

1//! Automatic basis type and parameter selection.
2
3use super::bspline::bspline_basis;
4use super::fourier::fourier_basis;
5use super::helpers::{compute_model_criterion, svd_pseudoinverse};
6use super::projection::ProjectionBasisType;
7use super::pspline::difference_matrix;
8use crate::iter_maybe_parallel;
9use crate::matrix::FdMatrix;
10use nalgebra::{DMatrix, DVector};
11#[cfg(feature = "parallel")]
12use rayon::iter::ParallelIterator;
13
14/// Result of automatic basis selection for a single curve.
15#[derive(Debug, Clone)]
16#[non_exhaustive]
17pub struct SingleCurveSelection {
18    /// Selected basis type.
19    pub basis_type: ProjectionBasisType,
20    /// Selected number of basis functions
21    pub nbasis: usize,
22    /// Best criterion score (GCV, AIC, or BIC)
23    pub score: f64,
24    /// Coefficients for the selected basis
25    pub coefficients: Vec<f64>,
26    /// Fitted values
27    pub fitted: Vec<f64>,
28    /// Effective degrees of freedom
29    pub edf: f64,
30    /// Whether seasonal pattern was detected (if use_seasonal_hint)
31    pub seasonal_detected: bool,
32    /// Lambda value (for P-spline, NaN for Fourier)
33    pub lambda: f64,
34}
35
36/// Result of automatic basis selection for all curves.
37#[derive(Debug, Clone)]
38#[non_exhaustive]
39pub struct BasisAutoSelectionResult {
40    /// Per-curve selection results
41    pub selections: Vec<SingleCurveSelection>,
42    /// Criterion used (0=GCV, 1=AIC, 2=BIC)
43    pub criterion: i32,
44}
45
46/// Detect if a curve has seasonal/periodic pattern using FFT.
47///
48/// Returns true if the peak power in the periodogram is significantly
49/// above the mean power level.
50fn detect_seasonality_fft(curve: &[f64]) -> bool {
51    use rustfft::{num_complex::Complex, FftPlanner};
52
53    let n = curve.len();
54    if n < 8 {
55        return false;
56    }
57
58    // Remove mean
59    let mean: f64 = curve.iter().sum::<f64>() / n as f64;
60    let mut input: Vec<Complex<f64>> = curve.iter().map(|&x| Complex::new(x - mean, 0.0)).collect();
61
62    let mut planner = FftPlanner::new();
63    let fft = planner.plan_fft_forward(n);
64    fft.process(&mut input);
65
66    // Compute power spectrum (skip DC component and Nyquist)
67    let powers: Vec<f64> = input[1..n / 2]
68        .iter()
69        .map(nalgebra::Complex::norm_sqr)
70        .collect();
71
72    if powers.is_empty() {
73        return false;
74    }
75
76    let max_power = powers.iter().copied().fold(0.0_f64, f64::max);
77    let mean_power = powers.iter().sum::<f64>() / powers.len() as f64;
78
79    // Seasonal if peak is significantly above mean
80    max_power > 2.0 * mean_power
81}
82
83/// Fit a single curve with Fourier basis and compute criterion score.
84fn fit_curve_fourier(
85    curve: &[f64],
86    m: usize,
87    argvals: &[f64],
88    nbasis: usize,
89    criterion: i32,
90) -> Option<(f64, Vec<f64>, Vec<f64>, f64)> {
91    let nbasis = if nbasis % 2 == 0 { nbasis + 1 } else { nbasis };
92
93    let basis = fourier_basis(argvals, nbasis);
94    let actual_nbasis = basis.len() / m;
95    let b_mat = DMatrix::from_column_slice(m, actual_nbasis, &basis);
96
97    let btb = &b_mat.transpose() * &b_mat;
98    let btb_inv = svd_pseudoinverse(&btb)?;
99    let proj = &btb_inv * b_mat.transpose();
100    let h_mat = &b_mat * &proj;
101    let edf: f64 = (0..m).map(|i| h_mat[(i, i)]).sum();
102
103    let curve_vec = DVector::from_column_slice(curve);
104    let coefs = &btb_inv * (b_mat.transpose() * &curve_vec);
105    let fitted = &b_mat * &coefs;
106
107    let rss: f64 = (0..m).map(|j| (curve[j] - fitted[j]).powi(2)).sum();
108    let score = compute_model_criterion(rss, m as f64, edf, criterion);
109
110    let coef_vec: Vec<f64> = (0..actual_nbasis).map(|k| coefs[k]).collect();
111    let fitted_vec: Vec<f64> = (0..m).map(|j| fitted[j]).collect();
112
113    Some((score, coef_vec, fitted_vec, edf))
114}
115
116/// Fit a single curve with P-spline basis and compute criterion score.
117fn fit_curve_pspline(
118    curve: &[f64],
119    m: usize,
120    argvals: &[f64],
121    nbasis: usize,
122    lambda: f64,
123    order: usize,
124    criterion: i32,
125) -> Option<(f64, Vec<f64>, Vec<f64>, f64)> {
126    let basis = bspline_basis(argvals, nbasis.saturating_sub(4).max(2), 4);
127    let actual_nbasis = basis.len() / m;
128    let b_mat = DMatrix::from_column_slice(m, actual_nbasis, &basis);
129
130    let d = difference_matrix(actual_nbasis, order);
131    let penalty = &d.transpose() * &d;
132    let btb = &b_mat.transpose() * &b_mat;
133    let btb_penalized = &btb + lambda * &penalty;
134
135    let btb_inv = svd_pseudoinverse(&btb_penalized)?;
136    let proj = &btb_inv * b_mat.transpose();
137    let h_mat = &b_mat * &proj;
138    let edf: f64 = (0..m).map(|i| h_mat[(i, i)]).sum();
139
140    let curve_vec = DVector::from_column_slice(curve);
141    let coefs = &btb_inv * (b_mat.transpose() * &curve_vec);
142    let fitted = &b_mat * &coefs;
143
144    let rss: f64 = (0..m).map(|j| (curve[j] - fitted[j]).powi(2)).sum();
145    let score = compute_model_criterion(rss, m as f64, edf, criterion);
146
147    let coef_vec: Vec<f64> = (0..actual_nbasis).map(|k| coefs[k]).collect();
148    let fitted_vec: Vec<f64> = (0..m).map(|j| fitted[j]).collect();
149
150    Some((score, coef_vec, fitted_vec, edf))
151}
152
153/// Result of a basis search for a single curve.
154struct BasisSearchResult {
155    score: f64,
156    nbasis: usize,
157    coefs: Vec<f64>,
158    fitted: Vec<f64>,
159    edf: f64,
160    lambda: f64,
161}
162
163/// Search over Fourier basis sizes for the best fit.
164fn search_fourier_basis(
165    curve: &[f64],
166    m: usize,
167    argvals: &[f64],
168    fourier_min: usize,
169    fourier_max: usize,
170    seasonal: bool,
171    criterion: i32,
172) -> Option<BasisSearchResult> {
173    let fourier_start = if seasonal {
174        fourier_min.max(5)
175    } else {
176        fourier_min
177    };
178    let mut nb = if fourier_start % 2 == 0 {
179        fourier_start + 1
180    } else {
181        fourier_start
182    };
183
184    let mut best: Option<BasisSearchResult> = None;
185    while nb <= fourier_max {
186        if let Some((score, coefs, fitted, edf)) =
187            fit_curve_fourier(curve, m, argvals, nb, criterion)
188        {
189            if score.is_finite() && best.as_ref().map_or(true, |b| score < b.score) {
190                best = Some(BasisSearchResult {
191                    score,
192                    nbasis: nb,
193                    coefs,
194                    fitted,
195                    edf,
196                    lambda: f64::NAN,
197                });
198            }
199        }
200        nb += 2;
201    }
202    best
203}
204
205/// Try a single P-spline fit and update best if it improves the score.
206fn try_pspline_fit_update(
207    curve: &[f64],
208    m: usize,
209    argvals: &[f64],
210    nb: usize,
211    lam: f64,
212    criterion: i32,
213    best: &mut Option<BasisSearchResult>,
214) {
215    if let Some((score, coefs, fitted, edf)) =
216        fit_curve_pspline(curve, m, argvals, nb, lam, 2, criterion)
217    {
218        if score.is_finite() && best.as_ref().map_or(true, |b| score < b.score) {
219            *best = Some(BasisSearchResult {
220                score,
221                nbasis: nb,
222                coefs,
223                fitted,
224                edf,
225                lambda: lam,
226            });
227        }
228    }
229}
230
231/// Search over P-spline basis sizes (and optionally lambda) for the best fit.
232fn search_pspline_basis(
233    curve: &[f64],
234    m: usize,
235    argvals: &[f64],
236    pspline_min: usize,
237    pspline_max: usize,
238    lambda_grid: &[f64],
239    auto_lambda: bool,
240    lambda: f64,
241    criterion: i32,
242) -> Option<BasisSearchResult> {
243    let mut best: Option<BasisSearchResult> = None;
244    for nb in pspline_min..=pspline_max {
245        let lambdas: Box<dyn Iterator<Item = f64>> = if auto_lambda {
246            Box::new(lambda_grid.iter().copied())
247        } else {
248            Box::new(std::iter::once(lambda))
249        };
250        for lam in lambdas {
251            try_pspline_fit_update(curve, m, argvals, nb, lam, criterion, &mut best);
252        }
253    }
254    best
255}
256
257/// Select optimal basis type and parameters for each curve individually.
258///
259/// This function compares Fourier and P-spline bases for each curve,
260/// selecting the optimal basis type and number of basis functions using
261/// model selection criteria (GCV, AIC, or BIC).
262///
263/// # Arguments
264/// * `data` - Column-major FdMatrix (n x m)
265/// * `argvals` - Evaluation points
266/// * `criterion` - Model selection criterion: 0=GCV, 1=AIC, 2=BIC
267/// * `nbasis_min` - Minimum number of basis functions (0 for auto)
268/// * `nbasis_max` - Maximum number of basis functions (0 for auto)
269/// * `lambda_pspline` - Smoothing parameter for P-spline (negative for auto-select)
270/// * `use_seasonal_hint` - Whether to use FFT to detect seasonality
271///
272/// # Returns
273/// BasisAutoSelectionResult with per-curve selections
274pub fn select_basis_auto_1d(
275    data: &FdMatrix,
276    argvals: &[f64],
277    criterion: i32,
278    nbasis_min: usize,
279    nbasis_max: usize,
280    lambda_pspline: f64,
281    use_seasonal_hint: bool,
282) -> BasisAutoSelectionResult {
283    let n = data.nrows();
284    let m = data.ncols();
285    let fourier_min = if nbasis_min > 0 { nbasis_min.max(3) } else { 3 };
286    let fourier_max = if nbasis_max > 0 {
287        nbasis_max.min(m / 3).min(25)
288    } else {
289        (m / 3).min(25)
290    };
291    let pspline_min = if nbasis_min > 0 { nbasis_min.max(6) } else { 6 };
292    let pspline_max = if nbasis_max > 0 {
293        nbasis_max.min(m / 2).min(40)
294    } else {
295        (m / 2).min(40)
296    };
297
298    let lambda_grid = [0.001, 0.01, 0.1, 1.0, 10.0, 100.0];
299    let auto_lambda = lambda_pspline < 0.0;
300
301    let selections: Vec<SingleCurveSelection> = iter_maybe_parallel!(0..n)
302        .map(|i| {
303            let curve: Vec<f64> = (0..m).map(|j| data[(i, j)]).collect();
304            let seasonal_detected = if use_seasonal_hint {
305                detect_seasonality_fft(&curve)
306            } else {
307                false
308            };
309
310            let fourier_best = search_fourier_basis(
311                &curve,
312                m,
313                argvals,
314                fourier_min,
315                fourier_max,
316                seasonal_detected,
317                criterion,
318            );
319            let pspline_best = search_pspline_basis(
320                &curve,
321                m,
322                argvals,
323                pspline_min,
324                pspline_max,
325                &lambda_grid,
326                auto_lambda,
327                lambda_pspline,
328                criterion,
329            );
330
331            // Pick the best overall result
332            let (basis_type, result) = match (fourier_best, pspline_best) {
333                (Some(f), Some(p)) => {
334                    if f.score <= p.score {
335                        (ProjectionBasisType::Fourier, f)
336                    } else {
337                        (ProjectionBasisType::Bspline, p)
338                    }
339                }
340                (Some(f), None) => (ProjectionBasisType::Fourier, f),
341                (None, Some(p)) => (ProjectionBasisType::Bspline, p),
342                (None, None) => {
343                    return SingleCurveSelection {
344                        basis_type: ProjectionBasisType::Bspline,
345                        nbasis: pspline_min,
346                        score: f64::INFINITY,
347                        coefficients: Vec::new(),
348                        fitted: Vec::new(),
349                        edf: 0.0,
350                        seasonal_detected,
351                        lambda: f64::NAN,
352                    };
353                }
354            };
355
356            SingleCurveSelection {
357                basis_type,
358                nbasis: result.nbasis,
359                score: result.score,
360                coefficients: result.coefs,
361                fitted: result.fitted,
362                edf: result.edf,
363                seasonal_detected,
364                lambda: result.lambda,
365            }
366        })
367        .collect();
368
369    BasisAutoSelectionResult {
370        selections,
371        criterion,
372    }
373}