1use crate::basis::{bspline_basis, fourier_basis_with_period};
13use crate::helpers::simpsons_weights;
14use crate::matrix::FdMatrix;
15use nalgebra::DMatrix;
16use std::f64::consts::PI;
17
18#[derive(Debug, Clone)]
22pub enum BasisType {
23 Bspline { order: usize },
25 Fourier { period: f64 },
27}
28
29#[derive(Debug, Clone)]
31pub struct FdPar {
32 pub basis_type: BasisType,
34 pub nbasis: usize,
36 pub lambda: f64,
38 pub lfd_order: usize,
40 pub penalty_matrix: Vec<f64>,
42}
43
44#[derive(Debug, Clone)]
46pub struct SmoothBasisResult {
47 pub coefficients: FdMatrix,
49 pub fitted: FdMatrix,
51 pub edf: f64,
53 pub gcv: f64,
55 pub aic: f64,
57 pub bic: f64,
59 pub penalty_matrix: Vec<f64>,
61 pub nbasis: usize,
63}
64
65pub fn bspline_penalty_matrix(
82 argvals: &[f64],
83 nbasis: usize,
84 order: usize,
85 lfd_order: usize,
86) -> Vec<f64> {
87 if nbasis < 2 || order < 1 || lfd_order >= order {
88 return vec![0.0; nbasis * nbasis];
89 }
90
91 let nknots = nbasis.saturating_sub(order).max(2);
92
93 let n_sub = 10;
95 let t_min = argvals[0];
96 let t_max = argvals[argvals.len() - 1];
97 let n_quad = (argvals.len() - 1) * n_sub + 1;
98 let quad_t: Vec<f64> = (0..n_quad)
99 .map(|i| t_min + (t_max - t_min) * i as f64 / (n_quad - 1) as f64)
100 .collect();
101
102 let basis_fine = bspline_basis(&quad_t, nknots, order);
104 let actual_nbasis = basis_fine.len() / n_quad;
105
106 let h = (t_max - t_min) / (n_quad - 1) as f64;
108 let deriv_basis = differentiate_basis_columns(&basis_fine, n_quad, actual_nbasis, h, lfd_order);
109
110 let weights = simpsons_weights(&quad_t);
112
113 integrate_symmetric_penalty(&deriv_basis, &weights, actual_nbasis, n_quad)
115}
116
117pub fn fourier_penalty_matrix(nbasis: usize, period: f64, lfd_order: usize) -> Vec<f64> {
129 let k = nbasis;
130 let mut penalty = vec![0.0; k * k];
131
132 let mut freq = 1;
137 let mut idx = 1;
138 while idx < k {
139 let omega = 2.0 * PI * freq as f64 / period;
140 let eigenval = omega.powi(2 * lfd_order as i32);
141
142 if idx < k {
144 penalty[idx + idx * k] = eigenval;
145 idx += 1;
146 }
147 if idx < k {
149 penalty[idx + idx * k] = eigenval;
150 idx += 1;
151 }
152 freq += 1;
153 }
154
155 penalty
156}
157
158pub fn smooth_basis(data: &FdMatrix, argvals: &[f64], fdpar: &FdPar) -> Option<SmoothBasisResult> {
173 let (n, m) = data.shape();
174 if n == 0 || m == 0 || argvals.len() != m || fdpar.nbasis < 2 {
175 return None;
176 }
177
178 let (basis_flat, actual_nbasis) = evaluate_basis(argvals, &fdpar.basis_type, fdpar.nbasis);
180 let k = actual_nbasis;
181
182 let b_mat = DMatrix::from_column_slice(m, k, &basis_flat);
183 let r_mat = DMatrix::from_column_slice(k, k, &fdpar.penalty_matrix);
184
185 let btb = b_mat.transpose() * &b_mat;
187 let ridge_eps = 1e-10;
188 let system: DMatrix<f64> =
189 &btb + fdpar.lambda * &r_mat + ridge_eps * DMatrix::<f64>::identity(k, k);
190
191 let system_inv = invert_penalized_system(&system, k)?;
193
194 let h_mat = &b_mat * &system_inv * b_mat.transpose();
196 let edf: f64 = (0..m).map(|i| h_mat[(i, i)]).sum();
197
198 let proj = &system_inv * b_mat.transpose();
200 let (all_coefs, all_fitted, total_rss) = project_all_curves(data, &b_mat, &proj, n, m, k);
201
202 let total_points = (n * m) as f64;
203 let gcv = compute_gcv(total_rss, total_points, edf, m);
204 let mse = total_rss / total_points;
205 let aic = total_points * mse.max(1e-300).ln() + 2.0 * edf;
206 let bic = total_points * mse.max(1e-300).ln() + total_points.ln() * edf;
207
208 Some(SmoothBasisResult {
209 coefficients: all_coefs,
210 fitted: all_fitted,
211 edf,
212 gcv,
213 aic,
214 bic,
215 penalty_matrix: fdpar.penalty_matrix.clone(),
216 nbasis: k,
217 })
218}
219
220pub fn smooth_basis_gcv(
233 data: &FdMatrix,
234 argvals: &[f64],
235 basis_type: &BasisType,
236 nbasis: usize,
237 lfd_order: usize,
238 log_lambda_range: (f64, f64),
239 n_grid: usize,
240) -> Option<SmoothBasisResult> {
241 let m = argvals.len();
242 if m == 0 || nbasis < 2 || n_grid < 2 {
243 return None;
244 }
245
246 let penalty = match basis_type {
248 BasisType::Bspline { order } => bspline_penalty_matrix(argvals, nbasis, *order, lfd_order),
249 BasisType::Fourier { period } => fourier_penalty_matrix(nbasis, *period, lfd_order),
250 };
251
252 let (lo, hi) = log_lambda_range;
253 let mut best_gcv = f64::INFINITY;
254 let mut best_result: Option<SmoothBasisResult> = None;
255
256 for i in 0..n_grid {
257 let log_lam = lo + (hi - lo) * i as f64 / (n_grid - 1) as f64;
258 let lam = 10.0_f64.powf(log_lam);
259
260 let fdpar = FdPar {
261 basis_type: basis_type.clone(),
262 nbasis,
263 lambda: lam,
264 lfd_order,
265 penalty_matrix: penalty.clone(),
266 };
267
268 if let Some(result) = smooth_basis(data, argvals, &fdpar) {
269 if result.gcv < best_gcv {
270 best_gcv = result.gcv;
271 best_result = Some(result);
272 }
273 }
274 }
275
276 best_result
277}
278
279fn differentiate_basis_columns(
283 basis: &[f64],
284 n_quad: usize,
285 nbasis: usize,
286 h: f64,
287 lfd_order: usize,
288) -> Vec<f64> {
289 let mut deriv = basis.to_vec();
290 for _ in 0..lfd_order {
291 let mut new_deriv = vec![0.0; n_quad * nbasis];
292 for j in 0..nbasis {
293 let col: Vec<f64> = (0..n_quad).map(|i| deriv[i + j * n_quad]).collect();
294 let grad = crate::helpers::gradient_uniform(&col, h);
295 for i in 0..n_quad {
296 new_deriv[i + j * n_quad] = grad[i];
297 }
298 }
299 deriv = new_deriv;
300 }
301 deriv
302}
303
304fn integrate_symmetric_penalty(
306 deriv_basis: &[f64],
307 weights: &[f64],
308 k: usize,
309 n_quad: usize,
310) -> Vec<f64> {
311 let mut penalty = vec![0.0; k * k];
312 for j in 0..k {
313 for l in j..k {
314 let mut val = 0.0;
315 for i in 0..n_quad {
316 val += deriv_basis[i + j * n_quad] * deriv_basis[i + l * n_quad] * weights[i];
317 }
318 penalty[j + l * k] = val;
319 penalty[l + j * k] = val;
320 }
321 }
322 penalty
323}
324
325fn evaluate_basis(argvals: &[f64], basis_type: &BasisType, nbasis: usize) -> (Vec<f64>, usize) {
327 let m = argvals.len();
328 match basis_type {
329 BasisType::Bspline { order } => {
330 let nknots = nbasis.saturating_sub(*order).max(2);
331 let basis = bspline_basis(argvals, nknots, *order);
332 let actual = basis.len() / m;
333 (basis, actual)
334 }
335 BasisType::Fourier { period } => {
336 let basis = fourier_basis_with_period(argvals, nbasis, *period);
337 (basis, nbasis)
338 }
339 }
340}
341
342fn invert_penalized_system(system: &DMatrix<f64>, k: usize) -> Option<DMatrix<f64>> {
344 if let Some(chol) = system.clone().cholesky() {
345 return Some(chol.inverse());
346 }
347 let svd = nalgebra::SVD::new(system.clone(), true, true);
349 let u = svd.u.as_ref()?;
350 let v_t = svd.v_t.as_ref()?;
351 let max_sv: f64 = svd.singular_values.iter().cloned().fold(0.0_f64, f64::max);
352 let eps = 1e-10 * max_sv;
353 let mut inv = DMatrix::<f64>::zeros(k, k);
354 for ii in 0..k {
355 for jj in 0..k {
356 let mut sum = 0.0;
357 for s in 0..k.min(svd.singular_values.len()) {
358 if svd.singular_values[s] > eps {
359 sum += v_t[(s, ii)] / svd.singular_values[s] * u[(jj, s)];
360 }
361 }
362 inv[(ii, jj)] = sum;
363 }
364 }
365 Some(inv)
366}
367
368fn project_all_curves(
370 data: &FdMatrix,
371 b_mat: &DMatrix<f64>,
372 proj: &DMatrix<f64>,
373 n: usize,
374 m: usize,
375 k: usize,
376) -> (FdMatrix, FdMatrix, f64) {
377 let mut all_coefs = FdMatrix::zeros(n, k);
378 let mut all_fitted = FdMatrix::zeros(n, m);
379 let mut total_rss = 0.0;
380
381 for i in 0..n {
382 let curve: Vec<f64> = (0..m).map(|j| data[(i, j)]).collect();
383 let y_vec = nalgebra::DVector::from_vec(curve.clone());
384 let coefs = proj * &y_vec;
385
386 for j in 0..k {
387 all_coefs[(i, j)] = coefs[j];
388 }
389 let fitted = b_mat * &coefs;
390 for j in 0..m {
391 all_fitted[(i, j)] = fitted[j];
392 let resid = curve[j] - fitted[j];
393 total_rss += resid * resid;
394 }
395 }
396
397 (all_coefs, all_fitted, total_rss)
398}
399
400fn compute_gcv(rss: f64, n_points: f64, edf: f64, m: usize) -> f64 {
402 let gcv_denom = 1.0 - edf / m as f64;
403 if gcv_denom.abs() > 1e-10 {
404 (rss / n_points) / (gcv_denom * gcv_denom)
405 } else {
406 f64::INFINITY
407 }
408}
409
410#[cfg(test)]
411mod tests {
412 use super::*;
413 use std::f64::consts::PI;
414
415 fn uniform_grid(m: usize) -> Vec<f64> {
416 (0..m).map(|i| i as f64 / (m - 1) as f64).collect()
417 }
418
419 #[test]
420 fn test_bspline_penalty_matrix_symmetric() {
421 let t = uniform_grid(101);
422 let penalty = bspline_penalty_matrix(&t, 15, 4, 2);
423 let _k = 15; let actual_k = (penalty.len() as f64).sqrt() as usize;
425 for i in 0..actual_k {
426 for j in 0..actual_k {
427 assert!(
428 (penalty[i + j * actual_k] - penalty[j + i * actual_k]).abs() < 1e-10,
429 "Penalty matrix not symmetric at ({}, {})",
430 i,
431 j
432 );
433 }
434 }
435 }
436
437 #[test]
438 fn test_bspline_penalty_matrix_positive_semidefinite() {
439 let t = uniform_grid(101);
440 let penalty = bspline_penalty_matrix(&t, 10, 4, 2);
441 let k = (penalty.len() as f64).sqrt() as usize;
442 for i in 0..k {
444 assert!(
445 penalty[i + i * k] >= -1e-10,
446 "Diagonal element {} is negative: {}",
447 i,
448 penalty[i + i * k]
449 );
450 }
451 }
452
453 #[test]
454 fn test_fourier_penalty_diagonal() {
455 let penalty = fourier_penalty_matrix(7, 1.0, 2);
456 for i in 0..7 {
458 for j in 0..7 {
459 if i != j {
460 assert!(
461 penalty[i + j * 7].abs() < 1e-10,
462 "Off-diagonal ({},{}) = {}",
463 i,
464 j,
465 penalty[i + j * 7]
466 );
467 }
468 }
469 }
470 assert!(penalty[0].abs() < 1e-10);
472 assert!(penalty[1 + 7] > 0.0);
474 assert!(penalty[3 + 3 * 7] > penalty[1 + 7]);
475 }
476
477 #[test]
478 fn test_smooth_basis_bspline() {
479 let m = 101;
480 let n = 5;
481 let t = uniform_grid(m);
482
483 let mut data = FdMatrix::zeros(n, m);
485 for i in 0..n {
486 for j in 0..m {
487 data[(i, j)] = (2.0 * PI * t[j]).sin() + 0.1 * (i as f64 * 0.3 + j as f64 * 0.01);
488 }
489 }
490
491 let nbasis = 15;
492 let penalty = bspline_penalty_matrix(&t, nbasis, 4, 2);
493 let _actual_k = (penalty.len() as f64).sqrt() as usize;
494
495 let fdpar = FdPar {
496 basis_type: BasisType::Bspline { order: 4 },
497 nbasis,
498 lambda: 1e-4,
499 lfd_order: 2,
500 penalty_matrix: penalty,
501 };
502
503 let result = smooth_basis(&data, &t, &fdpar);
504 assert!(result.is_some(), "smooth_basis should succeed");
505
506 let res = result.unwrap();
507 assert_eq!(res.fitted.shape(), (n, m));
508 assert_eq!(res.coefficients.nrows(), n);
509 assert!(res.edf > 0.0, "EDF should be positive");
510 assert!(res.gcv > 0.0, "GCV should be positive");
511 }
512
513 #[test]
514 fn test_smooth_basis_fourier() {
515 let m = 101;
516 let n = 3;
517 let t = uniform_grid(m);
518
519 let mut data = FdMatrix::zeros(n, m);
520 for i in 0..n {
521 for j in 0..m {
522 data[(i, j)] = (2.0 * PI * t[j]).sin() + (4.0 * PI * t[j]).cos();
523 }
524 }
525
526 let nbasis = 7;
527 let period = 1.0;
528 let penalty = fourier_penalty_matrix(nbasis, period, 2);
529
530 let fdpar = FdPar {
531 basis_type: BasisType::Fourier { period },
532 nbasis,
533 lambda: 1e-6,
534 lfd_order: 2,
535 penalty_matrix: penalty,
536 };
537
538 let result = smooth_basis(&data, &t, &fdpar);
539 assert!(result.is_some());
540
541 let res = result.unwrap();
542 for j in 0..m {
544 let expected = (2.0 * PI * t[j]).sin() + (4.0 * PI * t[j]).cos();
545 assert!(
546 (res.fitted[(0, j)] - expected).abs() < 0.1,
547 "Fourier fit poor at j={}: got {}, expected {}",
548 j,
549 res.fitted[(0, j)],
550 expected
551 );
552 }
553 }
554
555 #[test]
556 fn test_smooth_basis_gcv_selects_reasonable_lambda() {
557 let m = 101;
558 let n = 5;
559 let t = uniform_grid(m);
560
561 let mut data = FdMatrix::zeros(n, m);
562 for i in 0..n {
563 for j in 0..m {
564 data[(i, j)] =
565 (2.0 * PI * t[j]).sin() + 0.1 * ((i * 37 + j * 13) % 20) as f64 / 20.0;
566 }
567 }
568
569 let basis_type = BasisType::Bspline { order: 4 };
570 let result = smooth_basis_gcv(&data, &t, &basis_type, 15, 2, (-8.0, 4.0), 25);
571 assert!(result.is_some(), "GCV search should succeed");
572 }
573
574 #[test]
575 fn test_smooth_basis_large_lambda_reduces_edf() {
576 let m = 101;
577 let n = 3;
578 let t = uniform_grid(m);
579
580 let mut data = FdMatrix::zeros(n, m);
581 for i in 0..n {
582 for j in 0..m {
583 data[(i, j)] = (2.0 * PI * t[j]).sin();
584 }
585 }
586
587 let nbasis = 15;
588 let penalty = bspline_penalty_matrix(&t, nbasis, 4, 2);
589 let _actual_k = (penalty.len() as f64).sqrt() as usize;
590
591 let fdpar_small = FdPar {
592 basis_type: BasisType::Bspline { order: 4 },
593 nbasis,
594 lambda: 1e-8,
595 lfd_order: 2,
596 penalty_matrix: penalty.clone(),
597 };
598 let fdpar_large = FdPar {
599 basis_type: BasisType::Bspline { order: 4 },
600 nbasis,
601 lambda: 1e2,
602 lfd_order: 2,
603 penalty_matrix: penalty,
604 };
605
606 let res_small = smooth_basis(&data, &t, &fdpar_small).unwrap();
607 let res_large = smooth_basis(&data, &t, &fdpar_large).unwrap();
608
609 assert!(
610 res_large.edf < res_small.edf,
611 "Larger lambda should reduce EDF: {} vs {}",
612 res_large.edf,
613 res_small.edf
614 );
615 }
616}