use super::pairwise::elastic_align_pair;
use super::srsf::compose_warps;
use super::KarcherMeanResult;
use crate::helpers::{gradient_uniform, l2_distance, simpsons_weights};
use crate::matrix::FdMatrix;
#[derive(Debug, Clone, PartialEq)]
pub struct AlignmentQuality {
pub warp_complexity: Vec<f64>,
pub mean_warp_complexity: f64,
pub warp_smoothness: Vec<f64>,
pub mean_warp_smoothness: f64,
pub total_variance: f64,
pub amplitude_variance: f64,
pub phase_variance: f64,
pub phase_amplitude_ratio: f64,
pub pointwise_variance_ratio: Vec<f64>,
pub mean_variance_reduction: f64,
}
pub fn warp_complexity(gamma: &[f64], argvals: &[f64]) -> f64 {
crate::warping::phase_distance(gamma, argvals)
}
pub fn warp_smoothness(gamma: &[f64], argvals: &[f64]) -> f64 {
let m = gamma.len();
if m < 3 {
return 0.0;
}
let h = (argvals[m - 1] - argvals[0]) / (m - 1) as f64;
let gam_prime = gradient_uniform(gamma, h);
let gam_pprime = gradient_uniform(&gam_prime, h);
let integrand: Vec<f64> = gam_pprime.iter().map(|&g| g * g).collect();
crate::helpers::trapz(&integrand, argvals)
}
pub fn alignment_quality(
data: &FdMatrix,
karcher: &KarcherMeanResult,
argvals: &[f64],
) -> AlignmentQuality {
let (n, m) = data.shape();
let weights = simpsons_weights(argvals);
let wc: Vec<f64> = (0..n)
.map(|i| {
let gamma: Vec<f64> = (0..m).map(|j| karcher.gammas[(i, j)]).collect();
warp_complexity(&gamma, argvals)
})
.collect();
let ws: Vec<f64> = (0..n)
.map(|i| {
let gamma: Vec<f64> = (0..m).map(|j| karcher.gammas[(i, j)]).collect();
warp_smoothness(&gamma, argvals)
})
.collect();
let mean_wc = wc.iter().sum::<f64>() / n as f64;
let mean_ws = ws.iter().sum::<f64>() / n as f64;
let orig_mean = crate::fdata::mean_1d(data);
let total_var: f64 = (0..n)
.map(|i| {
let fi = data.row(i);
let d = l2_distance(&fi, &orig_mean, &weights);
d * d
})
.sum::<f64>()
/ n as f64;
let aligned_mean = crate::fdata::mean_1d(&karcher.aligned_data);
let amp_var: f64 = (0..n)
.map(|i| {
let fi = karcher.aligned_data.row(i);
let d = l2_distance(&fi, &aligned_mean, &weights);
d * d
})
.sum::<f64>()
/ n as f64;
let phase_var = (total_var - amp_var).max(0.0);
let ratio = if total_var > 1e-10 {
phase_var / total_var
} else {
0.0
};
let mut pw_ratio = vec![0.0; m];
for j in 0..m {
let col_orig = data.column(j);
let mean_orig_j = col_orig.iter().sum::<f64>() / n as f64;
let var_orig: f64 = col_orig
.iter()
.map(|&v| (v - mean_orig_j).powi(2))
.sum::<f64>()
/ n as f64;
let col_aligned = karcher.aligned_data.column(j);
let mean_aligned_j = col_aligned.iter().sum::<f64>() / n as f64;
let var_aligned: f64 = col_aligned
.iter()
.map(|&v| (v - mean_aligned_j).powi(2))
.sum::<f64>()
/ n as f64;
pw_ratio[j] = if var_orig > 1e-15 {
var_aligned / var_orig
} else {
1.0
};
}
let mean_vr = pw_ratio.iter().sum::<f64>() / m as f64;
AlignmentQuality {
warp_complexity: wc,
mean_warp_complexity: mean_wc,
warp_smoothness: ws,
mean_warp_smoothness: mean_ws,
total_variance: total_var,
amplitude_variance: amp_var,
phase_variance: phase_var,
phase_amplitude_ratio: ratio,
pointwise_variance_ratio: pw_ratio,
mean_variance_reduction: mean_vr,
}
}
fn triplet_indices(n: usize, max_triplets: usize) -> Vec<(usize, usize, usize)> {
let total = n * (n - 1) * (n - 2) / 6;
let cap = if max_triplets > 0 {
max_triplets.min(total)
} else {
total
};
(0..n)
.flat_map(|i| ((i + 1)..n).flat_map(move |j| ((j + 1)..n).map(move |k| (i, j, k))))
.take(cap)
.collect()
}
fn triplet_warp_deviation(
data: &FdMatrix,
argvals: &[f64],
weights: &[f64],
i: usize,
j: usize,
k: usize,
lambda: f64,
) -> f64 {
let fi = data.row(i);
let fj = data.row(j);
let fk = data.row(k);
let rij = elastic_align_pair(&fi, &fj, argvals, lambda);
let rjk = elastic_align_pair(&fj, &fk, argvals, lambda);
let rik = elastic_align_pair(&fi, &fk, argvals, lambda);
let composed = compose_warps(&rij.gamma, &rjk.gamma, argvals);
l2_distance(&composed, &rik.gamma, weights)
}
pub fn pairwise_consistency(
data: &FdMatrix,
argvals: &[f64],
lambda: f64,
max_triplets: usize,
) -> f64 {
let n = data.nrows();
if n < 3 {
return 0.0;
}
let weights = simpsons_weights(argvals);
let triplets = triplet_indices(n, max_triplets);
if triplets.is_empty() {
return 0.0;
}
let total_dev: f64 = triplets
.iter()
.map(|&(i, j, k)| triplet_warp_deviation(data, argvals, &weights, i, j, k, lambda))
.sum();
total_dev / triplets.len() as f64
}