Skip to main content

fdars_core/basis/
pspline.rs

1//! P-spline fitting and penalty (difference) matrices.
2
3use super::bspline::bspline_basis;
4use super::helpers::svd_pseudoinverse;
5use crate::matrix::FdMatrix;
6use nalgebra::{DMatrix, DVector};
7
8/// Compute difference matrix for P-spline penalty.
9pub fn difference_matrix(n: usize, order: usize) -> DMatrix<f64> {
10    if order == 0 {
11        return DMatrix::identity(n, n);
12    }
13
14    let mut d = DMatrix::zeros(n - 1, n);
15    for i in 0..(n - 1) {
16        d[(i, i)] = -1.0;
17        d[(i, i + 1)] = 1.0;
18    }
19
20    let mut result = d;
21    for _ in 1..order {
22        if result.nrows() <= 1 {
23            break;
24        }
25        let rows = result.nrows() - 1;
26        let cols = result.ncols();
27        let mut d_next = DMatrix::zeros(rows, cols);
28        for i in 0..rows {
29            for j in 0..cols {
30                d_next[(i, j)] = -result[(i, j)] + result[(i + 1, j)];
31            }
32        }
33        result = d_next;
34    }
35
36    result
37}
38
39/// Result of P-spline fitting.
40#[derive(Debug, Clone, PartialEq)]
41#[non_exhaustive]
42pub struct PsplineFitResult {
43    /// Coefficient matrix (n x nbasis)
44    pub coefficients: FdMatrix,
45    /// Fitted values (n x m)
46    pub fitted: FdMatrix,
47    /// Effective degrees of freedom
48    pub edf: f64,
49    /// Residual sum of squares
50    pub rss: f64,
51    /// GCV score
52    pub gcv: f64,
53    /// AIC
54    pub aic: f64,
55    /// BIC
56    pub bic: f64,
57    /// Number of basis functions
58    pub n_basis: usize,
59}
60
61/// Fit P-splines to functional data.
62pub fn pspline_fit_1d(
63    data: &FdMatrix,
64    argvals: &[f64],
65    nbasis: usize,
66    lambda: f64,
67    order: usize,
68) -> Option<PsplineFitResult> {
69    let n = data.nrows();
70    let m = data.ncols();
71    if n == 0 || m == 0 || nbasis < 2 || argvals.len() != m {
72        return None;
73    }
74
75    // For order 4 B-splines: nbasis = nknots + order, so nknots = nbasis - 4
76    let basis = bspline_basis(argvals, nbasis.saturating_sub(4).max(2), 4);
77    let actual_nbasis = basis.len() / m;
78    let b_mat = DMatrix::from_column_slice(m, actual_nbasis, &basis);
79
80    let d = difference_matrix(actual_nbasis, order);
81    let penalty = &d.transpose() * &d;
82
83    let btb = &b_mat.transpose() * &b_mat;
84    let btb_penalized = &btb + lambda * &penalty;
85
86    let btb_inv = svd_pseudoinverse(&btb_penalized)?;
87    let proj = &btb_inv * b_mat.transpose();
88    let h_mat = &b_mat * &proj;
89    let edf: f64 = (0..m).map(|i| h_mat[(i, i)]).sum();
90
91    let mut all_coefs = FdMatrix::zeros(n, actual_nbasis);
92    let mut all_fitted = FdMatrix::zeros(n, m);
93    let mut total_rss = 0.0;
94
95    for i in 0..n {
96        let curve: Vec<f64> = (0..m).map(|j| data[(i, j)]).collect();
97        let curve_vec = DVector::from_vec(curve.clone());
98
99        let bt_y = b_mat.transpose() * &curve_vec;
100        let coefs = &btb_inv * bt_y;
101
102        for k in 0..actual_nbasis {
103            all_coefs[(i, k)] = coefs[k];
104        }
105
106        let fitted = &b_mat * &coefs;
107        for j in 0..m {
108            all_fitted[(i, j)] = fitted[j];
109            let resid = curve[j] - fitted[j];
110            total_rss += resid * resid;
111        }
112    }
113
114    let total_points = (n * m) as f64;
115
116    let gcv_denom = 1.0 - edf / m as f64;
117    let gcv = if gcv_denom.abs() > 1e-10 {
118        (total_rss / total_points) / (gcv_denom * gcv_denom)
119    } else {
120        f64::INFINITY
121    };
122
123    let mse = total_rss / total_points;
124    let aic = total_points * mse.ln() + 2.0 * edf;
125    let bic = total_points * mse.ln() + total_points.ln() * edf;
126
127    Some(PsplineFitResult {
128        coefficients: all_coefs,
129        fitted: all_fitted,
130        edf,
131        rss: total_rss,
132        gcv,
133        aic,
134        bic,
135        n_basis: actual_nbasis,
136    })
137}