Skip to main content

fdars_core/
elastic_regression.rs

1//! Elastic regression models (alignment-integrated regression).
2//!
3//! These models from fdasrvf align curves during the regression fitting process,
4//! jointly optimizing alignment and regression coefficients.
5//!
6//! Key capabilities:
7//! - [`elastic_regression`] — Scalar-on-function regression with elastic alignment
8//! - [`elastic_logistic`] — Binary classification with elastic alignment
9//! - [`elastic_pcr`] — Principal component regression after elastic alignment
10
11use crate::alignment::{
12    dp_alignment_core, karcher_mean, reparameterize_curve, sqrt_mean_inverse, srsf_transform,
13    KarcherMeanResult,
14};
15use crate::basis::bspline_basis;
16use crate::elastic_fpca::{
17    horiz_fpca, joint_fpca, vert_fpca, HorizFpcaResult, JointFpcaResult, VertFpcaResult,
18};
19use crate::helpers::simpsons_weights;
20use crate::matrix::FdMatrix;
21use crate::smooth_basis::bspline_penalty_matrix;
22use nalgebra::{DMatrix, DVector};
23
24// ─── Types ──────────────────────────────────────────────────────────────────
25
26/// Result of elastic scalar-on-function regression.
27#[derive(Debug, Clone)]
28pub struct ElasticRegressionResult {
29    /// Intercept.
30    pub alpha: f64,
31    /// Regression function β(t), length m.
32    pub beta: Vec<f64>,
33    /// Fitted values, length n.
34    pub fitted_values: Vec<f64>,
35    /// Residuals, length n.
36    pub residuals: Vec<f64>,
37    /// Residual sum of squares.
38    pub sse: f64,
39    /// Coefficient of determination.
40    pub r_squared: f64,
41    /// Final warping functions (n × m).
42    pub gammas: FdMatrix,
43    /// Aligned SRSFs (n × m).
44    pub aligned_srsfs: FdMatrix,
45    /// Number of iterations used.
46    pub n_iter: usize,
47}
48
49/// Result of elastic logistic regression.
50#[derive(Debug, Clone)]
51pub struct ElasticLogisticResult {
52    /// Intercept.
53    pub alpha: f64,
54    /// Regression function β(t), length m.
55    pub beta: Vec<f64>,
56    /// Predicted probabilities, length n.
57    pub probabilities: Vec<f64>,
58    /// Predicted class labels (-1 or 1), length n.
59    pub predicted_classes: Vec<i8>,
60    /// Classification accuracy.
61    pub accuracy: f64,
62    /// Logistic loss.
63    pub loss: f64,
64    /// Final warping functions (n × m).
65    pub gammas: FdMatrix,
66    /// Aligned SRSFs (n × m).
67    pub aligned_srsfs: FdMatrix,
68    /// Number of iterations used.
69    pub n_iter: usize,
70}
71
72/// PCA method for elastic PCR.
73#[derive(Debug, Clone, Copy, PartialEq)]
74pub enum PcaMethod {
75    Vertical,
76    Horizontal,
77    Joint,
78}
79
80/// Result of elastic principal component regression.
81#[derive(Debug, Clone)]
82pub struct ElasticPcrResult {
83    /// Intercept.
84    pub alpha: f64,
85    /// Regression coefficients on PC scores, length ncomp.
86    pub coefficients: Vec<f64>,
87    /// Fitted values, length n.
88    pub fitted_values: Vec<f64>,
89    /// Residual sum of squares.
90    pub sse: f64,
91    /// Coefficient of determination.
92    pub r_squared: f64,
93    /// PCA method used.
94    pub pca_method: PcaMethod,
95    /// Karcher mean result.
96    pub karcher: KarcherMeanResult,
97    /// Vertical FPCA result (stored when method is Vertical or Joint).
98    pub vert_fpca: Option<VertFpcaResult>,
99    /// Horizontal FPCA result (stored when method is Horizontal or Joint).
100    pub horiz_fpca: Option<HorizFpcaResult>,
101    /// Joint FPCA result (stored when method is Joint).
102    pub joint_fpca: Option<JointFpcaResult>,
103}
104
105// ─── Elastic Regression ─────────────────────────────────────────────────────
106
107/// Alternating alignment + penalized regression for scalar-on-function.
108///
109/// Iterates:
110/// 1. Align SRSFs by current warps
111/// 2. Build basis inner products Φ\[i,j\] = ∫ q_aligned_i · B_j dt
112/// 3. Penalized OLS for β
113/// 4. Find optimal warps via `regression_warp`
114/// 5. Check convergence
115///
116/// # Arguments
117/// * `data` — Functional data (n × m)
118/// * `y` — Scalar responses (length n)
119/// * `argvals` — Evaluation points (length m)
120/// * `ncomp_beta` — Number of B-spline basis functions for β
121/// * `lambda` — Roughness penalty on β
122/// * `max_iter` — Maximum iterations (default: 20)
123/// * `tol` — Convergence tolerance (default: 1e-4)
124pub fn elastic_regression(
125    data: &FdMatrix,
126    y: &[f64],
127    argvals: &[f64],
128    ncomp_beta: usize,
129    lambda: f64,
130    max_iter: usize,
131    tol: f64,
132) -> Option<ElasticRegressionResult> {
133    let (n, m) = data.shape();
134    if n < 2 || m < 2 || y.len() != n || argvals.len() != m || ncomp_beta < 2 {
135        return None;
136    }
137
138    let weights = simpsons_weights(argvals);
139    let q_all = srsf_transform(data, argvals);
140
141    let (b_mat, r_trimmed, actual_nbasis) = build_basis_and_penalty(argvals, ncomp_beta, m);
142
143    let mut gammas = init_identity_warps(n, argvals);
144    let y_mean: f64 = y.iter().sum::<f64>() / n as f64;
145    let mut beta = vec![0.0; m];
146    let mut alpha = y_mean;
147    let mut n_iter = 0;
148
149    for iter in 0..max_iter {
150        n_iter = iter + 1;
151
152        let (beta_new, alpha_new) = regression_iteration_step(
153            &q_all,
154            &gammas,
155            argvals,
156            &b_mat,
157            &r_trimmed,
158            &weights,
159            y,
160            alpha,
161            lambda,
162            n,
163            m,
164            actual_nbasis,
165        )?;
166
167        if beta_converged(&beta_new, &beta, tol) && iter > 0 {
168            beta = beta_new;
169            alpha = alpha_new;
170            break;
171        }
172
173        beta = beta_new;
174        alpha = alpha_new;
175
176        update_regression_warps(&mut gammas, &q_all, &beta, argvals, alpha, y, lambda * 0.01);
177        center_warps(&mut gammas, argvals);
178    }
179
180    // Final fitted values
181    let aligned_srsfs = apply_warps_to_srsfs(&q_all, &gammas, argvals);
182    let fitted_values = srsf_fitted_values(&aligned_srsfs, &beta, &weights, alpha);
183    let (residuals, sse, r_squared) = compute_regression_residuals(y, &fitted_values, y_mean);
184
185    Some(ElasticRegressionResult {
186        alpha,
187        beta,
188        fitted_values,
189        residuals,
190        sse,
191        r_squared,
192        gammas,
193        aligned_srsfs,
194        n_iter,
195    })
196}
197
198// ─── Elastic Logistic Regression ────────────────────────────────────────────
199
200/// Elastic logistic regression for binary classification.
201///
202/// Labels should be -1 or 1. Uses gradient descent with Armijo line search.
203///
204/// # Arguments
205/// * `data` — Functional data (n × m)
206/// * `y` — Binary labels (-1 or 1), length n
207/// * `argvals` — Evaluation points (length m)
208/// * `ncomp_beta` — Number of B-spline basis functions for β
209/// * `lambda` — Roughness penalty on β
210/// * `max_iter` — Maximum iterations
211/// * `tol` — Convergence tolerance
212pub fn elastic_logistic(
213    data: &FdMatrix,
214    y: &[i8],
215    argvals: &[f64],
216    _ncomp_beta: usize,
217    lambda: f64,
218    max_iter: usize,
219    tol: f64,
220) -> Option<ElasticLogisticResult> {
221    let (n, m) = data.shape();
222    if n < 2 || m < 2 || y.len() != n || argvals.len() != m {
223        return None;
224    }
225
226    let weights = simpsons_weights(argvals);
227    let q_all = srsf_transform(data, argvals);
228    let mut gammas = init_identity_warps(n, argvals);
229    let mut beta = vec![0.0; m];
230    let mut alpha = 0.0;
231    let mut n_iter = 0;
232
233    for iter in 0..max_iter {
234        n_iter = iter + 1;
235
236        let q_aligned = apply_warps_to_srsfs(&q_all, &gammas, argvals);
237        let (grad_a, grad_beta, prob) =
238            logistic_gradients(&q_aligned, &beta, &weights, alpha, y, lambda);
239
240        let loss_current = logistic_loss(&prob, y, &beta, lambda);
241        let grad_norm_sq: f64 = grad_a * grad_a + grad_beta.iter().map(|&g| g * g).sum::<f64>();
242
243        let step = armijo_line_search_logistic(
244            &q_aligned,
245            alpha,
246            &beta,
247            grad_a,
248            &grad_beta,
249            &weights,
250            y,
251            lambda,
252            loss_current,
253            grad_norm_sq,
254        );
255
256        let beta_new: Vec<f64> = beta
257            .iter()
258            .zip(grad_beta.iter())
259            .map(|(&b, &g)| b - step * g)
260            .collect();
261        let alpha_new = alpha - step * grad_a;
262
263        if beta_converged(&beta_new, &beta, tol) && iter > 0 {
264            beta = beta_new;
265            alpha = alpha_new;
266            break;
267        }
268
269        beta = beta_new;
270        alpha = alpha_new;
271
272        update_logistic_warps(&mut gammas, &q_all, &beta, y, argvals, lambda * 0.01);
273    }
274
275    // Final predictions
276    let aligned_srsfs = apply_warps_to_srsfs(&q_all, &gammas, argvals);
277    let (probabilities, predicted_classes, accuracy, loss) =
278        compute_logistic_predictions(&aligned_srsfs, &beta, &weights, alpha, y, lambda);
279
280    Some(ElasticLogisticResult {
281        alpha,
282        beta,
283        probabilities,
284        predicted_classes,
285        accuracy,
286        loss,
287        gammas,
288        aligned_srsfs,
289        n_iter,
290    })
291}
292
293/// Compute logistic loss with L2 penalty.
294fn logistic_loss(prob: &[f64], y: &[i8], beta: &[f64], lambda: f64) -> f64 {
295    let n = prob.len();
296    let mut loss = 0.0;
297    for i in 0..n {
298        let target = if y[i] == 1 { 1.0 } else { 0.0 };
299        let p = prob[i].clamp(1e-15, 1.0 - 1e-15);
300        loss -= target * p.ln() + (1.0 - target) * (1.0 - p).ln();
301    }
302    loss /= n as f64;
303    // L2 penalty
304    loss += 0.5 * lambda * beta.iter().map(|&b| b * b).sum::<f64>();
305    loss
306}
307
308// ─── Elastic PCR ────────────────────────────────────────────────────────────
309
310/// Elastic principal component regression.
311///
312/// Performs Karcher mean alignment, then FPCA (vert/horiz/joint), then OLS
313/// on the PC scores.
314///
315/// # Arguments
316/// * `data` — Functional data (n × m)
317/// * `y` — Scalar responses (length n)
318/// * `argvals` — Evaluation points (length m)
319/// * `ncomp` — Number of PCs to use
320/// * `pca_method` — Which FPCA variant to use
321/// * `lambda` — Alignment penalty (passed to karcher_mean)
322/// * `max_iter` — Maximum iterations for karcher_mean
323/// * `tol` — Convergence tolerance for karcher_mean
324pub fn elastic_pcr(
325    data: &FdMatrix,
326    y: &[f64],
327    argvals: &[f64],
328    ncomp: usize,
329    pca_method: PcaMethod,
330    lambda: f64,
331    max_iter: usize,
332    tol: f64,
333) -> Option<ElasticPcrResult> {
334    let (n, _m) = data.shape();
335    if n < 2 || y.len() != n || ncomp < 1 {
336        return None;
337    }
338
339    // Karcher mean alignment
340    let km = karcher_mean(data, argvals, max_iter, tol, lambda);
341
342    // FPCA
343    let mut stored_vert: Option<VertFpcaResult> = None;
344    let mut stored_horiz: Option<HorizFpcaResult> = None;
345    let mut stored_joint: Option<JointFpcaResult> = None;
346
347    let scores_mat = match pca_method {
348        PcaMethod::Vertical => {
349            let fpca = vert_fpca(&km, argvals, ncomp)?;
350            let scores = fpca.scores.clone();
351            stored_vert = Some(fpca);
352            scores
353        }
354        PcaMethod::Horizontal => {
355            let fpca = horiz_fpca(&km, argvals, ncomp)?;
356            let scores = fpca.scores.clone();
357            stored_horiz = Some(fpca);
358            scores
359        }
360        PcaMethod::Joint => {
361            let fpca = joint_fpca(&km, argvals, ncomp, None)?;
362            let scores = fpca.scores.clone();
363            stored_joint = Some(fpca);
364            scores
365        }
366    };
367
368    let actual_ncomp = scores_mat.ncols();
369    let (coefs, alpha, fitted_values, sse, r_squared) =
370        ols_on_scores(&scores_mat, y, n, actual_ncomp)?;
371
372    Some(ElasticPcrResult {
373        alpha,
374        coefficients: coefs,
375        fitted_values,
376        sse,
377        r_squared,
378        pca_method,
379        karcher: km,
380        vert_fpca: stored_vert,
381        horiz_fpca: stored_horiz,
382        joint_fpca: stored_joint,
383    })
384}
385
386// ─── Internal: Regression Warp ──────────────────────────────────────────────
387
388/// Find optimal warp for a single curve in elastic regression.
389///
390/// Aligns q_i to both +β and -β via DP, then binary searches between the
391/// two extreme warps to find the one giving predicted y closest to actual.
392fn regression_warp(
393    q_i: &[f64],
394    beta: &[f64],
395    argvals: &[f64],
396    alpha: f64,
397    y_i: f64,
398    lambda: f64,
399) -> Vec<f64> {
400    let weights = simpsons_weights(argvals);
401
402    // Align to +β
403    let gam_pos = dp_alignment_core(beta, q_i, argvals, lambda);
404
405    // Align to -β
406    let neg_beta: Vec<f64> = beta.iter().map(|&b| -b).collect();
407    let gam_neg = dp_alignment_core(&neg_beta, q_i, argvals, lambda);
408
409    // Compute predicted y for each extreme
410    let y_pos = compute_predicted_y(q_i, beta, &gam_pos, argvals, alpha, &weights);
411    let y_neg = compute_predicted_y(q_i, beta, &gam_neg, argvals, alpha, &weights);
412
413    // If already close enough, return the nearest extreme
414    if let Some(gam) = check_extreme_warps(&gam_pos, &gam_neg, y_pos, y_neg, y_i) {
415        return gam;
416    }
417
418    // Binary search between the two warps
419    let (gam_lo, gam_hi) = order_warps_by_prediction(gam_pos, gam_neg, y_pos, y_neg);
420    binary_search_warps(gam_lo, gam_hi, q_i, beta, argvals, alpha, y_i, &weights)
421}
422
423/// Compute predicted y for a warped curve.
424fn compute_predicted_y(
425    q_i: &[f64],
426    beta: &[f64],
427    gam: &[f64],
428    argvals: &[f64],
429    alpha: f64,
430    weights: &[f64],
431) -> f64 {
432    let m = argvals.len();
433    let q_warped = reparameterize_curve(q_i, argvals, gam);
434    let h = (argvals[m - 1] - argvals[0]) / (m - 1) as f64;
435    let gam_deriv = crate::helpers::gradient_uniform(gam, h);
436
437    let mut y_hat = alpha;
438    for j in 0..m {
439        let q_aligned_j = q_warped[j] * gam_deriv[j].max(0.0).sqrt();
440        y_hat += q_aligned_j * beta[j] * weights[j];
441    }
442    y_hat
443}
444
445// ─── Shared Helpers ────────────────────────────────────────────────────────
446
447/// Apply warping functions to SRSFs, producing aligned SRSFs with sqrt(γ') factor.
448fn apply_warps_to_srsfs(q_all: &FdMatrix, gammas: &FdMatrix, argvals: &[f64]) -> FdMatrix {
449    let (n, m) = q_all.shape();
450    let h = (argvals[m - 1] - argvals[0]) / (m - 1) as f64;
451    let mut q_aligned = FdMatrix::zeros(n, m);
452    for i in 0..n {
453        let qi: Vec<f64> = (0..m).map(|j| q_all[(i, j)]).collect();
454        let gam: Vec<f64> = (0..m).map(|j| gammas[(i, j)]).collect();
455        let q_warped = reparameterize_curve(&qi, argvals, &gam);
456        let gam_deriv = crate::helpers::gradient_uniform(&gam, h);
457        for j in 0..m {
458            q_aligned[(i, j)] = q_warped[j] * gam_deriv[j].max(0.0).sqrt();
459        }
460    }
461    q_aligned
462}
463
464/// Initialize warping functions to identity (γ_i(t) = t).
465fn init_identity_warps(n: usize, argvals: &[f64]) -> FdMatrix {
466    let m = argvals.len();
467    let mut gammas = FdMatrix::zeros(n, m);
468    for i in 0..n {
469        for j in 0..m {
470            gammas[(i, j)] = argvals[j];
471        }
472    }
473    gammas
474}
475
476/// Trim or pad penalty matrix to match actual basis dimension.
477fn trim_penalty_to_basis(
478    r_mat: &DMatrix<f64>,
479    penalty_k: usize,
480    actual_nbasis: usize,
481) -> DMatrix<f64> {
482    if penalty_k >= actual_nbasis {
483        r_mat
484            .view((0, 0), (actual_nbasis, actual_nbasis))
485            .into_owned()
486    } else {
487        let mut r = DMatrix::zeros(actual_nbasis, actual_nbasis);
488        let dim = penalty_k.min(actual_nbasis);
489        for i in 0..dim {
490            for j in 0..dim {
491                r[(i, j)] = r_mat[(i, j)];
492            }
493        }
494        r
495    }
496}
497
498/// Build design matrix Φ[i,k] = ∫ q_aligned_i · B_k · w dt.
499fn build_phi_matrix(
500    q_aligned: &FdMatrix,
501    b_mat: &DMatrix<f64>,
502    weights: &[f64],
503    n: usize,
504    m: usize,
505    actual_nbasis: usize,
506) -> DMatrix<f64> {
507    let mut phi = DMatrix::zeros(n, actual_nbasis);
508    for i in 0..n {
509        for k in 0..actual_nbasis {
510            let mut val = 0.0;
511            for j in 0..m {
512                val += q_aligned[(i, j)] * b_mat[(j, k)] * weights[j];
513            }
514            phi[(i, k)] = val;
515        }
516    }
517    phi
518}
519
520/// Solve penalized OLS: (Φ'Φ + λR)c = Φ'y.
521fn solve_penalized_ols(
522    phi: &DMatrix<f64>,
523    r_trimmed: &DMatrix<f64>,
524    y_centered: &[f64],
525    lambda: f64,
526) -> Option<Vec<f64>> {
527    let y_vec = DVector::from_vec(y_centered.to_vec());
528    let phi_t_phi = phi.transpose() * phi;
529    let system = &phi_t_phi + lambda * r_trimmed;
530    let rhs = phi.transpose() * &y_vec;
531    let coefs = if let Some(chol) = system.clone().cholesky() {
532        chol.solve(&rhs)
533    } else {
534        let svd = nalgebra::SVD::new(system, true, true);
535        svd.solve(&rhs, 1e-10).ok()?
536    };
537    Some(coefs.iter().cloned().collect())
538}
539
540/// Reconstruct β(t) = Σ c_k B_k(t) from B-spline coefficients.
541fn reconstruct_beta_from_coefs(
542    coefs: &[f64],
543    b_mat: &DMatrix<f64>,
544    m: usize,
545    actual_nbasis: usize,
546) -> Vec<f64> {
547    let mut beta = vec![0.0; m];
548    for j in 0..m {
549        for k in 0..actual_nbasis {
550            beta[j] += coefs[k] * b_mat[(j, k)];
551        }
552    }
553    beta
554}
555
556/// Compute intercept: α̂ = mean(y - ∫ q·β·w dt).
557fn compute_alpha_from_residuals(
558    q_aligned: &FdMatrix,
559    beta: &[f64],
560    weights: &[f64],
561    y: &[f64],
562) -> f64 {
563    let (n, m) = q_aligned.shape();
564    let mut alpha = 0.0;
565    for i in 0..n {
566        let mut y_hat_i = 0.0;
567        for j in 0..m {
568            y_hat_i += q_aligned[(i, j)] * beta[j] * weights[j];
569        }
570        alpha += y[i] - y_hat_i;
571    }
572    alpha / n as f64
573}
574
575/// Compute fitted values: ŷ_i = α + ∫ q_aligned_i · β · w dt.
576fn srsf_fitted_values(q_aligned: &FdMatrix, beta: &[f64], weights: &[f64], alpha: f64) -> Vec<f64> {
577    let (n, m) = q_aligned.shape();
578    let mut fitted = vec![0.0; n];
579    for i in 0..n {
580        fitted[i] = alpha;
581        for j in 0..m {
582            fitted[i] += q_aligned[(i, j)] * beta[j] * weights[j];
583        }
584    }
585    fitted
586}
587
588/// Check relative convergence of β.
589fn beta_converged(beta_new: &[f64], beta_old: &[f64], tol: f64) -> bool {
590    let diff: f64 = beta_new
591        .iter()
592        .zip(beta_old.iter())
593        .map(|(&a, &b)| (a - b).powi(2))
594        .sum::<f64>()
595        .sqrt();
596    let norm: f64 = beta_old
597        .iter()
598        .map(|&b| b * b)
599        .sum::<f64>()
600        .sqrt()
601        .max(1e-10);
602    diff / norm < tol
603}
604
605/// Center warping functions using Karcher mean.
606fn center_warps(gammas: &mut FdMatrix, argvals: &[f64]) {
607    let (n, m) = gammas.shape();
608    let gam_mu = sqrt_mean_inverse(gammas, argvals);
609    for i in 0..n {
610        let gam_i: Vec<f64> = (0..m).map(|j| gammas[(i, j)]).collect();
611        let composed = crate::alignment::compose_warps(&gam_i, &gam_mu, argvals);
612        for j in 0..m {
613            gammas[(i, j)] = composed[j];
614        }
615    }
616}
617
618/// Compute logistic gradients for α and β, returning (grad_a, grad_beta, probabilities).
619fn logistic_gradients(
620    q_aligned: &FdMatrix,
621    beta: &[f64],
622    weights: &[f64],
623    alpha: f64,
624    y: &[i8],
625    lambda: f64,
626) -> (f64, Vec<f64>, Vec<f64>) {
627    let (n, m) = q_aligned.shape();
628    let eta = srsf_fitted_values(q_aligned, beta, weights, alpha);
629    let prob: Vec<f64> = eta.iter().map(|&e| 1.0 / (1.0 + (-e).exp())).collect();
630
631    let mut grad_a = 0.0;
632    for i in 0..n {
633        let target = if y[i] == 1 { 1.0 } else { 0.0 };
634        grad_a += prob[i] - target;
635    }
636    grad_a /= n as f64;
637
638    let mut grad_beta = vec![0.0; m];
639    for j in 0..m {
640        for i in 0..n {
641            let target = if y[i] == 1 { 1.0 } else { 0.0 };
642            grad_beta[j] += (prob[i] - target) * q_aligned[(i, j)] * weights[j];
643        }
644        grad_beta[j] /= n as f64;
645        grad_beta[j] += lambda * beta[j];
646    }
647
648    (grad_a, grad_beta, prob)
649}
650
651/// OLS regression on PC scores: returns (coefs, alpha, fitted, sse, r²).
652fn ols_on_scores(
653    scores_mat: &FdMatrix,
654    y: &[f64],
655    n: usize,
656    ncomp: usize,
657) -> Option<(Vec<f64>, f64, Vec<f64>, f64, f64)> {
658    let y_mean: f64 = y.iter().sum::<f64>() / n as f64;
659    let mut score_means = vec![0.0; ncomp];
660    for k in 0..ncomp {
661        for i in 0..n {
662            score_means[k] += scores_mat[(i, k)];
663        }
664        score_means[k] /= n as f64;
665    }
666
667    let mut x_cen = DMatrix::zeros(n, ncomp);
668    for i in 0..n {
669        for k in 0..ncomp {
670            x_cen[(i, k)] = scores_mat[(i, k)] - score_means[k];
671        }
672    }
673    let y_cen: Vec<f64> = y.iter().map(|&yi| yi - y_mean).collect();
674    let y_vec = DVector::from_vec(y_cen);
675
676    let xtx = x_cen.transpose() * &x_cen;
677    let xty = x_cen.transpose() * &y_vec;
678    let coefficients = if let Some(chol) = xtx.clone().cholesky() {
679        chol.solve(&xty)
680    } else {
681        let svd = nalgebra::SVD::new(xtx, true, true);
682        svd.solve(&xty, 1e-10).ok()?
683    };
684    let coefs: Vec<f64> = coefficients.iter().cloned().collect();
685
686    let alpha = y_mean
687        - coefs
688            .iter()
689            .zip(score_means.iter())
690            .map(|(&c, &sm)| c * sm)
691            .sum::<f64>();
692
693    let mut fitted_values = vec![0.0; n];
694    for i in 0..n {
695        fitted_values[i] = alpha;
696        for k in 0..ncomp {
697            fitted_values[i] += coefs[k] * scores_mat[(i, k)];
698        }
699    }
700
701    let sse: f64 = y
702        .iter()
703        .zip(fitted_values.iter())
704        .map(|(&yi, &yh)| (yi - yh).powi(2))
705        .sum();
706    let ss_tot: f64 = y.iter().map(|&yi| (yi - y_mean).powi(2)).sum();
707    let r_squared = if ss_tot > 0.0 {
708        1.0 - sse / ss_tot
709    } else {
710        0.0
711    };
712
713    Some((coefs, alpha, fitted_values, sse, r_squared))
714}
715
716/// Armijo line search for logistic regression. Returns optimal step size.
717fn armijo_line_search_logistic(
718    q_aligned: &FdMatrix,
719    alpha: f64,
720    beta: &[f64],
721    grad_a: f64,
722    grad_beta: &[f64],
723    weights: &[f64],
724    y: &[i8],
725    lambda: f64,
726    loss_current: f64,
727    grad_norm_sq: f64,
728) -> f64 {
729    let mut step = 1.0;
730    for _ in 0..20 {
731        let alpha_trial = alpha - step * grad_a;
732        let beta_trial: Vec<f64> = beta
733            .iter()
734            .zip(grad_beta.iter())
735            .map(|(&b, &g)| b - step * g)
736            .collect();
737        let eta_trial = srsf_fitted_values(q_aligned, &beta_trial, weights, alpha_trial);
738        let prob_trial: Vec<f64> = eta_trial
739            .iter()
740            .map(|&e| 1.0 / (1.0 + (-e).exp()))
741            .collect();
742        let loss_trial = logistic_loss(&prob_trial, y, &beta_trial, lambda);
743        if loss_trial <= loss_current - 1e-4 * step * grad_norm_sq {
744            break;
745        }
746        step *= 0.5;
747    }
748    step
749}
750
751/// One iteration step of elastic regression: align, solve OLS, return new (β, α).
752fn regression_iteration_step(
753    q_all: &FdMatrix,
754    gammas: &FdMatrix,
755    argvals: &[f64],
756    b_mat: &DMatrix<f64>,
757    r_trimmed: &DMatrix<f64>,
758    weights: &[f64],
759    y: &[f64],
760    alpha: f64,
761    lambda: f64,
762    n: usize,
763    m: usize,
764    actual_nbasis: usize,
765) -> Option<(Vec<f64>, f64)> {
766    let q_aligned = apply_warps_to_srsfs(q_all, gammas, argvals);
767    let phi = build_phi_matrix(&q_aligned, b_mat, weights, n, m, actual_nbasis);
768    let y_centered: Vec<f64> = y.iter().map(|&yi| yi - alpha).collect();
769    let coefs = solve_penalized_ols(&phi, r_trimmed, &y_centered, lambda)?;
770    let beta_new = reconstruct_beta_from_coefs(&coefs, b_mat, m, actual_nbasis);
771    let alpha_new = compute_alpha_from_residuals(&q_aligned, &beta_new, weights, y);
772    Some((beta_new, alpha_new))
773}
774
775/// Build B-spline basis matrix and roughness penalty for β representation.
776fn build_basis_and_penalty(
777    argvals: &[f64],
778    ncomp_beta: usize,
779    m: usize,
780) -> (DMatrix<f64>, DMatrix<f64>, usize) {
781    let nknots = ncomp_beta.saturating_sub(4).max(2);
782    let basis_flat = bspline_basis(argvals, nknots, 4);
783    let actual_nbasis = basis_flat.len() / m;
784    let b_mat = DMatrix::from_column_slice(m, actual_nbasis, &basis_flat);
785
786    let penalty_flat = bspline_penalty_matrix(argvals, ncomp_beta, 4, 2);
787    let penalty_k = (penalty_flat.len() as f64).sqrt() as usize;
788    let r_mat = DMatrix::from_column_slice(penalty_k, penalty_k, &penalty_flat);
789    let r_trimmed = trim_penalty_to_basis(&r_mat, penalty_k, actual_nbasis);
790
791    (b_mat, r_trimmed, actual_nbasis)
792}
793
794/// Update warping functions for all curves in elastic regression.
795fn update_regression_warps(
796    gammas: &mut FdMatrix,
797    q_all: &FdMatrix,
798    beta: &[f64],
799    argvals: &[f64],
800    alpha: f64,
801    y: &[f64],
802    lambda: f64,
803) {
804    let (n, m) = q_all.shape();
805    for i in 0..n {
806        let qi: Vec<f64> = (0..m).map(|j| q_all[(i, j)]).collect();
807        let new_gam = regression_warp(&qi, beta, argvals, alpha, y[i], lambda);
808        for j in 0..m {
809            gammas[(i, j)] = new_gam[j];
810        }
811    }
812}
813
814/// Update warping functions for all curves in elastic logistic regression.
815fn update_logistic_warps(
816    gammas: &mut FdMatrix,
817    q_all: &FdMatrix,
818    beta: &[f64],
819    y: &[i8],
820    argvals: &[f64],
821    lambda: f64,
822) {
823    let (n, m) = q_all.shape();
824    for i in 0..n {
825        let qi: Vec<f64> = (0..m).map(|j| q_all[(i, j)]).collect();
826        let beta_signed: Vec<f64> = beta.iter().map(|&b| b * y[i] as f64).collect();
827        let new_gam = dp_alignment_core(&beta_signed, &qi, argvals, lambda);
828        for j in 0..m {
829            gammas[(i, j)] = new_gam[j];
830        }
831    }
832}
833
834/// Compute residuals, SSE, and R² from y and fitted values.
835fn compute_regression_residuals(
836    y: &[f64],
837    fitted_values: &[f64],
838    y_mean: f64,
839) -> (Vec<f64>, f64, f64) {
840    let residuals: Vec<f64> = y
841        .iter()
842        .zip(fitted_values.iter())
843        .map(|(&yi, &yh)| yi - yh)
844        .collect();
845    let sse: f64 = residuals.iter().map(|&r| r * r).sum();
846    let ss_tot: f64 = y.iter().map(|&yi| (yi - y_mean).powi(2)).sum();
847    let r_squared = if ss_tot > 0.0 {
848        1.0 - sse / ss_tot
849    } else {
850        0.0
851    };
852    (residuals, sse, r_squared)
853}
854
855/// Compute final logistic predictions: probabilities, classes, accuracy, loss.
856fn compute_logistic_predictions(
857    aligned_srsfs: &FdMatrix,
858    beta: &[f64],
859    weights: &[f64],
860    alpha: f64,
861    y: &[i8],
862    lambda: f64,
863) -> (Vec<f64>, Vec<i8>, f64, f64) {
864    let n = y.len();
865    let eta = srsf_fitted_values(aligned_srsfs, beta, weights, alpha);
866    let probabilities: Vec<f64> = eta.iter().map(|&e| 1.0 / (1.0 + (-e).exp())).collect();
867    let predicted_classes: Vec<i8> = probabilities
868        .iter()
869        .map(|&p| if p >= 0.5 { 1 } else { -1 })
870        .collect();
871    let accuracy = predicted_classes
872        .iter()
873        .zip(y.iter())
874        .filter(|(&p, &t)| p == t)
875        .count() as f64
876        / n as f64;
877    let loss = logistic_loss(&probabilities, y, beta, lambda);
878    (probabilities, predicted_classes, accuracy, loss)
879}
880
881/// Check if either extreme warp is already close enough to target y.
882fn check_extreme_warps(
883    gam_pos: &[f64],
884    gam_neg: &[f64],
885    y_pos: f64,
886    y_neg: f64,
887    y_i: f64,
888) -> Option<Vec<f64>> {
889    if (y_pos - y_i).abs() <= (y_neg - y_i).abs() {
890        if (y_pos - y_i).abs() < 1e-10 {
891            return Some(gam_pos.to_vec());
892        }
893    } else if (y_neg - y_i).abs() < 1e-10 {
894        return Some(gam_neg.to_vec());
895    }
896    None
897}
898
899/// Order warps so gam_lo gives lower prediction and gam_hi gives higher.
900fn order_warps_by_prediction(
901    gam_pos: Vec<f64>,
902    gam_neg: Vec<f64>,
903    y_pos: f64,
904    y_neg: f64,
905) -> (Vec<f64>, Vec<f64>) {
906    if y_pos < y_neg {
907        (gam_pos, gam_neg)
908    } else {
909        (gam_neg, gam_pos)
910    }
911}
912
913/// Binary search between two warps to find one giving predicted y closest to target.
914fn binary_search_warps(
915    mut gam_lo: Vec<f64>,
916    mut gam_hi: Vec<f64>,
917    q_i: &[f64],
918    beta: &[f64],
919    argvals: &[f64],
920    alpha: f64,
921    y_i: f64,
922    weights: &[f64],
923) -> Vec<f64> {
924    for _ in 0..15 {
925        let gam_mid: Vec<f64> = gam_lo
926            .iter()
927            .zip(gam_hi.iter())
928            .map(|(&lo, &hi)| 0.5 * (lo + hi))
929            .collect();
930        let y_mid = compute_predicted_y(q_i, beta, &gam_mid, argvals, alpha, weights);
931        if (y_mid - y_i).abs() < 1e-6 {
932            return gam_mid;
933        }
934        if y_mid < y_i {
935            gam_lo = gam_mid;
936        } else {
937            gam_hi = gam_mid;
938        }
939    }
940    gam_lo
941        .iter()
942        .zip(gam_hi.iter())
943        .map(|(&lo, &hi)| 0.5 * (lo + hi))
944        .collect()
945}
946
947/// Predict new responses using a fitted elastic regression model.
948///
949/// Transforms new curves to SRSFs, aligns them using the training warps as
950/// a template (identity alignment for new data), then applies the fitted
951/// regression coefficients.
952///
953/// # Arguments
954/// * `fit` — A fitted [`ElasticRegressionResult`]
955/// * `new_data` — New functional data (n_new × m)
956/// * `argvals` — Evaluation points (length m)
957pub fn predict_elastic_regression(
958    fit: &ElasticRegressionResult,
959    new_data: &FdMatrix,
960    argvals: &[f64],
961) -> Vec<f64> {
962    let weights = simpsons_weights(argvals);
963    let q_new = srsf_transform(new_data, argvals);
964    srsf_fitted_values(&q_new, &fit.beta, &weights, fit.alpha)
965}
966
967/// Predict probabilities for new data using a fitted elastic logistic model.
968///
969/// Transforms new curves to SRSFs and applies the fitted logistic
970/// coefficients to produce P(Y=1).
971///
972/// # Arguments
973/// * `fit` — A fitted [`ElasticLogisticResult`]
974/// * `new_data` — New functional data (n_new × m)
975/// * `argvals` — Evaluation points (length m)
976pub fn predict_elastic_logistic(
977    fit: &ElasticLogisticResult,
978    new_data: &FdMatrix,
979    argvals: &[f64],
980) -> Vec<f64> {
981    let weights = simpsons_weights(argvals);
982    let q_new = srsf_transform(new_data, argvals);
983    let eta = srsf_fitted_values(&q_new, &fit.beta, &weights, fit.alpha);
984    eta.iter().map(|&e| 1.0 / (1.0 + (-e).exp())).collect()
985}
986
987impl ElasticRegressionResult {
988    /// Predict responses for new data. Delegates to [`predict_elastic_regression`].
989    pub fn predict(&self, new_data: &FdMatrix, argvals: &[f64]) -> Vec<f64> {
990        predict_elastic_regression(self, new_data, argvals)
991    }
992}
993
994impl ElasticLogisticResult {
995    /// Predict probabilities for new data. Delegates to [`predict_elastic_logistic`].
996    pub fn predict(&self, new_data: &FdMatrix, argvals: &[f64]) -> Vec<f64> {
997        predict_elastic_logistic(self, new_data, argvals)
998    }
999}
1000
1001#[cfg(test)]
1002mod tests {
1003    use super::*;
1004    use std::f64::consts::PI;
1005
1006    fn generate_test_data(n: usize, m: usize) -> (FdMatrix, Vec<f64>, Vec<f64>) {
1007        let t: Vec<f64> = (0..m).map(|j| j as f64 / (m - 1) as f64).collect();
1008        let mut data = FdMatrix::zeros(n, m);
1009        let mut y = vec![0.0; n];
1010
1011        for i in 0..n {
1012            let amp = 1.0 + 0.5 * (i as f64 / n as f64);
1013            let shift = 0.1 * (i as f64 - n as f64 / 2.0);
1014            for j in 0..m {
1015                data[(i, j)] = amp * (2.0 * PI * (t[j] + shift)).sin();
1016            }
1017            y[i] = amp; // Response related to amplitude
1018        }
1019        (data, y, t)
1020    }
1021
1022    #[test]
1023    fn test_elastic_regression_basic() {
1024        let (data, y, t) = generate_test_data(15, 51);
1025        let result = elastic_regression(&data, &y, &t, 10, 1e-3, 5, 1e-3);
1026        assert!(result.is_some(), "elastic_regression should succeed");
1027
1028        let res = result.unwrap();
1029        assert_eq!(res.fitted_values.len(), 15);
1030        assert_eq!(res.beta.len(), 51);
1031        assert_eq!(res.gammas.shape(), (15, 51));
1032        assert!(res.n_iter > 0);
1033    }
1034
1035    #[test]
1036    fn test_elastic_logistic_basic() {
1037        let n = 20;
1038        let m = 51;
1039        let t: Vec<f64> = (0..m).map(|j| j as f64 / (m - 1) as f64).collect();
1040        let mut data = FdMatrix::zeros(n, m);
1041        let mut y = vec![0_i8; n];
1042
1043        for i in 0..n {
1044            let label = if i < n / 2 { -1_i8 } else { 1_i8 };
1045            y[i] = label;
1046            let amp = if label == 1 { 2.0 } else { 1.0 };
1047            for j in 0..m {
1048                data[(i, j)] = amp * (2.0 * PI * t[j]).sin();
1049            }
1050        }
1051
1052        let result = elastic_logistic(&data, &y, &t, 10, 1e-2, 5, 1e-3);
1053        assert!(result.is_some(), "elastic_logistic should succeed");
1054
1055        let res = result.unwrap();
1056        assert_eq!(res.probabilities.len(), n);
1057        assert_eq!(res.predicted_classes.len(), n);
1058        assert!(res.accuracy >= 0.0 && res.accuracy <= 1.0);
1059    }
1060
1061    #[test]
1062    fn test_elastic_pcr_vertical() {
1063        let (data, y, t) = generate_test_data(15, 51);
1064        let result = elastic_pcr(&data, &y, &t, 3, PcaMethod::Vertical, 0.0, 5, 1e-3);
1065        assert!(result.is_some(), "elastic_pcr (vertical) should succeed");
1066
1067        let res = result.unwrap();
1068        assert_eq!(res.fitted_values.len(), 15);
1069        assert_eq!(res.coefficients.len(), 3);
1070    }
1071
1072    #[test]
1073    fn test_elastic_pcr_horizontal() {
1074        let (data, y, t) = generate_test_data(15, 51);
1075        let result = elastic_pcr(&data, &y, &t, 3, PcaMethod::Horizontal, 0.0, 5, 1e-3);
1076        assert!(result.is_some(), "elastic_pcr (horizontal) should succeed");
1077    }
1078
1079    #[test]
1080    fn test_elastic_regression_invalid() {
1081        let data = FdMatrix::zeros(1, 10);
1082        let y = vec![1.0];
1083        let t: Vec<f64> = (0..10).map(|i| i as f64 / 9.0).collect();
1084        assert!(elastic_regression(&data, &y, &t, 5, 1e-3, 5, 1e-3).is_none());
1085    }
1086}