fdars_core/basis/
pspline.rs1use super::bspline::bspline_basis;
4use super::helpers::svd_pseudoinverse;
5use crate::matrix::FdMatrix;
6use nalgebra::{DMatrix, DVector};
7
8pub fn difference_matrix(n: usize, order: usize) -> DMatrix<f64> {
10 if order == 0 {
11 return DMatrix::identity(n, n);
12 }
13
14 let mut d = DMatrix::zeros(n - 1, n);
15 for i in 0..(n - 1) {
16 d[(i, i)] = -1.0;
17 d[(i, i + 1)] = 1.0;
18 }
19
20 let mut result = d;
21 for _ in 1..order {
22 if result.nrows() <= 1 {
23 break;
24 }
25 let rows = result.nrows() - 1;
26 let cols = result.ncols();
27 let mut d_next = DMatrix::zeros(rows, cols);
28 for i in 0..rows {
29 for j in 0..cols {
30 d_next[(i, j)] = -result[(i, j)] + result[(i + 1, j)];
31 }
32 }
33 result = d_next;
34 }
35
36 result
37}
38
39#[derive(Debug, Clone, PartialEq)]
41#[non_exhaustive]
42pub struct PsplineFitResult {
43 pub coefficients: FdMatrix,
45 pub fitted: FdMatrix,
47 pub edf: f64,
49 pub rss: f64,
51 pub gcv: f64,
53 pub aic: f64,
55 pub bic: f64,
57 pub n_basis: usize,
59}
60
61pub fn pspline_fit_1d(
63 data: &FdMatrix,
64 argvals: &[f64],
65 nbasis: usize,
66 lambda: f64,
67 order: usize,
68) -> Option<PsplineFitResult> {
69 let n = data.nrows();
70 let m = data.ncols();
71 if n == 0 || m == 0 || nbasis < 2 || argvals.len() != m {
72 return None;
73 }
74
75 let basis = bspline_basis(argvals, nbasis.saturating_sub(4).max(2), 4);
77 let actual_nbasis = basis.len() / m;
78 let b_mat = DMatrix::from_column_slice(m, actual_nbasis, &basis);
79
80 let d = difference_matrix(actual_nbasis, order);
81 let penalty = &d.transpose() * &d;
82
83 let btb = &b_mat.transpose() * &b_mat;
84 let btb_penalized = &btb + lambda * &penalty;
85
86 let btb_inv = svd_pseudoinverse(&btb_penalized)?;
87 let proj = &btb_inv * b_mat.transpose();
88 let h_mat = &b_mat * &proj;
89 let edf: f64 = (0..m).map(|i| h_mat[(i, i)]).sum();
90
91 let mut all_coefs = FdMatrix::zeros(n, actual_nbasis);
92 let mut all_fitted = FdMatrix::zeros(n, m);
93 let mut total_rss = 0.0;
94
95 for i in 0..n {
96 let curve: Vec<f64> = (0..m).map(|j| data[(i, j)]).collect();
97 let curve_vec = DVector::from_vec(curve.clone());
98
99 let bt_y = b_mat.transpose() * &curve_vec;
100 let coefs = &btb_inv * bt_y;
101
102 for k in 0..actual_nbasis {
103 all_coefs[(i, k)] = coefs[k];
104 }
105
106 let fitted = &b_mat * &coefs;
107 for j in 0..m {
108 all_fitted[(i, j)] = fitted[j];
109 let resid = curve[j] - fitted[j];
110 total_rss += resid * resid;
111 }
112 }
113
114 let total_points = (n * m) as f64;
115
116 let gcv_denom = 1.0 - edf / m as f64;
117 let gcv = if gcv_denom.abs() > 1e-10 {
118 (total_rss / total_points) / (gcv_denom * gcv_denom)
119 } else {
120 f64::INFINITY
121 };
122
123 let mse = total_rss / total_points;
124 let aic = total_points * mse.ln() + 2.0 * edf;
125 let bic = total_points * mse.ln() + total_points.ln() * edf;
126
127 Some(PsplineFitResult {
128 coefficients: all_coefs,
129 fitted: all_fitted,
130 edf,
131 rss: total_rss,
132 gcv,
133 aic,
134 bic,
135 n_basis: actual_nbasis,
136 })
137}