1use crate::alignment::{dp_alignment_core, reparameterize_curve, sqrt_mean_inverse};
4use crate::basis::bspline_basis;
5use crate::helpers::simpsons_weights;
6use crate::matrix::FdMatrix;
7use crate::smooth_basis::bspline_penalty_matrix;
8use nalgebra::{DMatrix, DVector};
9
10use super::{
11 apply_warps_to_srsfs, beta_converged, init_identity_warps, srsf_fitted_values, ElasticConfig,
12};
13
14use crate::alignment::srsf_transform;
15
16#[derive(Debug, Clone, PartialEq)]
18#[non_exhaustive]
19pub struct ElasticRegressionResult {
20 pub alpha: f64,
22 pub beta: Vec<f64>,
24 pub fitted_values: Vec<f64>,
26 pub residuals: Vec<f64>,
28 pub sse: f64,
30 pub r_squared: f64,
32 pub gammas: FdMatrix,
34 pub aligned_srsfs: FdMatrix,
36 pub n_iter: usize,
38}
39
40#[must_use = "expensive computation whose result should not be discarded"]
65pub fn elastic_regression(
66 data: &FdMatrix,
67 y: &[f64],
68 argvals: &[f64],
69 ncomp_beta: usize,
70 lambda: f64,
71 max_iter: usize,
72 tol: f64,
73) -> Result<ElasticRegressionResult, crate::FdarError> {
74 let (n, m) = data.shape();
75 if n < 2 || m < 2 || y.len() != n || argvals.len() != m || ncomp_beta < 2 {
76 return Err(crate::FdarError::InvalidDimension {
77 parameter: "data/y/argvals",
78 expected: "n >= 2, m >= 2, y.len() == n, argvals.len() == m, ncomp_beta >= 2"
79 .to_string(),
80 actual: format!(
81 "n={}, m={}, y.len()={}, argvals.len()={}, ncomp_beta={}",
82 n,
83 m,
84 y.len(),
85 argvals.len(),
86 ncomp_beta
87 ),
88 });
89 }
90
91 let weights = simpsons_weights(argvals);
92 let q_all = srsf_transform(data, argvals);
93
94 let (b_mat, r_trimmed, actual_nbasis) = build_basis_and_penalty(argvals, ncomp_beta, m);
95
96 let mut gammas = init_identity_warps(n, argvals);
97 let y_mean: f64 = y.iter().sum::<f64>() / n as f64;
98 let mut beta = vec![0.0; m];
99 let mut alpha = y_mean;
100 let mut n_iter = 0;
101
102 for iter in 0..max_iter {
103 n_iter = iter + 1;
104
105 let (beta_new, alpha_new) = regression_iteration_step(
106 &q_all,
107 &gammas,
108 argvals,
109 &b_mat,
110 &r_trimmed,
111 &weights,
112 y,
113 alpha,
114 lambda,
115 n,
116 m,
117 actual_nbasis,
118 )
119 .ok_or_else(|| crate::FdarError::ComputationFailed {
120 operation: "regression_iteration",
121 detail: format!(
122 "iteration {} failed; try increasing lambda or reducing nbasis",
123 iter + 1
124 ),
125 })?;
126
127 if beta_converged(&beta_new, &beta, tol) && iter > 0 {
128 beta = beta_new;
129 alpha = alpha_new;
130 break;
131 }
132
133 beta = beta_new;
134 alpha = alpha_new;
135
136 update_regression_warps(&mut gammas, &q_all, &beta, argvals, alpha, y, lambda * 0.01);
137 center_warps(&mut gammas, argvals);
138 }
139
140 let aligned_srsfs = apply_warps_to_srsfs(&q_all, &gammas, argvals);
142 let fitted_values = srsf_fitted_values(&aligned_srsfs, &beta, &weights, alpha);
143 let (residuals, sse, r_squared) = compute_regression_residuals(y, &fitted_values, y_mean);
144
145 Ok(ElasticRegressionResult {
146 alpha,
147 beta,
148 fitted_values,
149 residuals,
150 sse,
151 r_squared,
152 gammas,
153 aligned_srsfs,
154 n_iter,
155 })
156}
157
158#[must_use = "expensive computation whose result should not be discarded"]
162pub fn elastic_regression_with_config(
163 data: &FdMatrix,
164 y: &[f64],
165 argvals: &[f64],
166 config: &ElasticConfig,
167) -> Result<ElasticRegressionResult, crate::FdarError> {
168 elastic_regression(
169 data,
170 y,
171 argvals,
172 config.ncomp_beta,
173 config.lambda,
174 config.max_iter,
175 config.tol,
176 )
177}
178
179pub fn predict_elastic_regression(
190 fit: &ElasticRegressionResult,
191 new_data: &FdMatrix,
192 argvals: &[f64],
193) -> Vec<f64> {
194 let weights = simpsons_weights(argvals);
195 let q_new = srsf_transform(new_data, argvals);
196 srsf_fitted_values(&q_new, &fit.beta, &weights, fit.alpha)
197}
198
199impl ElasticRegressionResult {
200 pub fn predict(&self, new_data: &FdMatrix, argvals: &[f64]) -> Vec<f64> {
202 predict_elastic_regression(self, new_data, argvals)
203 }
204}
205
206fn regression_warp(
213 q_i: &[f64],
214 beta: &[f64],
215 argvals: &[f64],
216 alpha: f64,
217 y_i: f64,
218 lambda: f64,
219) -> Vec<f64> {
220 let weights = simpsons_weights(argvals);
221
222 let gam_pos = dp_alignment_core(beta, q_i, argvals, lambda);
224
225 let neg_beta: Vec<f64> = beta.iter().map(|&b| -b).collect();
227 let gam_neg = dp_alignment_core(&neg_beta, q_i, argvals, lambda);
228
229 let y_pos = compute_predicted_y(q_i, beta, &gam_pos, argvals, alpha, &weights);
231 let y_neg = compute_predicted_y(q_i, beta, &gam_neg, argvals, alpha, &weights);
232
233 if let Some(gam) = check_extreme_warps(&gam_pos, &gam_neg, y_pos, y_neg, y_i) {
235 return gam;
236 }
237
238 let (gam_lo, gam_hi) = order_warps_by_prediction(gam_pos, gam_neg, y_pos, y_neg);
240 binary_search_warps(gam_lo, gam_hi, q_i, beta, argvals, alpha, y_i, &weights)
241}
242
243fn compute_predicted_y(
245 q_i: &[f64],
246 beta: &[f64],
247 gam: &[f64],
248 argvals: &[f64],
249 alpha: f64,
250 weights: &[f64],
251) -> f64 {
252 let m = argvals.len();
253 let q_warped = reparameterize_curve(q_i, argvals, gam);
254 let h = (argvals[m - 1] - argvals[0]) / (m - 1) as f64;
255 let gam_deriv = crate::helpers::gradient_uniform(gam, h);
256
257 let mut y_hat = alpha;
258 for j in 0..m {
259 let q_aligned_j = q_warped[j] * gam_deriv[j].max(0.0).sqrt();
260 y_hat += q_aligned_j * beta[j] * weights[j];
261 }
262 y_hat
263}
264
265fn build_basis_and_penalty(
267 argvals: &[f64],
268 ncomp_beta: usize,
269 m: usize,
270) -> (DMatrix<f64>, DMatrix<f64>, usize) {
271 let nknots = ncomp_beta.saturating_sub(4).max(2);
272 let basis_flat = bspline_basis(argvals, nknots, 4);
273 let actual_nbasis = basis_flat.len() / m;
274 let b_mat = DMatrix::from_column_slice(m, actual_nbasis, &basis_flat);
275
276 let penalty_flat = bspline_penalty_matrix(argvals, ncomp_beta, 4, 2);
277 let penalty_k = (penalty_flat.len() as f64).sqrt() as usize;
278 let r_mat = DMatrix::from_column_slice(penalty_k, penalty_k, &penalty_flat);
279 let r_trimmed = trim_penalty_to_basis(&r_mat, penalty_k, actual_nbasis);
280
281 (b_mat, r_trimmed, actual_nbasis)
282}
283
284fn trim_penalty_to_basis(
286 r_mat: &DMatrix<f64>,
287 penalty_k: usize,
288 actual_nbasis: usize,
289) -> DMatrix<f64> {
290 if penalty_k >= actual_nbasis {
291 r_mat
292 .view((0, 0), (actual_nbasis, actual_nbasis))
293 .into_owned()
294 } else {
295 let mut r = DMatrix::zeros(actual_nbasis, actual_nbasis);
296 let dim = penalty_k.min(actual_nbasis);
297 for i in 0..dim {
298 for j in 0..dim {
299 r[(i, j)] = r_mat[(i, j)];
300 }
301 }
302 r
303 }
304}
305
306fn build_phi_matrix(
308 q_aligned: &FdMatrix,
309 b_mat: &DMatrix<f64>,
310 weights: &[f64],
311 n: usize,
312 m: usize,
313 actual_nbasis: usize,
314) -> DMatrix<f64> {
315 let mut phi = DMatrix::zeros(n, actual_nbasis);
316 for i in 0..n {
317 for k in 0..actual_nbasis {
318 let mut val = 0.0;
319 for j in 0..m {
320 val += q_aligned[(i, j)] * b_mat[(j, k)] * weights[j];
321 }
322 phi[(i, k)] = val;
323 }
324 }
325 phi
326}
327
328pub(super) fn solve_penalized_ols(
330 phi: &DMatrix<f64>,
331 r_trimmed: &DMatrix<f64>,
332 y_centered: &[f64],
333 lambda: f64,
334) -> Option<Vec<f64>> {
335 let y_vec = DVector::from_vec(y_centered.to_vec());
336 let phi_t_phi = phi.transpose() * phi;
337 let system = &phi_t_phi + lambda * r_trimmed;
338 let rhs = phi.transpose() * &y_vec;
339 let coefs = if let Some(chol) = system.clone().cholesky() {
340 chol.solve(&rhs)
341 } else {
342 let svd = nalgebra::SVD::new(system, true, true);
343 svd.solve(&rhs, 1e-10).ok()?
344 };
345 Some(coefs.iter().copied().collect())
346}
347
348fn reconstruct_beta_from_coefs(
350 coefs: &[f64],
351 b_mat: &DMatrix<f64>,
352 m: usize,
353 actual_nbasis: usize,
354) -> Vec<f64> {
355 let mut beta = vec![0.0; m];
356 for j in 0..m {
357 for k in 0..actual_nbasis {
358 beta[j] += coefs[k] * b_mat[(j, k)];
359 }
360 }
361 beta
362}
363
364fn compute_alpha_from_residuals(
366 q_aligned: &FdMatrix,
367 beta: &[f64],
368 weights: &[f64],
369 y: &[f64],
370) -> f64 {
371 let (n, m) = q_aligned.shape();
372 let mut alpha = 0.0;
373 for i in 0..n {
374 let mut y_hat_i = 0.0;
375 for j in 0..m {
376 y_hat_i += q_aligned[(i, j)] * beta[j] * weights[j];
377 }
378 alpha += y[i] - y_hat_i;
379 }
380 alpha / n as f64
381}
382
383fn regression_iteration_step(
385 q_all: &FdMatrix,
386 gammas: &FdMatrix,
387 argvals: &[f64],
388 b_mat: &DMatrix<f64>,
389 r_trimmed: &DMatrix<f64>,
390 weights: &[f64],
391 y: &[f64],
392 alpha: f64,
393 lambda: f64,
394 n: usize,
395 m: usize,
396 actual_nbasis: usize,
397) -> Option<(Vec<f64>, f64)> {
398 let q_aligned = apply_warps_to_srsfs(q_all, gammas, argvals);
399 let phi = build_phi_matrix(&q_aligned, b_mat, weights, n, m, actual_nbasis);
400 let y_centered: Vec<f64> = y.iter().map(|&yi| yi - alpha).collect();
401 let coefs = solve_penalized_ols(&phi, r_trimmed, &y_centered, lambda)?;
402 let beta_new = reconstruct_beta_from_coefs(&coefs, b_mat, m, actual_nbasis);
403 let alpha_new = compute_alpha_from_residuals(&q_aligned, &beta_new, weights, y);
404 Some((beta_new, alpha_new))
405}
406
407fn update_regression_warps(
409 gammas: &mut FdMatrix,
410 q_all: &FdMatrix,
411 beta: &[f64],
412 argvals: &[f64],
413 alpha: f64,
414 y: &[f64],
415 lambda: f64,
416) {
417 let (n, m) = q_all.shape();
418 for i in 0..n {
419 let qi: Vec<f64> = (0..m).map(|j| q_all[(i, j)]).collect();
420 let new_gam = regression_warp(&qi, beta, argvals, alpha, y[i], lambda);
421 for j in 0..m {
422 gammas[(i, j)] = new_gam[j];
423 }
424 }
425}
426
427fn center_warps(gammas: &mut FdMatrix, argvals: &[f64]) {
429 let (n, m) = gammas.shape();
430 let gam_mu = sqrt_mean_inverse(gammas, argvals);
431 for i in 0..n {
432 let gam_i: Vec<f64> = (0..m).map(|j| gammas[(i, j)]).collect();
433 let composed = crate::alignment::compose_warps(&gam_i, &gam_mu, argvals);
434 for j in 0..m {
435 gammas[(i, j)] = composed[j];
436 }
437 }
438}
439
440fn compute_regression_residuals(
442 y: &[f64],
443 fitted_values: &[f64],
444 y_mean: f64,
445) -> (Vec<f64>, f64, f64) {
446 let residuals: Vec<f64> = y
447 .iter()
448 .zip(fitted_values.iter())
449 .map(|(&yi, &yh)| yi - yh)
450 .collect();
451 let sse: f64 = residuals.iter().map(|&r| r * r).sum();
452 let ss_tot: f64 = y.iter().map(|&yi| (yi - y_mean).powi(2)).sum();
453 let r_squared = if ss_tot > 0.0 {
454 1.0 - sse / ss_tot
455 } else {
456 0.0
457 };
458 (residuals, sse, r_squared)
459}
460
461fn check_extreme_warps(
463 gam_pos: &[f64],
464 gam_neg: &[f64],
465 y_pos: f64,
466 y_neg: f64,
467 y_i: f64,
468) -> Option<Vec<f64>> {
469 if (y_pos - y_i).abs() <= (y_neg - y_i).abs() {
470 if (y_pos - y_i).abs() < 1e-10 {
471 return Some(gam_pos.to_vec());
472 }
473 } else if (y_neg - y_i).abs() < 1e-10 {
474 return Some(gam_neg.to_vec());
475 }
476 None
477}
478
479fn order_warps_by_prediction(
481 gam_pos: Vec<f64>,
482 gam_neg: Vec<f64>,
483 y_pos: f64,
484 y_neg: f64,
485) -> (Vec<f64>, Vec<f64>) {
486 if y_pos < y_neg {
487 (gam_pos, gam_neg)
488 } else {
489 (gam_neg, gam_pos)
490 }
491}
492
493fn binary_search_warps(
495 mut gam_lo: Vec<f64>,
496 mut gam_hi: Vec<f64>,
497 q_i: &[f64],
498 beta: &[f64],
499 argvals: &[f64],
500 alpha: f64,
501 y_i: f64,
502 weights: &[f64],
503) -> Vec<f64> {
504 for _ in 0..15 {
505 let gam_mid: Vec<f64> = gam_lo
506 .iter()
507 .zip(gam_hi.iter())
508 .map(|(&lo, &hi)| 0.5 * (lo + hi))
509 .collect();
510 let y_mid = compute_predicted_y(q_i, beta, &gam_mid, argvals, alpha, weights);
511 if (y_mid - y_i).abs() < 1e-6 {
512 return gam_mid;
513 }
514 if y_mid < y_i {
515 gam_lo = gam_mid;
516 } else {
517 gam_hi = gam_mid;
518 }
519 }
520 gam_lo
521 .iter()
522 .zip(gam_hi.iter())
523 .map(|(&lo, &hi)| 0.5 * (lo + hi))
524 .collect()
525}