use crate::alignment::{dp_alignment_core, reparameterize_curve, sqrt_mean_inverse};
use crate::basis::bspline_basis;
use crate::helpers::simpsons_weights;
use crate::matrix::FdMatrix;
use crate::smooth_basis::bspline_penalty_matrix;
use nalgebra::{DMatrix, DVector};
use super::{
apply_warps_to_srsfs, beta_converged, init_identity_warps, srsf_fitted_values, ElasticConfig,
};
use crate::alignment::srsf_transform;
#[derive(Debug, Clone, PartialEq)]
#[non_exhaustive]
pub struct ElasticRegressionResult {
pub alpha: f64,
pub beta: Vec<f64>,
pub fitted_values: Vec<f64>,
pub residuals: Vec<f64>,
pub sse: f64,
pub r_squared: f64,
pub gammas: FdMatrix,
pub aligned_srsfs: FdMatrix,
pub n_iter: usize,
}
#[must_use = "expensive computation whose result should not be discarded"]
pub fn elastic_regression(
data: &FdMatrix,
y: &[f64],
argvals: &[f64],
ncomp_beta: usize,
lambda: f64,
max_iter: usize,
tol: f64,
) -> Result<ElasticRegressionResult, crate::FdarError> {
let (n, m) = data.shape();
if n < 2 || m < 2 || y.len() != n || argvals.len() != m || ncomp_beta < 2 {
return Err(crate::FdarError::InvalidDimension {
parameter: "data/y/argvals",
expected: "n >= 2, m >= 2, y.len() == n, argvals.len() == m, ncomp_beta >= 2"
.to_string(),
actual: format!(
"n={}, m={}, y.len()={}, argvals.len()={}, ncomp_beta={}",
n,
m,
y.len(),
argvals.len(),
ncomp_beta
),
});
}
let weights = simpsons_weights(argvals);
let q_all = srsf_transform(data, argvals);
let (b_mat, r_trimmed, actual_nbasis) = build_basis_and_penalty(argvals, ncomp_beta, m);
let mut gammas = init_identity_warps(n, argvals);
let y_mean: f64 = y.iter().sum::<f64>() / n as f64;
let mut beta = vec![0.0; m];
let mut alpha = y_mean;
let mut n_iter = 0;
for iter in 0..max_iter {
n_iter = iter + 1;
let (beta_new, alpha_new) = regression_iteration_step(
&q_all,
&gammas,
argvals,
&b_mat,
&r_trimmed,
&weights,
y,
alpha,
lambda,
n,
m,
actual_nbasis,
)
.ok_or_else(|| crate::FdarError::ComputationFailed {
operation: "regression_iteration",
detail: format!(
"iteration {} failed; try increasing lambda or reducing nbasis",
iter + 1
),
})?;
if beta_converged(&beta_new, &beta, tol) && iter > 0 {
beta = beta_new;
alpha = alpha_new;
break;
}
beta = beta_new;
alpha = alpha_new;
update_regression_warps(&mut gammas, &q_all, &beta, argvals, alpha, y, lambda * 0.01);
center_warps(&mut gammas, argvals);
}
let aligned_srsfs = apply_warps_to_srsfs(&q_all, &gammas, argvals);
let fitted_values = srsf_fitted_values(&aligned_srsfs, &beta, &weights, alpha);
let (residuals, sse, r_squared) = compute_regression_residuals(y, &fitted_values, y_mean);
Ok(ElasticRegressionResult {
alpha,
beta,
fitted_values,
residuals,
sse,
r_squared,
gammas,
aligned_srsfs,
n_iter,
})
}
#[must_use = "expensive computation whose result should not be discarded"]
pub fn elastic_regression_with_config(
data: &FdMatrix,
y: &[f64],
argvals: &[f64],
config: &ElasticConfig,
) -> Result<ElasticRegressionResult, crate::FdarError> {
elastic_regression(
data,
y,
argvals,
config.ncomp_beta,
config.lambda,
config.max_iter,
config.tol,
)
}
pub fn predict_elastic_regression(
fit: &ElasticRegressionResult,
new_data: &FdMatrix,
argvals: &[f64],
) -> Vec<f64> {
let weights = simpsons_weights(argvals);
let q_new = srsf_transform(new_data, argvals);
srsf_fitted_values(&q_new, &fit.beta, &weights, fit.alpha)
}
impl ElasticRegressionResult {
pub fn predict(&self, new_data: &FdMatrix, argvals: &[f64]) -> Vec<f64> {
predict_elastic_regression(self, new_data, argvals)
}
}
fn regression_warp(
q_i: &[f64],
beta: &[f64],
argvals: &[f64],
alpha: f64,
y_i: f64,
lambda: f64,
) -> Vec<f64> {
let weights = simpsons_weights(argvals);
let gam_pos = dp_alignment_core(beta, q_i, argvals, lambda);
let neg_beta: Vec<f64> = beta.iter().map(|&b| -b).collect();
let gam_neg = dp_alignment_core(&neg_beta, q_i, argvals, lambda);
let y_pos = compute_predicted_y(q_i, beta, &gam_pos, argvals, alpha, &weights);
let y_neg = compute_predicted_y(q_i, beta, &gam_neg, argvals, alpha, &weights);
if let Some(gam) = check_extreme_warps(&gam_pos, &gam_neg, y_pos, y_neg, y_i) {
return gam;
}
let (gam_lo, gam_hi) = order_warps_by_prediction(gam_pos, gam_neg, y_pos, y_neg);
binary_search_warps(gam_lo, gam_hi, q_i, beta, argvals, alpha, y_i, &weights)
}
fn compute_predicted_y(
q_i: &[f64],
beta: &[f64],
gam: &[f64],
argvals: &[f64],
alpha: f64,
weights: &[f64],
) -> f64 {
let m = argvals.len();
let q_warped = reparameterize_curve(q_i, argvals, gam);
let h = (argvals[m - 1] - argvals[0]) / (m - 1) as f64;
let gam_deriv = crate::helpers::gradient_uniform(gam, h);
let mut y_hat = alpha;
for j in 0..m {
let q_aligned_j = q_warped[j] * gam_deriv[j].max(0.0).sqrt();
y_hat += q_aligned_j * beta[j] * weights[j];
}
y_hat
}
fn build_basis_and_penalty(
argvals: &[f64],
ncomp_beta: usize,
m: usize,
) -> (DMatrix<f64>, DMatrix<f64>, usize) {
let nknots = ncomp_beta.saturating_sub(4).max(2);
let basis_flat = bspline_basis(argvals, nknots, 4);
let actual_nbasis = basis_flat.len() / m;
let b_mat = DMatrix::from_column_slice(m, actual_nbasis, &basis_flat);
let penalty_flat = bspline_penalty_matrix(argvals, ncomp_beta, 4, 2);
let penalty_k = (penalty_flat.len() as f64).sqrt() as usize;
let r_mat = DMatrix::from_column_slice(penalty_k, penalty_k, &penalty_flat);
let r_trimmed = trim_penalty_to_basis(&r_mat, penalty_k, actual_nbasis);
(b_mat, r_trimmed, actual_nbasis)
}
fn trim_penalty_to_basis(
r_mat: &DMatrix<f64>,
penalty_k: usize,
actual_nbasis: usize,
) -> DMatrix<f64> {
if penalty_k >= actual_nbasis {
r_mat
.view((0, 0), (actual_nbasis, actual_nbasis))
.into_owned()
} else {
let mut r = DMatrix::zeros(actual_nbasis, actual_nbasis);
let dim = penalty_k.min(actual_nbasis);
for i in 0..dim {
for j in 0..dim {
r[(i, j)] = r_mat[(i, j)];
}
}
r
}
}
fn build_phi_matrix(
q_aligned: &FdMatrix,
b_mat: &DMatrix<f64>,
weights: &[f64],
n: usize,
m: usize,
actual_nbasis: usize,
) -> DMatrix<f64> {
let mut phi = DMatrix::zeros(n, actual_nbasis);
for i in 0..n {
for k in 0..actual_nbasis {
let mut val = 0.0;
for j in 0..m {
val += q_aligned[(i, j)] * b_mat[(j, k)] * weights[j];
}
phi[(i, k)] = val;
}
}
phi
}
pub(super) fn solve_penalized_ols(
phi: &DMatrix<f64>,
r_trimmed: &DMatrix<f64>,
y_centered: &[f64],
lambda: f64,
) -> Option<Vec<f64>> {
let y_vec = DVector::from_vec(y_centered.to_vec());
let phi_t_phi = phi.transpose() * phi;
let system = &phi_t_phi + lambda * r_trimmed;
let rhs = phi.transpose() * &y_vec;
let coefs = if let Some(chol) = system.clone().cholesky() {
chol.solve(&rhs)
} else {
let svd = nalgebra::SVD::new(system, true, true);
svd.solve(&rhs, 1e-10).ok()?
};
Some(coefs.iter().copied().collect())
}
fn reconstruct_beta_from_coefs(
coefs: &[f64],
b_mat: &DMatrix<f64>,
m: usize,
actual_nbasis: usize,
) -> Vec<f64> {
let mut beta = vec![0.0; m];
for j in 0..m {
for k in 0..actual_nbasis {
beta[j] += coefs[k] * b_mat[(j, k)];
}
}
beta
}
fn compute_alpha_from_residuals(
q_aligned: &FdMatrix,
beta: &[f64],
weights: &[f64],
y: &[f64],
) -> f64 {
let (n, m) = q_aligned.shape();
let mut alpha = 0.0;
for i in 0..n {
let mut y_hat_i = 0.0;
for j in 0..m {
y_hat_i += q_aligned[(i, j)] * beta[j] * weights[j];
}
alpha += y[i] - y_hat_i;
}
alpha / n as f64
}
fn regression_iteration_step(
q_all: &FdMatrix,
gammas: &FdMatrix,
argvals: &[f64],
b_mat: &DMatrix<f64>,
r_trimmed: &DMatrix<f64>,
weights: &[f64],
y: &[f64],
alpha: f64,
lambda: f64,
n: usize,
m: usize,
actual_nbasis: usize,
) -> Option<(Vec<f64>, f64)> {
let q_aligned = apply_warps_to_srsfs(q_all, gammas, argvals);
let phi = build_phi_matrix(&q_aligned, b_mat, weights, n, m, actual_nbasis);
let y_centered: Vec<f64> = y.iter().map(|&yi| yi - alpha).collect();
let coefs = solve_penalized_ols(&phi, r_trimmed, &y_centered, lambda)?;
let beta_new = reconstruct_beta_from_coefs(&coefs, b_mat, m, actual_nbasis);
let alpha_new = compute_alpha_from_residuals(&q_aligned, &beta_new, weights, y);
Some((beta_new, alpha_new))
}
fn update_regression_warps(
gammas: &mut FdMatrix,
q_all: &FdMatrix,
beta: &[f64],
argvals: &[f64],
alpha: f64,
y: &[f64],
lambda: f64,
) {
let (n, m) = q_all.shape();
for i in 0..n {
let qi: Vec<f64> = (0..m).map(|j| q_all[(i, j)]).collect();
let new_gam = regression_warp(&qi, beta, argvals, alpha, y[i], lambda);
for j in 0..m {
gammas[(i, j)] = new_gam[j];
}
}
}
fn center_warps(gammas: &mut FdMatrix, argvals: &[f64]) {
let (n, m) = gammas.shape();
let gam_mu = sqrt_mean_inverse(gammas, argvals);
for i in 0..n {
let gam_i: Vec<f64> = (0..m).map(|j| gammas[(i, j)]).collect();
let composed = crate::alignment::compose_warps(&gam_i, &gam_mu, argvals);
for j in 0..m {
gammas[(i, j)] = composed[j];
}
}
}
fn compute_regression_residuals(
y: &[f64],
fitted_values: &[f64],
y_mean: f64,
) -> (Vec<f64>, f64, f64) {
let residuals: Vec<f64> = y
.iter()
.zip(fitted_values.iter())
.map(|(&yi, &yh)| yi - yh)
.collect();
let sse: f64 = residuals.iter().map(|&r| r * r).sum();
let ss_tot: f64 = y.iter().map(|&yi| (yi - y_mean).powi(2)).sum();
let r_squared = if ss_tot > 0.0 {
1.0 - sse / ss_tot
} else {
0.0
};
(residuals, sse, r_squared)
}
fn check_extreme_warps(
gam_pos: &[f64],
gam_neg: &[f64],
y_pos: f64,
y_neg: f64,
y_i: f64,
) -> Option<Vec<f64>> {
if (y_pos - y_i).abs() <= (y_neg - y_i).abs() {
if (y_pos - y_i).abs() < 1e-10 {
return Some(gam_pos.to_vec());
}
} else if (y_neg - y_i).abs() < 1e-10 {
return Some(gam_neg.to_vec());
}
None
}
fn order_warps_by_prediction(
gam_pos: Vec<f64>,
gam_neg: Vec<f64>,
y_pos: f64,
y_neg: f64,
) -> (Vec<f64>, Vec<f64>) {
if y_pos < y_neg {
(gam_pos, gam_neg)
} else {
(gam_neg, gam_pos)
}
}
fn binary_search_warps(
mut gam_lo: Vec<f64>,
mut gam_hi: Vec<f64>,
q_i: &[f64],
beta: &[f64],
argvals: &[f64],
alpha: f64,
y_i: f64,
weights: &[f64],
) -> Vec<f64> {
for _ in 0..15 {
let gam_mid: Vec<f64> = gam_lo
.iter()
.zip(gam_hi.iter())
.map(|(&lo, &hi)| 0.5 * (lo + hi))
.collect();
let y_mid = compute_predicted_y(q_i, beta, &gam_mid, argvals, alpha, weights);
if (y_mid - y_i).abs() < 1e-6 {
return gam_mid;
}
if y_mid < y_i {
gam_lo = gam_mid;
} else {
gam_hi = gam_mid;
}
}
gam_lo
.iter()
.zip(gam_hi.iter())
.map(|(&lo, &hi)| 0.5 * (lo + hi))
.collect()
}