Skip to main content

fdars_core/elastic_regression/
mod.rs

1//! Elastic regression models (alignment-integrated regression).
2//!
3//! These models from fdasrvf align curves during the regression fitting process,
4//! jointly optimizing alignment and regression coefficients.
5//!
6//! Key capabilities:
7//! - [`elastic_regression`] — Scalar-on-function regression with elastic alignment
8//! - [`elastic_logistic`] — Binary classification with elastic alignment
9//! - [`elastic_pcr`] — Principal component regression after elastic alignment
10
11pub mod logistic;
12pub mod pcr;
13pub mod regression;
14
15#[cfg(test)]
16mod tests;
17
18// Re-export all public items
19pub use logistic::{
20    elastic_logistic, elastic_logistic_with_config, predict_elastic_logistic, ElasticLogisticResult,
21};
22pub use pcr::{elastic_pcr, elastic_pcr_with_config, ElasticPcrResult};
23pub use regression::{
24    elastic_regression, elastic_regression_with_config, predict_elastic_regression,
25    ElasticRegressionResult,
26};
27
28use crate::alignment::reparameterize_curve;
29use crate::matrix::FdMatrix;
30
31// ─── Config Structs ─────────────────────────────────────────────────────────
32
33/// Configuration for [`elastic_regression`] and [`elastic_logistic`].
34#[derive(Debug, Clone, PartialEq)]
35pub struct ElasticConfig {
36    /// Number of basis functions for the beta coefficient (for elastic_regression).
37    pub ncomp_beta: usize,
38    /// Roughness penalty weight.
39    pub lambda: f64,
40    /// Maximum iterations for iterative alignment.
41    pub max_iter: usize,
42    /// Convergence tolerance.
43    pub tol: f64,
44}
45
46impl Default for ElasticConfig {
47    fn default() -> Self {
48        Self {
49            ncomp_beta: 10,
50            lambda: 0.0,
51            max_iter: 20,
52            tol: 1e-4,
53        }
54    }
55}
56
57/// Configuration for [`elastic_pcr`].
58#[derive(Debug, Clone, PartialEq)]
59pub struct ElasticPcrConfig {
60    /// Number of principal components to retain.
61    pub ncomp: usize,
62    /// PCA method (vertical, horizontal, or joint).
63    pub pca_method: PcaMethod,
64    /// Roughness penalty weight.
65    pub lambda: f64,
66    /// Maximum iterations for Karcher mean.
67    pub max_iter: usize,
68    /// Convergence tolerance for Karcher mean.
69    pub tol: f64,
70}
71
72impl Default for ElasticPcrConfig {
73    fn default() -> Self {
74        Self {
75            ncomp: 3,
76            pca_method: PcaMethod::Vertical,
77            lambda: 0.0,
78            max_iter: 20,
79            tol: 1e-4,
80        }
81    }
82}
83
84// ─── Types ──────────────────────────────────────────────────────────────────
85
86/// PCA method for elastic PCR.
87#[derive(Debug, Clone, Copy, PartialEq)]
88pub enum PcaMethod {
89    Vertical,
90    Horizontal,
91    Joint,
92}
93
94// ─── Shared Helpers ────────────────────────────────────────────────────────
95
96/// Apply warping functions to SRSFs, producing aligned SRSFs with sqrt(γ') factor.
97pub(super) fn apply_warps_to_srsfs(
98    q_all: &FdMatrix,
99    gammas: &FdMatrix,
100    argvals: &[f64],
101) -> FdMatrix {
102    let (n, m) = q_all.shape();
103    let h = (argvals[m - 1] - argvals[0]) / (m - 1) as f64;
104    let mut q_aligned = FdMatrix::zeros(n, m);
105    for i in 0..n {
106        let qi: Vec<f64> = (0..m).map(|j| q_all[(i, j)]).collect();
107        let gam: Vec<f64> = (0..m).map(|j| gammas[(i, j)]).collect();
108        let q_warped = reparameterize_curve(&qi, argvals, &gam);
109        let gam_deriv = crate::helpers::gradient_uniform(&gam, h);
110        for j in 0..m {
111            q_aligned[(i, j)] = q_warped[j] * gam_deriv[j].max(0.0).sqrt();
112        }
113    }
114    q_aligned
115}
116
117/// Initialize warping functions to identity (γ_i(t) = t).
118pub(super) fn init_identity_warps(n: usize, argvals: &[f64]) -> FdMatrix {
119    let m = argvals.len();
120    let mut gammas = FdMatrix::zeros(n, m);
121    for i in 0..n {
122        for j in 0..m {
123            gammas[(i, j)] = argvals[j];
124        }
125    }
126    gammas
127}
128
129/// Compute fitted values: ŷ_i = α + ∫ q_aligned_i · β · w dt.
130pub(super) fn srsf_fitted_values(
131    q_aligned: &FdMatrix,
132    beta: &[f64],
133    weights: &[f64],
134    alpha: f64,
135) -> Vec<f64> {
136    let (n, m) = q_aligned.shape();
137    let mut fitted = vec![0.0; n];
138    for i in 0..n {
139        fitted[i] = alpha;
140        for j in 0..m {
141            fitted[i] += q_aligned[(i, j)] * beta[j] * weights[j];
142        }
143    }
144    fitted
145}
146
147/// Check relative convergence of β.
148pub(super) fn beta_converged(beta_new: &[f64], beta_old: &[f64], tol: f64) -> bool {
149    let diff: f64 = beta_new
150        .iter()
151        .zip(beta_old.iter())
152        .map(|(&a, &b)| (a - b).powi(2))
153        .sum::<f64>()
154        .sqrt();
155    let norm: f64 = beta_old
156        .iter()
157        .map(|&b| b * b)
158        .sum::<f64>()
159        .sqrt()
160        .max(1e-10);
161    diff / norm < tol
162}