Skip to main content

fdars_core/elastic_regression/
regression.rs

1//! Elastic scalar-on-function regression.
2
3use crate::alignment::{dp_alignment_core, reparameterize_curve, sqrt_mean_inverse};
4use crate::basis::bspline_basis;
5use crate::helpers::simpsons_weights;
6use crate::matrix::FdMatrix;
7use crate::smooth_basis::bspline_penalty_matrix;
8use nalgebra::{DMatrix, DVector};
9
10use super::{
11    apply_warps_to_srsfs, beta_converged, init_identity_warps, srsf_fitted_values, ElasticConfig,
12};
13
14use crate::alignment::srsf_transform;
15
16/// Result of elastic scalar-on-function regression.
17#[derive(Debug, Clone, PartialEq)]
18#[non_exhaustive]
19pub struct ElasticRegressionResult {
20    /// Intercept.
21    pub alpha: f64,
22    /// Regression function β(t), length m.
23    pub beta: Vec<f64>,
24    /// Fitted values, length n.
25    pub fitted_values: Vec<f64>,
26    /// Residuals, length n.
27    pub residuals: Vec<f64>,
28    /// Residual sum of squares.
29    pub sse: f64,
30    /// Coefficient of determination.
31    pub r_squared: f64,
32    /// Final warping functions (n × m).
33    pub gammas: FdMatrix,
34    /// Aligned SRSFs (n × m).
35    pub aligned_srsfs: FdMatrix,
36    /// Number of iterations used.
37    pub n_iter: usize,
38}
39
40/// Alternating alignment + penalized regression for scalar-on-function.
41///
42/// Iterates:
43/// 1. Align SRSFs by current warps
44/// 2. Build basis inner products Φ\[i,j\] = ∫ q_aligned_i · B_j dt
45/// 3. Penalized OLS for β
46/// 4. Find optimal warps via `regression_warp`
47/// 5. Check convergence
48///
49/// # Arguments
50/// * `data` — Functional data (n × m)
51/// * `y` — Scalar responses (length n)
52/// * `argvals` — Evaluation points (length m)
53/// * `ncomp_beta` — Number of B-spline basis functions for β
54/// * `lambda` — Roughness penalty on β
55/// * `max_iter` — Maximum iterations (default: 20)
56/// * `tol` — Convergence tolerance (default: 1e-4)
57///
58/// # Errors
59///
60/// Returns [`crate::FdarError::InvalidDimension`] if `n < 2`, `m < 2`,
61/// `y.len() != n`, `argvals.len() != m`, or `ncomp_beta < 2`.
62/// Returns [`crate::FdarError::ComputationFailed`] if a regression iteration fails
63/// to converge (e.g., singular penalized system).
64#[must_use = "expensive computation whose result should not be discarded"]
65pub fn elastic_regression(
66    data: &FdMatrix,
67    y: &[f64],
68    argvals: &[f64],
69    ncomp_beta: usize,
70    lambda: f64,
71    max_iter: usize,
72    tol: f64,
73) -> Result<ElasticRegressionResult, crate::FdarError> {
74    let (n, m) = data.shape();
75    if n < 2 || m < 2 || y.len() != n || argvals.len() != m || ncomp_beta < 2 {
76        return Err(crate::FdarError::InvalidDimension {
77            parameter: "data/y/argvals",
78            expected: "n >= 2, m >= 2, y.len() == n, argvals.len() == m, ncomp_beta >= 2"
79                .to_string(),
80            actual: format!(
81                "n={}, m={}, y.len()={}, argvals.len()={}, ncomp_beta={}",
82                n,
83                m,
84                y.len(),
85                argvals.len(),
86                ncomp_beta
87            ),
88        });
89    }
90
91    let weights = simpsons_weights(argvals);
92    let q_all = srsf_transform(data, argvals);
93
94    let (b_mat, r_trimmed, actual_nbasis) = build_basis_and_penalty(argvals, ncomp_beta, m);
95
96    let mut gammas = init_identity_warps(n, argvals);
97    let y_mean: f64 = y.iter().sum::<f64>() / n as f64;
98    let mut beta = vec![0.0; m];
99    let mut alpha = y_mean;
100    let mut n_iter = 0;
101
102    for iter in 0..max_iter {
103        n_iter = iter + 1;
104
105        let (beta_new, alpha_new) = regression_iteration_step(
106            &q_all,
107            &gammas,
108            argvals,
109            &b_mat,
110            &r_trimmed,
111            &weights,
112            y,
113            alpha,
114            lambda,
115            n,
116            m,
117            actual_nbasis,
118        )
119        .ok_or_else(|| crate::FdarError::ComputationFailed {
120            operation: "regression_iteration",
121            detail: format!(
122                "iteration {} failed; try increasing lambda or reducing nbasis",
123                iter + 1
124            ),
125        })?;
126
127        if beta_converged(&beta_new, &beta, tol) && iter > 0 {
128            beta = beta_new;
129            alpha = alpha_new;
130            break;
131        }
132
133        beta = beta_new;
134        alpha = alpha_new;
135
136        update_regression_warps(&mut gammas, &q_all, &beta, argvals, alpha, y, lambda * 0.01);
137        center_warps(&mut gammas, argvals);
138    }
139
140    // Final fitted values
141    let aligned_srsfs = apply_warps_to_srsfs(&q_all, &gammas, argvals);
142    let fitted_values = srsf_fitted_values(&aligned_srsfs, &beta, &weights, alpha);
143    let (residuals, sse, r_squared) = compute_regression_residuals(y, &fitted_values, y_mean);
144
145    Ok(ElasticRegressionResult {
146        alpha,
147        beta,
148        fitted_values,
149        residuals,
150        sse,
151        r_squared,
152        gammas,
153        aligned_srsfs,
154        n_iter,
155    })
156}
157
158/// Elastic scalar-on-function regression using a configuration struct.
159///
160/// Equivalent to [`elastic_regression`] but bundles method parameters in [`ElasticConfig`].
161#[must_use = "expensive computation whose result should not be discarded"]
162pub fn elastic_regression_with_config(
163    data: &FdMatrix,
164    y: &[f64],
165    argvals: &[f64],
166    config: &ElasticConfig,
167) -> Result<ElasticRegressionResult, crate::FdarError> {
168    elastic_regression(
169        data,
170        y,
171        argvals,
172        config.ncomp_beta,
173        config.lambda,
174        config.max_iter,
175        config.tol,
176    )
177}
178
179/// Predict new responses using a fitted elastic regression model.
180///
181/// Transforms new curves to SRSFs, aligns them using the training warps as
182/// a template (identity alignment for new data), then applies the fitted
183/// regression coefficients.
184///
185/// # Arguments
186/// * `fit` — A fitted [`ElasticRegressionResult`]
187/// * `new_data` — New functional data (n_new × m)
188/// * `argvals` — Evaluation points (length m)
189pub fn predict_elastic_regression(
190    fit: &ElasticRegressionResult,
191    new_data: &FdMatrix,
192    argvals: &[f64],
193) -> Vec<f64> {
194    let weights = simpsons_weights(argvals);
195    let q_new = srsf_transform(new_data, argvals);
196    srsf_fitted_values(&q_new, &fit.beta, &weights, fit.alpha)
197}
198
199impl ElasticRegressionResult {
200    /// Predict responses for new data. Delegates to [`predict_elastic_regression`].
201    pub fn predict(&self, new_data: &FdMatrix, argvals: &[f64]) -> Vec<f64> {
202        predict_elastic_regression(self, new_data, argvals)
203    }
204}
205
206// ─── Internal helpers ───────────────────────────────────────────────────────
207
208/// Find optimal warp for a single curve in elastic regression.
209///
210/// Aligns q_i to both +β and -β via DP, then binary searches between the
211/// two extreme warps to find the one giving predicted y closest to actual.
212fn regression_warp(
213    q_i: &[f64],
214    beta: &[f64],
215    argvals: &[f64],
216    alpha: f64,
217    y_i: f64,
218    lambda: f64,
219) -> Vec<f64> {
220    let weights = simpsons_weights(argvals);
221
222    // Align to +β
223    let gam_pos = dp_alignment_core(beta, q_i, argvals, lambda);
224
225    // Align to -β
226    let neg_beta: Vec<f64> = beta.iter().map(|&b| -b).collect();
227    let gam_neg = dp_alignment_core(&neg_beta, q_i, argvals, lambda);
228
229    // Compute predicted y for each extreme
230    let y_pos = compute_predicted_y(q_i, beta, &gam_pos, argvals, alpha, &weights);
231    let y_neg = compute_predicted_y(q_i, beta, &gam_neg, argvals, alpha, &weights);
232
233    // If already close enough, return the nearest extreme
234    if let Some(gam) = check_extreme_warps(&gam_pos, &gam_neg, y_pos, y_neg, y_i) {
235        return gam;
236    }
237
238    // Binary search between the two warps
239    let (gam_lo, gam_hi) = order_warps_by_prediction(gam_pos, gam_neg, y_pos, y_neg);
240    binary_search_warps(gam_lo, gam_hi, q_i, beta, argvals, alpha, y_i, &weights)
241}
242
243/// Compute predicted y for a warped curve.
244fn compute_predicted_y(
245    q_i: &[f64],
246    beta: &[f64],
247    gam: &[f64],
248    argvals: &[f64],
249    alpha: f64,
250    weights: &[f64],
251) -> f64 {
252    let m = argvals.len();
253    let q_warped = reparameterize_curve(q_i, argvals, gam);
254    let h = (argvals[m - 1] - argvals[0]) / (m - 1) as f64;
255    let gam_deriv = crate::helpers::gradient_uniform(gam, h);
256
257    let mut y_hat = alpha;
258    for j in 0..m {
259        let q_aligned_j = q_warped[j] * gam_deriv[j].max(0.0).sqrt();
260        y_hat += q_aligned_j * beta[j] * weights[j];
261    }
262    y_hat
263}
264
265/// Build B-spline basis matrix and roughness penalty for β representation.
266fn build_basis_and_penalty(
267    argvals: &[f64],
268    ncomp_beta: usize,
269    m: usize,
270) -> (DMatrix<f64>, DMatrix<f64>, usize) {
271    let nknots = ncomp_beta.saturating_sub(4).max(2);
272    let basis_flat = bspline_basis(argvals, nknots, 4);
273    let actual_nbasis = basis_flat.len() / m;
274    let b_mat = DMatrix::from_column_slice(m, actual_nbasis, &basis_flat);
275
276    let penalty_flat = bspline_penalty_matrix(argvals, ncomp_beta, 4, 2);
277    let penalty_k = (penalty_flat.len() as f64).sqrt() as usize;
278    let r_mat = DMatrix::from_column_slice(penalty_k, penalty_k, &penalty_flat);
279    let r_trimmed = trim_penalty_to_basis(&r_mat, penalty_k, actual_nbasis);
280
281    (b_mat, r_trimmed, actual_nbasis)
282}
283
284/// Trim or pad penalty matrix to match actual basis dimension.
285fn trim_penalty_to_basis(
286    r_mat: &DMatrix<f64>,
287    penalty_k: usize,
288    actual_nbasis: usize,
289) -> DMatrix<f64> {
290    if penalty_k >= actual_nbasis {
291        r_mat
292            .view((0, 0), (actual_nbasis, actual_nbasis))
293            .into_owned()
294    } else {
295        let mut r = DMatrix::zeros(actual_nbasis, actual_nbasis);
296        let dim = penalty_k.min(actual_nbasis);
297        for i in 0..dim {
298            for j in 0..dim {
299                r[(i, j)] = r_mat[(i, j)];
300            }
301        }
302        r
303    }
304}
305
306/// Build design matrix Φ[i,k] = ∫ q_aligned_i · B_k · w dt.
307fn build_phi_matrix(
308    q_aligned: &FdMatrix,
309    b_mat: &DMatrix<f64>,
310    weights: &[f64],
311    n: usize,
312    m: usize,
313    actual_nbasis: usize,
314) -> DMatrix<f64> {
315    let mut phi = DMatrix::zeros(n, actual_nbasis);
316    for i in 0..n {
317        for k in 0..actual_nbasis {
318            let mut val = 0.0;
319            for j in 0..m {
320                val += q_aligned[(i, j)] * b_mat[(j, k)] * weights[j];
321            }
322            phi[(i, k)] = val;
323        }
324    }
325    phi
326}
327
328/// Solve penalized OLS: (Φ'Φ + λR)c = Φ'y.
329pub(super) fn solve_penalized_ols(
330    phi: &DMatrix<f64>,
331    r_trimmed: &DMatrix<f64>,
332    y_centered: &[f64],
333    lambda: f64,
334) -> Option<Vec<f64>> {
335    let y_vec = DVector::from_vec(y_centered.to_vec());
336    let phi_t_phi = phi.transpose() * phi;
337    let system = &phi_t_phi + lambda * r_trimmed;
338    let rhs = phi.transpose() * &y_vec;
339    let coefs = if let Some(chol) = system.clone().cholesky() {
340        chol.solve(&rhs)
341    } else {
342        let svd = nalgebra::SVD::new(system, true, true);
343        svd.solve(&rhs, 1e-10).ok()?
344    };
345    Some(coefs.iter().copied().collect())
346}
347
348/// Reconstruct β(t) = Σ c_k B_k(t) from B-spline coefficients.
349fn reconstruct_beta_from_coefs(
350    coefs: &[f64],
351    b_mat: &DMatrix<f64>,
352    m: usize,
353    actual_nbasis: usize,
354) -> Vec<f64> {
355    let mut beta = vec![0.0; m];
356    for j in 0..m {
357        for k in 0..actual_nbasis {
358            beta[j] += coefs[k] * b_mat[(j, k)];
359        }
360    }
361    beta
362}
363
364/// Compute intercept: α̂ = mean(y - ∫ q·β·w dt).
365fn compute_alpha_from_residuals(
366    q_aligned: &FdMatrix,
367    beta: &[f64],
368    weights: &[f64],
369    y: &[f64],
370) -> f64 {
371    let (n, m) = q_aligned.shape();
372    let mut alpha = 0.0;
373    for i in 0..n {
374        let mut y_hat_i = 0.0;
375        for j in 0..m {
376            y_hat_i += q_aligned[(i, j)] * beta[j] * weights[j];
377        }
378        alpha += y[i] - y_hat_i;
379    }
380    alpha / n as f64
381}
382
383/// One iteration step of elastic regression: align, solve OLS, return new (β, α).
384fn regression_iteration_step(
385    q_all: &FdMatrix,
386    gammas: &FdMatrix,
387    argvals: &[f64],
388    b_mat: &DMatrix<f64>,
389    r_trimmed: &DMatrix<f64>,
390    weights: &[f64],
391    y: &[f64],
392    alpha: f64,
393    lambda: f64,
394    n: usize,
395    m: usize,
396    actual_nbasis: usize,
397) -> Option<(Vec<f64>, f64)> {
398    let q_aligned = apply_warps_to_srsfs(q_all, gammas, argvals);
399    let phi = build_phi_matrix(&q_aligned, b_mat, weights, n, m, actual_nbasis);
400    let y_centered: Vec<f64> = y.iter().map(|&yi| yi - alpha).collect();
401    let coefs = solve_penalized_ols(&phi, r_trimmed, &y_centered, lambda)?;
402    let beta_new = reconstruct_beta_from_coefs(&coefs, b_mat, m, actual_nbasis);
403    let alpha_new = compute_alpha_from_residuals(&q_aligned, &beta_new, weights, y);
404    Some((beta_new, alpha_new))
405}
406
407/// Update warping functions for all curves in elastic regression.
408fn update_regression_warps(
409    gammas: &mut FdMatrix,
410    q_all: &FdMatrix,
411    beta: &[f64],
412    argvals: &[f64],
413    alpha: f64,
414    y: &[f64],
415    lambda: f64,
416) {
417    let (n, m) = q_all.shape();
418    for i in 0..n {
419        let qi: Vec<f64> = (0..m).map(|j| q_all[(i, j)]).collect();
420        let new_gam = regression_warp(&qi, beta, argvals, alpha, y[i], lambda);
421        for j in 0..m {
422            gammas[(i, j)] = new_gam[j];
423        }
424    }
425}
426
427/// Center warping functions using Karcher mean.
428fn center_warps(gammas: &mut FdMatrix, argvals: &[f64]) {
429    let (n, m) = gammas.shape();
430    let gam_mu = sqrt_mean_inverse(gammas, argvals);
431    for i in 0..n {
432        let gam_i: Vec<f64> = (0..m).map(|j| gammas[(i, j)]).collect();
433        let composed = crate::alignment::compose_warps(&gam_i, &gam_mu, argvals);
434        for j in 0..m {
435            gammas[(i, j)] = composed[j];
436        }
437    }
438}
439
440/// Compute residuals, SSE, and R² from y and fitted values.
441fn compute_regression_residuals(
442    y: &[f64],
443    fitted_values: &[f64],
444    y_mean: f64,
445) -> (Vec<f64>, f64, f64) {
446    let residuals: Vec<f64> = y
447        .iter()
448        .zip(fitted_values.iter())
449        .map(|(&yi, &yh)| yi - yh)
450        .collect();
451    let sse: f64 = residuals.iter().map(|&r| r * r).sum();
452    let ss_tot: f64 = y.iter().map(|&yi| (yi - y_mean).powi(2)).sum();
453    let r_squared = if ss_tot > 0.0 {
454        1.0 - sse / ss_tot
455    } else {
456        0.0
457    };
458    (residuals, sse, r_squared)
459}
460
461/// Check if either extreme warp is already close enough to target y.
462fn check_extreme_warps(
463    gam_pos: &[f64],
464    gam_neg: &[f64],
465    y_pos: f64,
466    y_neg: f64,
467    y_i: f64,
468) -> Option<Vec<f64>> {
469    if (y_pos - y_i).abs() <= (y_neg - y_i).abs() {
470        if (y_pos - y_i).abs() < 1e-10 {
471            return Some(gam_pos.to_vec());
472        }
473    } else if (y_neg - y_i).abs() < 1e-10 {
474        return Some(gam_neg.to_vec());
475    }
476    None
477}
478
479/// Order warps so gam_lo gives lower prediction and gam_hi gives higher.
480fn order_warps_by_prediction(
481    gam_pos: Vec<f64>,
482    gam_neg: Vec<f64>,
483    y_pos: f64,
484    y_neg: f64,
485) -> (Vec<f64>, Vec<f64>) {
486    if y_pos < y_neg {
487        (gam_pos, gam_neg)
488    } else {
489        (gam_neg, gam_pos)
490    }
491}
492
493/// Binary search between two warps to find one giving predicted y closest to target.
494fn binary_search_warps(
495    mut gam_lo: Vec<f64>,
496    mut gam_hi: Vec<f64>,
497    q_i: &[f64],
498    beta: &[f64],
499    argvals: &[f64],
500    alpha: f64,
501    y_i: f64,
502    weights: &[f64],
503) -> Vec<f64> {
504    for _ in 0..15 {
505        let gam_mid: Vec<f64> = gam_lo
506            .iter()
507            .zip(gam_hi.iter())
508            .map(|(&lo, &hi)| 0.5 * (lo + hi))
509            .collect();
510        let y_mid = compute_predicted_y(q_i, beta, &gam_mid, argvals, alpha, weights);
511        if (y_mid - y_i).abs() < 1e-6 {
512            return gam_mid;
513        }
514        if y_mid < y_i {
515            gam_lo = gam_mid;
516        } else {
517            gam_hi = gam_mid;
518        }
519    }
520    gam_lo
521        .iter()
522        .zip(gam_hi.iter())
523        .map(|(&lo, &hi)| 0.5 * (lo + hi))
524        .collect()
525}