1use super::bspline::bspline_basis;
4use super::fourier::fourier_basis;
5use super::helpers::{compute_model_criterion, svd_pseudoinverse};
6use super::projection::ProjectionBasisType;
7use super::pspline::difference_matrix;
8use crate::iter_maybe_parallel;
9use crate::matrix::FdMatrix;
10use nalgebra::{DMatrix, DVector};
11#[cfg(feature = "parallel")]
12use rayon::iter::ParallelIterator;
13
14#[derive(Debug, Clone)]
16pub struct SingleCurveSelection {
17 pub basis_type: ProjectionBasisType,
19 pub nbasis: usize,
21 pub score: f64,
23 pub coefficients: Vec<f64>,
25 pub fitted: Vec<f64>,
27 pub edf: f64,
29 pub seasonal_detected: bool,
31 pub lambda: f64,
33}
34
35#[derive(Debug, Clone)]
37pub struct BasisAutoSelectionResult {
38 pub selections: Vec<SingleCurveSelection>,
40 pub criterion: i32,
42}
43
44fn detect_seasonality_fft(curve: &[f64]) -> bool {
49 use rustfft::{num_complex::Complex, FftPlanner};
50
51 let n = curve.len();
52 if n < 8 {
53 return false;
54 }
55
56 let mean: f64 = curve.iter().sum::<f64>() / n as f64;
58 let mut input: Vec<Complex<f64>> = curve.iter().map(|&x| Complex::new(x - mean, 0.0)).collect();
59
60 let mut planner = FftPlanner::new();
61 let fft = planner.plan_fft_forward(n);
62 fft.process(&mut input);
63
64 let powers: Vec<f64> = input[1..n / 2]
66 .iter()
67 .map(nalgebra::Complex::norm_sqr)
68 .collect();
69
70 if powers.is_empty() {
71 return false;
72 }
73
74 let max_power = powers.iter().copied().fold(0.0_f64, f64::max);
75 let mean_power = powers.iter().sum::<f64>() / powers.len() as f64;
76
77 max_power > 2.0 * mean_power
79}
80
81fn fit_curve_fourier(
83 curve: &[f64],
84 m: usize,
85 argvals: &[f64],
86 nbasis: usize,
87 criterion: i32,
88) -> Option<(f64, Vec<f64>, Vec<f64>, f64)> {
89 let nbasis = if nbasis % 2 == 0 { nbasis + 1 } else { nbasis };
90
91 let basis = fourier_basis(argvals, nbasis);
92 let actual_nbasis = basis.len() / m;
93 let b_mat = DMatrix::from_column_slice(m, actual_nbasis, &basis);
94
95 let btb = &b_mat.transpose() * &b_mat;
96 let btb_inv = svd_pseudoinverse(&btb)?;
97 let proj = &btb_inv * b_mat.transpose();
98 let h_mat = &b_mat * &proj;
99 let edf: f64 = (0..m).map(|i| h_mat[(i, i)]).sum();
100
101 let curve_vec = DVector::from_column_slice(curve);
102 let coefs = &btb_inv * (b_mat.transpose() * &curve_vec);
103 let fitted = &b_mat * &coefs;
104
105 let rss: f64 = (0..m).map(|j| (curve[j] - fitted[j]).powi(2)).sum();
106 let score = compute_model_criterion(rss, m as f64, edf, criterion);
107
108 let coef_vec: Vec<f64> = (0..actual_nbasis).map(|k| coefs[k]).collect();
109 let fitted_vec: Vec<f64> = (0..m).map(|j| fitted[j]).collect();
110
111 Some((score, coef_vec, fitted_vec, edf))
112}
113
114fn fit_curve_pspline(
116 curve: &[f64],
117 m: usize,
118 argvals: &[f64],
119 nbasis: usize,
120 lambda: f64,
121 order: usize,
122 criterion: i32,
123) -> Option<(f64, Vec<f64>, Vec<f64>, f64)> {
124 let basis = bspline_basis(argvals, nbasis.saturating_sub(4).max(2), 4);
125 let actual_nbasis = basis.len() / m;
126 let b_mat = DMatrix::from_column_slice(m, actual_nbasis, &basis);
127
128 let d = difference_matrix(actual_nbasis, order);
129 let penalty = &d.transpose() * &d;
130 let btb = &b_mat.transpose() * &b_mat;
131 let btb_penalized = &btb + lambda * &penalty;
132
133 let btb_inv = svd_pseudoinverse(&btb_penalized)?;
134 let proj = &btb_inv * b_mat.transpose();
135 let h_mat = &b_mat * &proj;
136 let edf: f64 = (0..m).map(|i| h_mat[(i, i)]).sum();
137
138 let curve_vec = DVector::from_column_slice(curve);
139 let coefs = &btb_inv * (b_mat.transpose() * &curve_vec);
140 let fitted = &b_mat * &coefs;
141
142 let rss: f64 = (0..m).map(|j| (curve[j] - fitted[j]).powi(2)).sum();
143 let score = compute_model_criterion(rss, m as f64, edf, criterion);
144
145 let coef_vec: Vec<f64> = (0..actual_nbasis).map(|k| coefs[k]).collect();
146 let fitted_vec: Vec<f64> = (0..m).map(|j| fitted[j]).collect();
147
148 Some((score, coef_vec, fitted_vec, edf))
149}
150
151struct BasisSearchResult {
153 score: f64,
154 nbasis: usize,
155 coefs: Vec<f64>,
156 fitted: Vec<f64>,
157 edf: f64,
158 lambda: f64,
159}
160
161fn search_fourier_basis(
163 curve: &[f64],
164 m: usize,
165 argvals: &[f64],
166 fourier_min: usize,
167 fourier_max: usize,
168 seasonal: bool,
169 criterion: i32,
170) -> Option<BasisSearchResult> {
171 let fourier_start = if seasonal {
172 fourier_min.max(5)
173 } else {
174 fourier_min
175 };
176 let mut nb = if fourier_start % 2 == 0 {
177 fourier_start + 1
178 } else {
179 fourier_start
180 };
181
182 let mut best: Option<BasisSearchResult> = None;
183 while nb <= fourier_max {
184 if let Some((score, coefs, fitted, edf)) =
185 fit_curve_fourier(curve, m, argvals, nb, criterion)
186 {
187 if score.is_finite() && best.as_ref().map_or(true, |b| score < b.score) {
188 best = Some(BasisSearchResult {
189 score,
190 nbasis: nb,
191 coefs,
192 fitted,
193 edf,
194 lambda: f64::NAN,
195 });
196 }
197 }
198 nb += 2;
199 }
200 best
201}
202
203fn try_pspline_fit_update(
205 curve: &[f64],
206 m: usize,
207 argvals: &[f64],
208 nb: usize,
209 lam: f64,
210 criterion: i32,
211 best: &mut Option<BasisSearchResult>,
212) {
213 if let Some((score, coefs, fitted, edf)) =
214 fit_curve_pspline(curve, m, argvals, nb, lam, 2, criterion)
215 {
216 if score.is_finite() && best.as_ref().map_or(true, |b| score < b.score) {
217 *best = Some(BasisSearchResult {
218 score,
219 nbasis: nb,
220 coefs,
221 fitted,
222 edf,
223 lambda: lam,
224 });
225 }
226 }
227}
228
229fn search_pspline_basis(
231 curve: &[f64],
232 m: usize,
233 argvals: &[f64],
234 pspline_min: usize,
235 pspline_max: usize,
236 lambda_grid: &[f64],
237 auto_lambda: bool,
238 lambda: f64,
239 criterion: i32,
240) -> Option<BasisSearchResult> {
241 let mut best: Option<BasisSearchResult> = None;
242 for nb in pspline_min..=pspline_max {
243 let lambdas: Box<dyn Iterator<Item = f64>> = if auto_lambda {
244 Box::new(lambda_grid.iter().copied())
245 } else {
246 Box::new(std::iter::once(lambda))
247 };
248 for lam in lambdas {
249 try_pspline_fit_update(curve, m, argvals, nb, lam, criterion, &mut best);
250 }
251 }
252 best
253}
254
255pub fn select_basis_auto_1d(
273 data: &FdMatrix,
274 argvals: &[f64],
275 criterion: i32,
276 nbasis_min: usize,
277 nbasis_max: usize,
278 lambda_pspline: f64,
279 use_seasonal_hint: bool,
280) -> BasisAutoSelectionResult {
281 let n = data.nrows();
282 let m = data.ncols();
283 let fourier_min = if nbasis_min > 0 { nbasis_min.max(3) } else { 3 };
284 let fourier_max = if nbasis_max > 0 {
285 nbasis_max.min(m / 3).min(25)
286 } else {
287 (m / 3).min(25)
288 };
289 let pspline_min = if nbasis_min > 0 { nbasis_min.max(6) } else { 6 };
290 let pspline_max = if nbasis_max > 0 {
291 nbasis_max.min(m / 2).min(40)
292 } else {
293 (m / 2).min(40)
294 };
295
296 let lambda_grid = [0.001, 0.01, 0.1, 1.0, 10.0, 100.0];
297 let auto_lambda = lambda_pspline < 0.0;
298
299 let selections: Vec<SingleCurveSelection> = iter_maybe_parallel!(0..n)
300 .map(|i| {
301 let curve: Vec<f64> = (0..m).map(|j| data[(i, j)]).collect();
302 let seasonal_detected = if use_seasonal_hint {
303 detect_seasonality_fft(&curve)
304 } else {
305 false
306 };
307
308 let fourier_best = search_fourier_basis(
309 &curve,
310 m,
311 argvals,
312 fourier_min,
313 fourier_max,
314 seasonal_detected,
315 criterion,
316 );
317 let pspline_best = search_pspline_basis(
318 &curve,
319 m,
320 argvals,
321 pspline_min,
322 pspline_max,
323 &lambda_grid,
324 auto_lambda,
325 lambda_pspline,
326 criterion,
327 );
328
329 let (basis_type, result) = match (fourier_best, pspline_best) {
331 (Some(f), Some(p)) => {
332 if f.score <= p.score {
333 (ProjectionBasisType::Fourier, f)
334 } else {
335 (ProjectionBasisType::Bspline, p)
336 }
337 }
338 (Some(f), None) => (ProjectionBasisType::Fourier, f),
339 (None, Some(p)) => (ProjectionBasisType::Bspline, p),
340 (None, None) => {
341 return SingleCurveSelection {
342 basis_type: ProjectionBasisType::Bspline,
343 nbasis: pspline_min,
344 score: f64::INFINITY,
345 coefficients: Vec::new(),
346 fitted: Vec::new(),
347 edf: 0.0,
348 seasonal_detected,
349 lambda: f64::NAN,
350 };
351 }
352 };
353
354 SingleCurveSelection {
355 basis_type,
356 nbasis: result.nbasis,
357 score: result.score,
358 coefficients: result.coefs,
359 fitted: result.fitted,
360 edf: result.edf,
361 seasonal_detected,
362 lambda: result.lambda,
363 }
364 })
365 .collect();
366
367 BasisAutoSelectionResult {
368 selections,
369 criterion,
370 }
371}