1use super::bspline::bspline_basis;
4use super::fourier::fourier_basis;
5use super::helpers::{compute_model_criterion, svd_pseudoinverse};
6use super::projection::ProjectionBasisType;
7use super::pspline::difference_matrix;
8use crate::iter_maybe_parallel;
9use crate::matrix::FdMatrix;
10use nalgebra::{DMatrix, DVector};
11#[cfg(feature = "parallel")]
12use rayon::iter::ParallelIterator;
13
14#[derive(Debug, Clone)]
16#[non_exhaustive]
17pub struct SingleCurveSelection {
18 pub basis_type: ProjectionBasisType,
20 pub nbasis: usize,
22 pub score: f64,
24 pub coefficients: Vec<f64>,
26 pub fitted: Vec<f64>,
28 pub edf: f64,
30 pub seasonal_detected: bool,
32 pub lambda: f64,
34}
35
36#[derive(Debug, Clone)]
38#[non_exhaustive]
39pub struct BasisAutoSelectionResult {
40 pub selections: Vec<SingleCurveSelection>,
42 pub criterion: i32,
44}
45
46fn detect_seasonality_fft(curve: &[f64]) -> bool {
51 use rustfft::{num_complex::Complex, FftPlanner};
52
53 let n = curve.len();
54 if n < 8 {
55 return false;
56 }
57
58 let mean: f64 = curve.iter().sum::<f64>() / n as f64;
60 let mut input: Vec<Complex<f64>> = curve.iter().map(|&x| Complex::new(x - mean, 0.0)).collect();
61
62 let mut planner = FftPlanner::new();
63 let fft = planner.plan_fft_forward(n);
64 fft.process(&mut input);
65
66 let powers: Vec<f64> = input[1..n / 2]
68 .iter()
69 .map(nalgebra::Complex::norm_sqr)
70 .collect();
71
72 if powers.is_empty() {
73 return false;
74 }
75
76 let max_power = powers.iter().copied().fold(0.0_f64, f64::max);
77 let mean_power = powers.iter().sum::<f64>() / powers.len() as f64;
78
79 max_power > 2.0 * mean_power
81}
82
83fn fit_curve_fourier(
85 curve: &[f64],
86 m: usize,
87 argvals: &[f64],
88 nbasis: usize,
89 criterion: i32,
90) -> Option<(f64, Vec<f64>, Vec<f64>, f64)> {
91 let nbasis = if nbasis % 2 == 0 { nbasis + 1 } else { nbasis };
92
93 let basis = fourier_basis(argvals, nbasis);
94 let actual_nbasis = basis.len() / m;
95 let b_mat = DMatrix::from_column_slice(m, actual_nbasis, &basis);
96
97 let btb = &b_mat.transpose() * &b_mat;
98 let btb_inv = svd_pseudoinverse(&btb)?;
99 let proj = &btb_inv * b_mat.transpose();
100 let h_mat = &b_mat * &proj;
101 let edf: f64 = (0..m).map(|i| h_mat[(i, i)]).sum();
102
103 let curve_vec = DVector::from_column_slice(curve);
104 let coefs = &btb_inv * (b_mat.transpose() * &curve_vec);
105 let fitted = &b_mat * &coefs;
106
107 let rss: f64 = (0..m).map(|j| (curve[j] - fitted[j]).powi(2)).sum();
108 let score = compute_model_criterion(rss, m as f64, edf, criterion);
109
110 let coef_vec: Vec<f64> = (0..actual_nbasis).map(|k| coefs[k]).collect();
111 let fitted_vec: Vec<f64> = (0..m).map(|j| fitted[j]).collect();
112
113 Some((score, coef_vec, fitted_vec, edf))
114}
115
116fn fit_curve_pspline(
118 curve: &[f64],
119 m: usize,
120 argvals: &[f64],
121 nbasis: usize,
122 lambda: f64,
123 order: usize,
124 criterion: i32,
125) -> Option<(f64, Vec<f64>, Vec<f64>, f64)> {
126 let basis = bspline_basis(argvals, nbasis.saturating_sub(4).max(2), 4);
127 let actual_nbasis = basis.len() / m;
128 let b_mat = DMatrix::from_column_slice(m, actual_nbasis, &basis);
129
130 let d = difference_matrix(actual_nbasis, order);
131 let penalty = &d.transpose() * &d;
132 let btb = &b_mat.transpose() * &b_mat;
133 let btb_penalized = &btb + lambda * &penalty;
134
135 let btb_inv = svd_pseudoinverse(&btb_penalized)?;
136 let proj = &btb_inv * b_mat.transpose();
137 let h_mat = &b_mat * &proj;
138 let edf: f64 = (0..m).map(|i| h_mat[(i, i)]).sum();
139
140 let curve_vec = DVector::from_column_slice(curve);
141 let coefs = &btb_inv * (b_mat.transpose() * &curve_vec);
142 let fitted = &b_mat * &coefs;
143
144 let rss: f64 = (0..m).map(|j| (curve[j] - fitted[j]).powi(2)).sum();
145 let score = compute_model_criterion(rss, m as f64, edf, criterion);
146
147 let coef_vec: Vec<f64> = (0..actual_nbasis).map(|k| coefs[k]).collect();
148 let fitted_vec: Vec<f64> = (0..m).map(|j| fitted[j]).collect();
149
150 Some((score, coef_vec, fitted_vec, edf))
151}
152
153struct BasisSearchResult {
155 score: f64,
156 nbasis: usize,
157 coefs: Vec<f64>,
158 fitted: Vec<f64>,
159 edf: f64,
160 lambda: f64,
161}
162
163fn search_fourier_basis(
165 curve: &[f64],
166 m: usize,
167 argvals: &[f64],
168 fourier_min: usize,
169 fourier_max: usize,
170 seasonal: bool,
171 criterion: i32,
172) -> Option<BasisSearchResult> {
173 let fourier_start = if seasonal {
174 fourier_min.max(5)
175 } else {
176 fourier_min
177 };
178 let mut nb = if fourier_start % 2 == 0 {
179 fourier_start + 1
180 } else {
181 fourier_start
182 };
183
184 let mut best: Option<BasisSearchResult> = None;
185 while nb <= fourier_max {
186 if let Some((score, coefs, fitted, edf)) =
187 fit_curve_fourier(curve, m, argvals, nb, criterion)
188 {
189 if score.is_finite() && best.as_ref().map_or(true, |b| score < b.score) {
190 best = Some(BasisSearchResult {
191 score,
192 nbasis: nb,
193 coefs,
194 fitted,
195 edf,
196 lambda: f64::NAN,
197 });
198 }
199 }
200 nb += 2;
201 }
202 best
203}
204
205fn try_pspline_fit_update(
207 curve: &[f64],
208 m: usize,
209 argvals: &[f64],
210 nb: usize,
211 lam: f64,
212 criterion: i32,
213 best: &mut Option<BasisSearchResult>,
214) {
215 if let Some((score, coefs, fitted, edf)) =
216 fit_curve_pspline(curve, m, argvals, nb, lam, 2, criterion)
217 {
218 if score.is_finite() && best.as_ref().map_or(true, |b| score < b.score) {
219 *best = Some(BasisSearchResult {
220 score,
221 nbasis: nb,
222 coefs,
223 fitted,
224 edf,
225 lambda: lam,
226 });
227 }
228 }
229}
230
231fn search_pspline_basis(
233 curve: &[f64],
234 m: usize,
235 argvals: &[f64],
236 pspline_min: usize,
237 pspline_max: usize,
238 lambda_grid: &[f64],
239 auto_lambda: bool,
240 lambda: f64,
241 criterion: i32,
242) -> Option<BasisSearchResult> {
243 let mut best: Option<BasisSearchResult> = None;
244 for nb in pspline_min..=pspline_max {
245 let lambdas: Box<dyn Iterator<Item = f64>> = if auto_lambda {
246 Box::new(lambda_grid.iter().copied())
247 } else {
248 Box::new(std::iter::once(lambda))
249 };
250 for lam in lambdas {
251 try_pspline_fit_update(curve, m, argvals, nb, lam, criterion, &mut best);
252 }
253 }
254 best
255}
256
257pub fn select_basis_auto_1d(
275 data: &FdMatrix,
276 argvals: &[f64],
277 criterion: i32,
278 nbasis_min: usize,
279 nbasis_max: usize,
280 lambda_pspline: f64,
281 use_seasonal_hint: bool,
282) -> BasisAutoSelectionResult {
283 let n = data.nrows();
284 let m = data.ncols();
285 let fourier_min = if nbasis_min > 0 { nbasis_min.max(3) } else { 3 };
286 let fourier_max = if nbasis_max > 0 {
287 nbasis_max.min(m / 3).min(25)
288 } else {
289 (m / 3).min(25)
290 };
291 let pspline_min = if nbasis_min > 0 { nbasis_min.max(6) } else { 6 };
292 let pspline_max = if nbasis_max > 0 {
293 nbasis_max.min(m / 2).min(40)
294 } else {
295 (m / 2).min(40)
296 };
297
298 let lambda_grid = [0.001, 0.01, 0.1, 1.0, 10.0, 100.0];
299 let auto_lambda = lambda_pspline < 0.0;
300
301 let selections: Vec<SingleCurveSelection> = iter_maybe_parallel!(0..n)
302 .map(|i| {
303 let curve: Vec<f64> = (0..m).map(|j| data[(i, j)]).collect();
304 let seasonal_detected = if use_seasonal_hint {
305 detect_seasonality_fft(&curve)
306 } else {
307 false
308 };
309
310 let fourier_best = search_fourier_basis(
311 &curve,
312 m,
313 argvals,
314 fourier_min,
315 fourier_max,
316 seasonal_detected,
317 criterion,
318 );
319 let pspline_best = search_pspline_basis(
320 &curve,
321 m,
322 argvals,
323 pspline_min,
324 pspline_max,
325 &lambda_grid,
326 auto_lambda,
327 lambda_pspline,
328 criterion,
329 );
330
331 let (basis_type, result) = match (fourier_best, pspline_best) {
333 (Some(f), Some(p)) => {
334 if f.score <= p.score {
335 (ProjectionBasisType::Fourier, f)
336 } else {
337 (ProjectionBasisType::Bspline, p)
338 }
339 }
340 (Some(f), None) => (ProjectionBasisType::Fourier, f),
341 (None, Some(p)) => (ProjectionBasisType::Bspline, p),
342 (None, None) => {
343 return SingleCurveSelection {
344 basis_type: ProjectionBasisType::Bspline,
345 nbasis: pspline_min,
346 score: f64::INFINITY,
347 coefficients: Vec::new(),
348 fitted: Vec::new(),
349 edf: 0.0,
350 seasonal_detected,
351 lambda: f64::NAN,
352 };
353 }
354 };
355
356 SingleCurveSelection {
357 basis_type,
358 nbasis: result.nbasis,
359 score: result.score,
360 coefficients: result.coefs,
361 fitted: result.fitted,
362 edf: result.edf,
363 seasonal_detected,
364 lambda: result.lambda,
365 }
366 })
367 .collect();
368
369 BasisAutoSelectionResult {
370 selections,
371 criterion,
372 }
373}