1use crate::alignment::{
12 dp_alignment_core, karcher_mean, reparameterize_curve, sqrt_mean_inverse, srsf_transform,
13 KarcherMeanResult,
14};
15use crate::basis::bspline_basis;
16use crate::elastic_fpca::{
17 horiz_fpca, joint_fpca, vert_fpca, HorizFpcaResult, JointFpcaResult, VertFpcaResult,
18};
19use crate::helpers::simpsons_weights;
20use crate::matrix::FdMatrix;
21use crate::smooth_basis::bspline_penalty_matrix;
22use nalgebra::{DMatrix, DVector};
23
24#[derive(Debug, Clone)]
28pub struct ElasticRegressionResult {
29 pub alpha: f64,
31 pub beta: Vec<f64>,
33 pub fitted_values: Vec<f64>,
35 pub residuals: Vec<f64>,
37 pub sse: f64,
39 pub r_squared: f64,
41 pub gammas: FdMatrix,
43 pub aligned_srsfs: FdMatrix,
45 pub n_iter: usize,
47}
48
49#[derive(Debug, Clone)]
51pub struct ElasticLogisticResult {
52 pub alpha: f64,
54 pub beta: Vec<f64>,
56 pub probabilities: Vec<f64>,
58 pub predicted_classes: Vec<i8>,
60 pub accuracy: f64,
62 pub loss: f64,
64 pub gammas: FdMatrix,
66 pub aligned_srsfs: FdMatrix,
68 pub n_iter: usize,
70}
71
72#[derive(Debug, Clone, Copy, PartialEq)]
74pub enum PcaMethod {
75 Vertical,
76 Horizontal,
77 Joint,
78}
79
80#[derive(Debug, Clone)]
82pub struct ElasticPcrResult {
83 pub alpha: f64,
85 pub coefficients: Vec<f64>,
87 pub fitted_values: Vec<f64>,
89 pub sse: f64,
91 pub r_squared: f64,
93 pub pca_method: PcaMethod,
95 pub karcher: KarcherMeanResult,
97 pub vert_fpca: Option<VertFpcaResult>,
99 pub horiz_fpca: Option<HorizFpcaResult>,
101 pub joint_fpca: Option<JointFpcaResult>,
103}
104
105pub fn elastic_regression(
125 data: &FdMatrix,
126 y: &[f64],
127 argvals: &[f64],
128 ncomp_beta: usize,
129 lambda: f64,
130 max_iter: usize,
131 tol: f64,
132) -> Option<ElasticRegressionResult> {
133 let (n, m) = data.shape();
134 if n < 2 || m < 2 || y.len() != n || argvals.len() != m || ncomp_beta < 2 {
135 return None;
136 }
137
138 let weights = simpsons_weights(argvals);
139 let q_all = srsf_transform(data, argvals);
140
141 let (b_mat, r_trimmed, actual_nbasis) = build_basis_and_penalty(argvals, ncomp_beta, m);
142
143 let mut gammas = init_identity_warps(n, argvals);
144 let y_mean: f64 = y.iter().sum::<f64>() / n as f64;
145 let mut beta = vec![0.0; m];
146 let mut alpha = y_mean;
147 let mut n_iter = 0;
148
149 for iter in 0..max_iter {
150 n_iter = iter + 1;
151
152 let (beta_new, alpha_new) = regression_iteration_step(
153 &q_all,
154 &gammas,
155 argvals,
156 &b_mat,
157 &r_trimmed,
158 &weights,
159 y,
160 alpha,
161 lambda,
162 n,
163 m,
164 actual_nbasis,
165 )?;
166
167 if beta_converged(&beta_new, &beta, tol) && iter > 0 {
168 beta = beta_new;
169 alpha = alpha_new;
170 break;
171 }
172
173 beta = beta_new;
174 alpha = alpha_new;
175
176 update_regression_warps(&mut gammas, &q_all, &beta, argvals, alpha, y, lambda * 0.01);
177 center_warps(&mut gammas, argvals);
178 }
179
180 let aligned_srsfs = apply_warps_to_srsfs(&q_all, &gammas, argvals);
182 let fitted_values = srsf_fitted_values(&aligned_srsfs, &beta, &weights, alpha);
183 let (residuals, sse, r_squared) = compute_regression_residuals(y, &fitted_values, y_mean);
184
185 Some(ElasticRegressionResult {
186 alpha,
187 beta,
188 fitted_values,
189 residuals,
190 sse,
191 r_squared,
192 gammas,
193 aligned_srsfs,
194 n_iter,
195 })
196}
197
198pub fn elastic_logistic(
213 data: &FdMatrix,
214 y: &[i8],
215 argvals: &[f64],
216 _ncomp_beta: usize,
217 lambda: f64,
218 max_iter: usize,
219 tol: f64,
220) -> Option<ElasticLogisticResult> {
221 let (n, m) = data.shape();
222 if n < 2 || m < 2 || y.len() != n || argvals.len() != m {
223 return None;
224 }
225
226 let weights = simpsons_weights(argvals);
227 let q_all = srsf_transform(data, argvals);
228 let mut gammas = init_identity_warps(n, argvals);
229 let mut beta = vec![0.0; m];
230 let mut alpha = 0.0;
231 let mut n_iter = 0;
232
233 for iter in 0..max_iter {
234 n_iter = iter + 1;
235
236 let q_aligned = apply_warps_to_srsfs(&q_all, &gammas, argvals);
237 let (grad_a, grad_beta, prob) =
238 logistic_gradients(&q_aligned, &beta, &weights, alpha, y, lambda);
239
240 let loss_current = logistic_loss(&prob, y, &beta, lambda);
241 let grad_norm_sq: f64 = grad_a * grad_a + grad_beta.iter().map(|&g| g * g).sum::<f64>();
242
243 let step = armijo_line_search_logistic(
244 &q_aligned,
245 alpha,
246 &beta,
247 grad_a,
248 &grad_beta,
249 &weights,
250 y,
251 lambda,
252 loss_current,
253 grad_norm_sq,
254 );
255
256 let beta_new: Vec<f64> = beta
257 .iter()
258 .zip(grad_beta.iter())
259 .map(|(&b, &g)| b - step * g)
260 .collect();
261 let alpha_new = alpha - step * grad_a;
262
263 if beta_converged(&beta_new, &beta, tol) && iter > 0 {
264 beta = beta_new;
265 alpha = alpha_new;
266 break;
267 }
268
269 beta = beta_new;
270 alpha = alpha_new;
271
272 update_logistic_warps(&mut gammas, &q_all, &beta, y, argvals, lambda * 0.01);
273 }
274
275 let aligned_srsfs = apply_warps_to_srsfs(&q_all, &gammas, argvals);
277 let (probabilities, predicted_classes, accuracy, loss) =
278 compute_logistic_predictions(&aligned_srsfs, &beta, &weights, alpha, y, lambda);
279
280 Some(ElasticLogisticResult {
281 alpha,
282 beta,
283 probabilities,
284 predicted_classes,
285 accuracy,
286 loss,
287 gammas,
288 aligned_srsfs,
289 n_iter,
290 })
291}
292
293fn logistic_loss(prob: &[f64], y: &[i8], beta: &[f64], lambda: f64) -> f64 {
295 let n = prob.len();
296 let mut loss = 0.0;
297 for i in 0..n {
298 let target = if y[i] == 1 { 1.0 } else { 0.0 };
299 let p = prob[i].clamp(1e-15, 1.0 - 1e-15);
300 loss -= target * p.ln() + (1.0 - target) * (1.0 - p).ln();
301 }
302 loss /= n as f64;
303 loss += 0.5 * lambda * beta.iter().map(|&b| b * b).sum::<f64>();
305 loss
306}
307
308pub fn elastic_pcr(
325 data: &FdMatrix,
326 y: &[f64],
327 argvals: &[f64],
328 ncomp: usize,
329 pca_method: PcaMethod,
330 lambda: f64,
331 max_iter: usize,
332 tol: f64,
333) -> Option<ElasticPcrResult> {
334 let (n, _m) = data.shape();
335 if n < 2 || y.len() != n || ncomp < 1 {
336 return None;
337 }
338
339 let km = karcher_mean(data, argvals, max_iter, tol, lambda);
341
342 let mut stored_vert: Option<VertFpcaResult> = None;
344 let mut stored_horiz: Option<HorizFpcaResult> = None;
345 let mut stored_joint: Option<JointFpcaResult> = None;
346
347 let scores_mat = match pca_method {
348 PcaMethod::Vertical => {
349 let fpca = vert_fpca(&km, argvals, ncomp)?;
350 let scores = fpca.scores.clone();
351 stored_vert = Some(fpca);
352 scores
353 }
354 PcaMethod::Horizontal => {
355 let fpca = horiz_fpca(&km, argvals, ncomp)?;
356 let scores = fpca.scores.clone();
357 stored_horiz = Some(fpca);
358 scores
359 }
360 PcaMethod::Joint => {
361 let fpca = joint_fpca(&km, argvals, ncomp, None)?;
362 let scores = fpca.scores.clone();
363 stored_joint = Some(fpca);
364 scores
365 }
366 };
367
368 let actual_ncomp = scores_mat.ncols();
369 let (coefs, alpha, fitted_values, sse, r_squared) =
370 ols_on_scores(&scores_mat, y, n, actual_ncomp)?;
371
372 Some(ElasticPcrResult {
373 alpha,
374 coefficients: coefs,
375 fitted_values,
376 sse,
377 r_squared,
378 pca_method,
379 karcher: km,
380 vert_fpca: stored_vert,
381 horiz_fpca: stored_horiz,
382 joint_fpca: stored_joint,
383 })
384}
385
386fn regression_warp(
393 q_i: &[f64],
394 beta: &[f64],
395 argvals: &[f64],
396 alpha: f64,
397 y_i: f64,
398 lambda: f64,
399) -> Vec<f64> {
400 let weights = simpsons_weights(argvals);
401
402 let gam_pos = dp_alignment_core(beta, q_i, argvals, lambda);
404
405 let neg_beta: Vec<f64> = beta.iter().map(|&b| -b).collect();
407 let gam_neg = dp_alignment_core(&neg_beta, q_i, argvals, lambda);
408
409 let y_pos = compute_predicted_y(q_i, beta, &gam_pos, argvals, alpha, &weights);
411 let y_neg = compute_predicted_y(q_i, beta, &gam_neg, argvals, alpha, &weights);
412
413 if let Some(gam) = check_extreme_warps(&gam_pos, &gam_neg, y_pos, y_neg, y_i) {
415 return gam;
416 }
417
418 let (gam_lo, gam_hi) = order_warps_by_prediction(gam_pos, gam_neg, y_pos, y_neg);
420 binary_search_warps(gam_lo, gam_hi, q_i, beta, argvals, alpha, y_i, &weights)
421}
422
423fn compute_predicted_y(
425 q_i: &[f64],
426 beta: &[f64],
427 gam: &[f64],
428 argvals: &[f64],
429 alpha: f64,
430 weights: &[f64],
431) -> f64 {
432 let m = argvals.len();
433 let q_warped = reparameterize_curve(q_i, argvals, gam);
434 let h = (argvals[m - 1] - argvals[0]) / (m - 1) as f64;
435 let gam_deriv = crate::helpers::gradient_uniform(gam, h);
436
437 let mut y_hat = alpha;
438 for j in 0..m {
439 let q_aligned_j = q_warped[j] * gam_deriv[j].max(0.0).sqrt();
440 y_hat += q_aligned_j * beta[j] * weights[j];
441 }
442 y_hat
443}
444
445fn apply_warps_to_srsfs(q_all: &FdMatrix, gammas: &FdMatrix, argvals: &[f64]) -> FdMatrix {
449 let (n, m) = q_all.shape();
450 let h = (argvals[m - 1] - argvals[0]) / (m - 1) as f64;
451 let mut q_aligned = FdMatrix::zeros(n, m);
452 for i in 0..n {
453 let qi: Vec<f64> = (0..m).map(|j| q_all[(i, j)]).collect();
454 let gam: Vec<f64> = (0..m).map(|j| gammas[(i, j)]).collect();
455 let q_warped = reparameterize_curve(&qi, argvals, &gam);
456 let gam_deriv = crate::helpers::gradient_uniform(&gam, h);
457 for j in 0..m {
458 q_aligned[(i, j)] = q_warped[j] * gam_deriv[j].max(0.0).sqrt();
459 }
460 }
461 q_aligned
462}
463
464fn init_identity_warps(n: usize, argvals: &[f64]) -> FdMatrix {
466 let m = argvals.len();
467 let mut gammas = FdMatrix::zeros(n, m);
468 for i in 0..n {
469 for j in 0..m {
470 gammas[(i, j)] = argvals[j];
471 }
472 }
473 gammas
474}
475
476fn trim_penalty_to_basis(
478 r_mat: &DMatrix<f64>,
479 penalty_k: usize,
480 actual_nbasis: usize,
481) -> DMatrix<f64> {
482 if penalty_k >= actual_nbasis {
483 r_mat
484 .view((0, 0), (actual_nbasis, actual_nbasis))
485 .into_owned()
486 } else {
487 let mut r = DMatrix::zeros(actual_nbasis, actual_nbasis);
488 let dim = penalty_k.min(actual_nbasis);
489 for i in 0..dim {
490 for j in 0..dim {
491 r[(i, j)] = r_mat[(i, j)];
492 }
493 }
494 r
495 }
496}
497
498fn build_phi_matrix(
500 q_aligned: &FdMatrix,
501 b_mat: &DMatrix<f64>,
502 weights: &[f64],
503 n: usize,
504 m: usize,
505 actual_nbasis: usize,
506) -> DMatrix<f64> {
507 let mut phi = DMatrix::zeros(n, actual_nbasis);
508 for i in 0..n {
509 for k in 0..actual_nbasis {
510 let mut val = 0.0;
511 for j in 0..m {
512 val += q_aligned[(i, j)] * b_mat[(j, k)] * weights[j];
513 }
514 phi[(i, k)] = val;
515 }
516 }
517 phi
518}
519
520fn solve_penalized_ols(
522 phi: &DMatrix<f64>,
523 r_trimmed: &DMatrix<f64>,
524 y_centered: &[f64],
525 lambda: f64,
526) -> Option<Vec<f64>> {
527 let y_vec = DVector::from_vec(y_centered.to_vec());
528 let phi_t_phi = phi.transpose() * phi;
529 let system = &phi_t_phi + lambda * r_trimmed;
530 let rhs = phi.transpose() * &y_vec;
531 let coefs = if let Some(chol) = system.clone().cholesky() {
532 chol.solve(&rhs)
533 } else {
534 let svd = nalgebra::SVD::new(system, true, true);
535 svd.solve(&rhs, 1e-10).ok()?
536 };
537 Some(coefs.iter().cloned().collect())
538}
539
540fn reconstruct_beta_from_coefs(
542 coefs: &[f64],
543 b_mat: &DMatrix<f64>,
544 m: usize,
545 actual_nbasis: usize,
546) -> Vec<f64> {
547 let mut beta = vec![0.0; m];
548 for j in 0..m {
549 for k in 0..actual_nbasis {
550 beta[j] += coefs[k] * b_mat[(j, k)];
551 }
552 }
553 beta
554}
555
556fn compute_alpha_from_residuals(
558 q_aligned: &FdMatrix,
559 beta: &[f64],
560 weights: &[f64],
561 y: &[f64],
562) -> f64 {
563 let (n, m) = q_aligned.shape();
564 let mut alpha = 0.0;
565 for i in 0..n {
566 let mut y_hat_i = 0.0;
567 for j in 0..m {
568 y_hat_i += q_aligned[(i, j)] * beta[j] * weights[j];
569 }
570 alpha += y[i] - y_hat_i;
571 }
572 alpha / n as f64
573}
574
575fn srsf_fitted_values(q_aligned: &FdMatrix, beta: &[f64], weights: &[f64], alpha: f64) -> Vec<f64> {
577 let (n, m) = q_aligned.shape();
578 let mut fitted = vec![0.0; n];
579 for i in 0..n {
580 fitted[i] = alpha;
581 for j in 0..m {
582 fitted[i] += q_aligned[(i, j)] * beta[j] * weights[j];
583 }
584 }
585 fitted
586}
587
588fn beta_converged(beta_new: &[f64], beta_old: &[f64], tol: f64) -> bool {
590 let diff: f64 = beta_new
591 .iter()
592 .zip(beta_old.iter())
593 .map(|(&a, &b)| (a - b).powi(2))
594 .sum::<f64>()
595 .sqrt();
596 let norm: f64 = beta_old
597 .iter()
598 .map(|&b| b * b)
599 .sum::<f64>()
600 .sqrt()
601 .max(1e-10);
602 diff / norm < tol
603}
604
605fn center_warps(gammas: &mut FdMatrix, argvals: &[f64]) {
607 let (n, m) = gammas.shape();
608 let gam_mu = sqrt_mean_inverse(gammas, argvals);
609 for i in 0..n {
610 let gam_i: Vec<f64> = (0..m).map(|j| gammas[(i, j)]).collect();
611 let composed = crate::alignment::compose_warps(&gam_i, &gam_mu, argvals);
612 for j in 0..m {
613 gammas[(i, j)] = composed[j];
614 }
615 }
616}
617
618fn logistic_gradients(
620 q_aligned: &FdMatrix,
621 beta: &[f64],
622 weights: &[f64],
623 alpha: f64,
624 y: &[i8],
625 lambda: f64,
626) -> (f64, Vec<f64>, Vec<f64>) {
627 let (n, m) = q_aligned.shape();
628 let eta = srsf_fitted_values(q_aligned, beta, weights, alpha);
629 let prob: Vec<f64> = eta.iter().map(|&e| 1.0 / (1.0 + (-e).exp())).collect();
630
631 let mut grad_a = 0.0;
632 for i in 0..n {
633 let target = if y[i] == 1 { 1.0 } else { 0.0 };
634 grad_a += prob[i] - target;
635 }
636 grad_a /= n as f64;
637
638 let mut grad_beta = vec![0.0; m];
639 for j in 0..m {
640 for i in 0..n {
641 let target = if y[i] == 1 { 1.0 } else { 0.0 };
642 grad_beta[j] += (prob[i] - target) * q_aligned[(i, j)] * weights[j];
643 }
644 grad_beta[j] /= n as f64;
645 grad_beta[j] += lambda * beta[j];
646 }
647
648 (grad_a, grad_beta, prob)
649}
650
651fn ols_on_scores(
653 scores_mat: &FdMatrix,
654 y: &[f64],
655 n: usize,
656 ncomp: usize,
657) -> Option<(Vec<f64>, f64, Vec<f64>, f64, f64)> {
658 let y_mean: f64 = y.iter().sum::<f64>() / n as f64;
659 let mut score_means = vec![0.0; ncomp];
660 for k in 0..ncomp {
661 for i in 0..n {
662 score_means[k] += scores_mat[(i, k)];
663 }
664 score_means[k] /= n as f64;
665 }
666
667 let mut x_cen = DMatrix::zeros(n, ncomp);
668 for i in 0..n {
669 for k in 0..ncomp {
670 x_cen[(i, k)] = scores_mat[(i, k)] - score_means[k];
671 }
672 }
673 let y_cen: Vec<f64> = y.iter().map(|&yi| yi - y_mean).collect();
674 let y_vec = DVector::from_vec(y_cen);
675
676 let xtx = x_cen.transpose() * &x_cen;
677 let xty = x_cen.transpose() * &y_vec;
678 let coefficients = if let Some(chol) = xtx.clone().cholesky() {
679 chol.solve(&xty)
680 } else {
681 let svd = nalgebra::SVD::new(xtx, true, true);
682 svd.solve(&xty, 1e-10).ok()?
683 };
684 let coefs: Vec<f64> = coefficients.iter().cloned().collect();
685
686 let alpha = y_mean
687 - coefs
688 .iter()
689 .zip(score_means.iter())
690 .map(|(&c, &sm)| c * sm)
691 .sum::<f64>();
692
693 let mut fitted_values = vec![0.0; n];
694 for i in 0..n {
695 fitted_values[i] = alpha;
696 for k in 0..ncomp {
697 fitted_values[i] += coefs[k] * scores_mat[(i, k)];
698 }
699 }
700
701 let sse: f64 = y
702 .iter()
703 .zip(fitted_values.iter())
704 .map(|(&yi, &yh)| (yi - yh).powi(2))
705 .sum();
706 let ss_tot: f64 = y.iter().map(|&yi| (yi - y_mean).powi(2)).sum();
707 let r_squared = if ss_tot > 0.0 {
708 1.0 - sse / ss_tot
709 } else {
710 0.0
711 };
712
713 Some((coefs, alpha, fitted_values, sse, r_squared))
714}
715
716fn armijo_line_search_logistic(
718 q_aligned: &FdMatrix,
719 alpha: f64,
720 beta: &[f64],
721 grad_a: f64,
722 grad_beta: &[f64],
723 weights: &[f64],
724 y: &[i8],
725 lambda: f64,
726 loss_current: f64,
727 grad_norm_sq: f64,
728) -> f64 {
729 let mut step = 1.0;
730 for _ in 0..20 {
731 let alpha_trial = alpha - step * grad_a;
732 let beta_trial: Vec<f64> = beta
733 .iter()
734 .zip(grad_beta.iter())
735 .map(|(&b, &g)| b - step * g)
736 .collect();
737 let eta_trial = srsf_fitted_values(q_aligned, &beta_trial, weights, alpha_trial);
738 let prob_trial: Vec<f64> = eta_trial
739 .iter()
740 .map(|&e| 1.0 / (1.0 + (-e).exp()))
741 .collect();
742 let loss_trial = logistic_loss(&prob_trial, y, &beta_trial, lambda);
743 if loss_trial <= loss_current - 1e-4 * step * grad_norm_sq {
744 break;
745 }
746 step *= 0.5;
747 }
748 step
749}
750
751fn regression_iteration_step(
753 q_all: &FdMatrix,
754 gammas: &FdMatrix,
755 argvals: &[f64],
756 b_mat: &DMatrix<f64>,
757 r_trimmed: &DMatrix<f64>,
758 weights: &[f64],
759 y: &[f64],
760 alpha: f64,
761 lambda: f64,
762 n: usize,
763 m: usize,
764 actual_nbasis: usize,
765) -> Option<(Vec<f64>, f64)> {
766 let q_aligned = apply_warps_to_srsfs(q_all, gammas, argvals);
767 let phi = build_phi_matrix(&q_aligned, b_mat, weights, n, m, actual_nbasis);
768 let y_centered: Vec<f64> = y.iter().map(|&yi| yi - alpha).collect();
769 let coefs = solve_penalized_ols(&phi, r_trimmed, &y_centered, lambda)?;
770 let beta_new = reconstruct_beta_from_coefs(&coefs, b_mat, m, actual_nbasis);
771 let alpha_new = compute_alpha_from_residuals(&q_aligned, &beta_new, weights, y);
772 Some((beta_new, alpha_new))
773}
774
775fn build_basis_and_penalty(
777 argvals: &[f64],
778 ncomp_beta: usize,
779 m: usize,
780) -> (DMatrix<f64>, DMatrix<f64>, usize) {
781 let nknots = ncomp_beta.saturating_sub(4).max(2);
782 let basis_flat = bspline_basis(argvals, nknots, 4);
783 let actual_nbasis = basis_flat.len() / m;
784 let b_mat = DMatrix::from_column_slice(m, actual_nbasis, &basis_flat);
785
786 let penalty_flat = bspline_penalty_matrix(argvals, ncomp_beta, 4, 2);
787 let penalty_k = (penalty_flat.len() as f64).sqrt() as usize;
788 let r_mat = DMatrix::from_column_slice(penalty_k, penalty_k, &penalty_flat);
789 let r_trimmed = trim_penalty_to_basis(&r_mat, penalty_k, actual_nbasis);
790
791 (b_mat, r_trimmed, actual_nbasis)
792}
793
794fn update_regression_warps(
796 gammas: &mut FdMatrix,
797 q_all: &FdMatrix,
798 beta: &[f64],
799 argvals: &[f64],
800 alpha: f64,
801 y: &[f64],
802 lambda: f64,
803) {
804 let (n, m) = q_all.shape();
805 for i in 0..n {
806 let qi: Vec<f64> = (0..m).map(|j| q_all[(i, j)]).collect();
807 let new_gam = regression_warp(&qi, beta, argvals, alpha, y[i], lambda);
808 for j in 0..m {
809 gammas[(i, j)] = new_gam[j];
810 }
811 }
812}
813
814fn update_logistic_warps(
816 gammas: &mut FdMatrix,
817 q_all: &FdMatrix,
818 beta: &[f64],
819 y: &[i8],
820 argvals: &[f64],
821 lambda: f64,
822) {
823 let (n, m) = q_all.shape();
824 for i in 0..n {
825 let qi: Vec<f64> = (0..m).map(|j| q_all[(i, j)]).collect();
826 let beta_signed: Vec<f64> = beta.iter().map(|&b| b * y[i] as f64).collect();
827 let new_gam = dp_alignment_core(&beta_signed, &qi, argvals, lambda);
828 for j in 0..m {
829 gammas[(i, j)] = new_gam[j];
830 }
831 }
832}
833
834fn compute_regression_residuals(
836 y: &[f64],
837 fitted_values: &[f64],
838 y_mean: f64,
839) -> (Vec<f64>, f64, f64) {
840 let residuals: Vec<f64> = y
841 .iter()
842 .zip(fitted_values.iter())
843 .map(|(&yi, &yh)| yi - yh)
844 .collect();
845 let sse: f64 = residuals.iter().map(|&r| r * r).sum();
846 let ss_tot: f64 = y.iter().map(|&yi| (yi - y_mean).powi(2)).sum();
847 let r_squared = if ss_tot > 0.0 {
848 1.0 - sse / ss_tot
849 } else {
850 0.0
851 };
852 (residuals, sse, r_squared)
853}
854
855fn compute_logistic_predictions(
857 aligned_srsfs: &FdMatrix,
858 beta: &[f64],
859 weights: &[f64],
860 alpha: f64,
861 y: &[i8],
862 lambda: f64,
863) -> (Vec<f64>, Vec<i8>, f64, f64) {
864 let n = y.len();
865 let eta = srsf_fitted_values(aligned_srsfs, beta, weights, alpha);
866 let probabilities: Vec<f64> = eta.iter().map(|&e| 1.0 / (1.0 + (-e).exp())).collect();
867 let predicted_classes: Vec<i8> = probabilities
868 .iter()
869 .map(|&p| if p >= 0.5 { 1 } else { -1 })
870 .collect();
871 let accuracy = predicted_classes
872 .iter()
873 .zip(y.iter())
874 .filter(|(&p, &t)| p == t)
875 .count() as f64
876 / n as f64;
877 let loss = logistic_loss(&probabilities, y, beta, lambda);
878 (probabilities, predicted_classes, accuracy, loss)
879}
880
881fn check_extreme_warps(
883 gam_pos: &[f64],
884 gam_neg: &[f64],
885 y_pos: f64,
886 y_neg: f64,
887 y_i: f64,
888) -> Option<Vec<f64>> {
889 if (y_pos - y_i).abs() <= (y_neg - y_i).abs() {
890 if (y_pos - y_i).abs() < 1e-10 {
891 return Some(gam_pos.to_vec());
892 }
893 } else if (y_neg - y_i).abs() < 1e-10 {
894 return Some(gam_neg.to_vec());
895 }
896 None
897}
898
899fn order_warps_by_prediction(
901 gam_pos: Vec<f64>,
902 gam_neg: Vec<f64>,
903 y_pos: f64,
904 y_neg: f64,
905) -> (Vec<f64>, Vec<f64>) {
906 if y_pos < y_neg {
907 (gam_pos, gam_neg)
908 } else {
909 (gam_neg, gam_pos)
910 }
911}
912
913fn binary_search_warps(
915 mut gam_lo: Vec<f64>,
916 mut gam_hi: Vec<f64>,
917 q_i: &[f64],
918 beta: &[f64],
919 argvals: &[f64],
920 alpha: f64,
921 y_i: f64,
922 weights: &[f64],
923) -> Vec<f64> {
924 for _ in 0..15 {
925 let gam_mid: Vec<f64> = gam_lo
926 .iter()
927 .zip(gam_hi.iter())
928 .map(|(&lo, &hi)| 0.5 * (lo + hi))
929 .collect();
930 let y_mid = compute_predicted_y(q_i, beta, &gam_mid, argvals, alpha, weights);
931 if (y_mid - y_i).abs() < 1e-6 {
932 return gam_mid;
933 }
934 if y_mid < y_i {
935 gam_lo = gam_mid;
936 } else {
937 gam_hi = gam_mid;
938 }
939 }
940 gam_lo
941 .iter()
942 .zip(gam_hi.iter())
943 .map(|(&lo, &hi)| 0.5 * (lo + hi))
944 .collect()
945}
946
947pub fn predict_elastic_regression(
958 fit: &ElasticRegressionResult,
959 new_data: &FdMatrix,
960 argvals: &[f64],
961) -> Vec<f64> {
962 let weights = simpsons_weights(argvals);
963 let q_new = srsf_transform(new_data, argvals);
964 srsf_fitted_values(&q_new, &fit.beta, &weights, fit.alpha)
965}
966
967pub fn predict_elastic_logistic(
977 fit: &ElasticLogisticResult,
978 new_data: &FdMatrix,
979 argvals: &[f64],
980) -> Vec<f64> {
981 let weights = simpsons_weights(argvals);
982 let q_new = srsf_transform(new_data, argvals);
983 let eta = srsf_fitted_values(&q_new, &fit.beta, &weights, fit.alpha);
984 eta.iter().map(|&e| 1.0 / (1.0 + (-e).exp())).collect()
985}
986
987impl ElasticRegressionResult {
988 pub fn predict(&self, new_data: &FdMatrix, argvals: &[f64]) -> Vec<f64> {
990 predict_elastic_regression(self, new_data, argvals)
991 }
992}
993
994impl ElasticLogisticResult {
995 pub fn predict(&self, new_data: &FdMatrix, argvals: &[f64]) -> Vec<f64> {
997 predict_elastic_logistic(self, new_data, argvals)
998 }
999}
1000
1001#[cfg(test)]
1002mod tests {
1003 use super::*;
1004 use std::f64::consts::PI;
1005
1006 fn generate_test_data(n: usize, m: usize) -> (FdMatrix, Vec<f64>, Vec<f64>) {
1007 let t: Vec<f64> = (0..m).map(|j| j as f64 / (m - 1) as f64).collect();
1008 let mut data = FdMatrix::zeros(n, m);
1009 let mut y = vec![0.0; n];
1010
1011 for i in 0..n {
1012 let amp = 1.0 + 0.5 * (i as f64 / n as f64);
1013 let shift = 0.1 * (i as f64 - n as f64 / 2.0);
1014 for j in 0..m {
1015 data[(i, j)] = amp * (2.0 * PI * (t[j] + shift)).sin();
1016 }
1017 y[i] = amp; }
1019 (data, y, t)
1020 }
1021
1022 #[test]
1023 fn test_elastic_regression_basic() {
1024 let (data, y, t) = generate_test_data(15, 51);
1025 let result = elastic_regression(&data, &y, &t, 10, 1e-3, 5, 1e-3);
1026 assert!(result.is_some(), "elastic_regression should succeed");
1027
1028 let res = result.unwrap();
1029 assert_eq!(res.fitted_values.len(), 15);
1030 assert_eq!(res.beta.len(), 51);
1031 assert_eq!(res.gammas.shape(), (15, 51));
1032 assert!(res.n_iter > 0);
1033 }
1034
1035 #[test]
1036 fn test_elastic_logistic_basic() {
1037 let n = 20;
1038 let m = 51;
1039 let t: Vec<f64> = (0..m).map(|j| j as f64 / (m - 1) as f64).collect();
1040 let mut data = FdMatrix::zeros(n, m);
1041 let mut y = vec![0_i8; n];
1042
1043 for i in 0..n {
1044 let label = if i < n / 2 { -1_i8 } else { 1_i8 };
1045 y[i] = label;
1046 let amp = if label == 1 { 2.0 } else { 1.0 };
1047 for j in 0..m {
1048 data[(i, j)] = amp * (2.0 * PI * t[j]).sin();
1049 }
1050 }
1051
1052 let result = elastic_logistic(&data, &y, &t, 10, 1e-2, 5, 1e-3);
1053 assert!(result.is_some(), "elastic_logistic should succeed");
1054
1055 let res = result.unwrap();
1056 assert_eq!(res.probabilities.len(), n);
1057 assert_eq!(res.predicted_classes.len(), n);
1058 assert!(res.accuracy >= 0.0 && res.accuracy <= 1.0);
1059 }
1060
1061 #[test]
1062 fn test_elastic_pcr_vertical() {
1063 let (data, y, t) = generate_test_data(15, 51);
1064 let result = elastic_pcr(&data, &y, &t, 3, PcaMethod::Vertical, 0.0, 5, 1e-3);
1065 assert!(result.is_some(), "elastic_pcr (vertical) should succeed");
1066
1067 let res = result.unwrap();
1068 assert_eq!(res.fitted_values.len(), 15);
1069 assert_eq!(res.coefficients.len(), 3);
1070 }
1071
1072 #[test]
1073 fn test_elastic_pcr_horizontal() {
1074 let (data, y, t) = generate_test_data(15, 51);
1075 let result = elastic_pcr(&data, &y, &t, 3, PcaMethod::Horizontal, 0.0, 5, 1e-3);
1076 assert!(result.is_some(), "elastic_pcr (horizontal) should succeed");
1077 }
1078
1079 #[test]
1080 fn test_elastic_regression_invalid() {
1081 let data = FdMatrix::zeros(1, 10);
1082 let y = vec![1.0];
1083 let t: Vec<f64> = (0..10).map(|i| i as f64 / 9.0).collect();
1084 assert!(elastic_regression(&data, &y, &t, 5, 1e-3, 5, 1e-3).is_none());
1085 }
1086}