use crate::alignment::{dp_alignment_core, srsf_transform};
use crate::helpers::simpsons_weights;
use crate::matrix::FdMatrix;
use super::{
apply_warps_to_srsfs, beta_converged, init_identity_warps, srsf_fitted_values, ElasticConfig,
};
#[derive(Debug, Clone, PartialEq)]
#[non_exhaustive]
pub struct ElasticLogisticResult {
pub alpha: f64,
pub beta: Vec<f64>,
pub probabilities: Vec<f64>,
pub predicted_classes: Vec<usize>,
pub accuracy: f64,
pub loss: f64,
pub gammas: FdMatrix,
pub aligned_srsfs: FdMatrix,
pub n_iter: usize,
}
#[must_use = "expensive computation whose result should not be discarded"]
pub fn elastic_logistic(
data: &FdMatrix,
y: &[i8],
argvals: &[f64],
_ncomp_beta: usize,
lambda: f64,
max_iter: usize,
tol: f64,
) -> Result<ElasticLogisticResult, crate::FdarError> {
let (n, m) = data.shape();
if n < 2 || m < 2 || y.len() != n || argvals.len() != m {
return Err(crate::FdarError::InvalidDimension {
parameter: "data/y/argvals",
expected: "n >= 2, m >= 2, y.len() == n, argvals.len() == m".to_string(),
actual: format!(
"n={}, m={}, y.len()={}, argvals.len()={}",
n,
m,
y.len(),
argvals.len()
),
});
}
let weights = simpsons_weights(argvals);
let q_all = srsf_transform(data, argvals);
let mut gammas = init_identity_warps(n, argvals);
let mut beta = vec![0.0; m];
let mut alpha = 0.0;
let mut n_iter = 0;
for iter in 0..max_iter {
n_iter = iter + 1;
let q_aligned = apply_warps_to_srsfs(&q_all, &gammas, argvals);
let (grad_a, grad_beta, prob) =
logistic_gradients(&q_aligned, &beta, &weights, alpha, y, lambda);
let loss_current = logistic_loss(&prob, y, &beta, lambda);
let grad_norm_sq: f64 = grad_a * grad_a + grad_beta.iter().map(|&g| g * g).sum::<f64>();
let step = armijo_line_search_logistic(
&q_aligned,
alpha,
&beta,
grad_a,
&grad_beta,
&weights,
y,
lambda,
loss_current,
grad_norm_sq,
);
let beta_new: Vec<f64> = beta
.iter()
.zip(grad_beta.iter())
.map(|(&b, &g)| b - step * g)
.collect();
let alpha_new = alpha - step * grad_a;
if beta_converged(&beta_new, &beta, tol) && iter > 0 {
beta = beta_new;
alpha = alpha_new;
break;
}
beta = beta_new;
alpha = alpha_new;
update_logistic_warps(&mut gammas, &q_all, &beta, y, argvals, lambda * 0.01);
}
let aligned_srsfs = apply_warps_to_srsfs(&q_all, &gammas, argvals);
let (probabilities, predicted_classes, accuracy, loss) =
compute_logistic_predictions(&aligned_srsfs, &beta, &weights, alpha, y, lambda);
Ok(ElasticLogisticResult {
alpha,
beta,
probabilities,
predicted_classes,
accuracy,
loss,
gammas,
aligned_srsfs,
n_iter,
})
}
#[must_use = "expensive computation whose result should not be discarded"]
pub fn elastic_logistic_with_config(
data: &FdMatrix,
y: &[i8],
argvals: &[f64],
config: &ElasticConfig,
) -> Result<ElasticLogisticResult, crate::FdarError> {
elastic_logistic(
data,
y,
argvals,
config.ncomp_beta,
config.lambda,
config.max_iter,
config.tol,
)
}
pub fn predict_elastic_logistic(
fit: &ElasticLogisticResult,
new_data: &FdMatrix,
argvals: &[f64],
) -> Vec<f64> {
let weights = simpsons_weights(argvals);
let q_new = srsf_transform(new_data, argvals);
let eta = srsf_fitted_values(&q_new, &fit.beta, &weights, fit.alpha);
eta.iter().map(|&e| 1.0 / (1.0 + (-e).exp())).collect()
}
impl ElasticLogisticResult {
pub fn predict(&self, new_data: &FdMatrix, argvals: &[f64]) -> Vec<f64> {
predict_elastic_logistic(self, new_data, argvals)
}
}
fn logistic_loss(prob: &[f64], y: &[i8], beta: &[f64], lambda: f64) -> f64 {
let n = prob.len();
let mut loss = 0.0;
for i in 0..n {
let target = if y[i] == 1 { 1.0 } else { 0.0 };
let p = prob[i].clamp(1e-15, 1.0 - 1e-15);
loss -= target * p.ln() + (1.0 - target) * (1.0 - p).ln();
}
loss /= n as f64;
loss += 0.5 * lambda * beta.iter().map(|&b| b * b).sum::<f64>();
loss
}
fn logistic_gradients(
q_aligned: &FdMatrix,
beta: &[f64],
weights: &[f64],
alpha: f64,
y: &[i8],
lambda: f64,
) -> (f64, Vec<f64>, Vec<f64>) {
let (n, m) = q_aligned.shape();
let eta = srsf_fitted_values(q_aligned, beta, weights, alpha);
let prob: Vec<f64> = eta.iter().map(|&e| 1.0 / (1.0 + (-e).exp())).collect();
let mut grad_a = 0.0;
for i in 0..n {
let target = if y[i] == 1 { 1.0 } else { 0.0 };
grad_a += prob[i] - target;
}
grad_a /= n as f64;
let mut grad_beta = vec![0.0; m];
for j in 0..m {
for i in 0..n {
let target = if y[i] == 1 { 1.0 } else { 0.0 };
grad_beta[j] += (prob[i] - target) * q_aligned[(i, j)] * weights[j];
}
grad_beta[j] /= n as f64;
grad_beta[j] += lambda * beta[j];
}
(grad_a, grad_beta, prob)
}
fn armijo_line_search_logistic(
q_aligned: &FdMatrix,
alpha: f64,
beta: &[f64],
grad_a: f64,
grad_beta: &[f64],
weights: &[f64],
y: &[i8],
lambda: f64,
loss_current: f64,
grad_norm_sq: f64,
) -> f64 {
let mut step = 1.0;
for _ in 0..20 {
let alpha_trial = alpha - step * grad_a;
let beta_trial: Vec<f64> = beta
.iter()
.zip(grad_beta.iter())
.map(|(&b, &g)| b - step * g)
.collect();
let eta_trial = srsf_fitted_values(q_aligned, &beta_trial, weights, alpha_trial);
let prob_trial: Vec<f64> = eta_trial
.iter()
.map(|&e| 1.0 / (1.0 + (-e).exp()))
.collect();
let loss_trial = logistic_loss(&prob_trial, y, &beta_trial, lambda);
if loss_trial <= loss_current - 1e-4 * step * grad_norm_sq {
break;
}
step *= 0.5;
}
step
}
fn update_logistic_warps(
gammas: &mut FdMatrix,
q_all: &FdMatrix,
beta: &[f64],
y: &[i8],
argvals: &[f64],
lambda: f64,
) {
let (n, m) = q_all.shape();
for i in 0..n {
let qi: Vec<f64> = (0..m).map(|j| q_all[(i, j)]).collect();
let beta_signed: Vec<f64> = beta.iter().map(|&b| b * f64::from(y[i])).collect();
let new_gam = dp_alignment_core(&beta_signed, &qi, argvals, lambda);
for j in 0..m {
gammas[(i, j)] = new_gam[j];
}
}
}
fn compute_logistic_predictions(
aligned_srsfs: &FdMatrix,
beta: &[f64],
weights: &[f64],
alpha: f64,
y: &[i8],
lambda: f64,
) -> (Vec<f64>, Vec<usize>, f64, f64) {
let n = y.len();
let eta = srsf_fitted_values(aligned_srsfs, beta, weights, alpha);
let probabilities: Vec<f64> = eta.iter().map(|&e| 1.0 / (1.0 + (-e).exp())).collect();
let predicted_classes: Vec<usize> = probabilities
.iter()
.map(|&p| usize::from(p >= 0.5))
.collect();
let accuracy = predicted_classes
.iter()
.zip(y.iter())
.filter(|(&p, &t)| p == usize::from(t == 1))
.count() as f64
/ n as f64;
let loss = logistic_loss(&probabilities, y, beta, lambda);
(probabilities, predicted_classes, accuracy, loss)
}