use crate::alignment::{align_to_target, karcher_mean};
use crate::error::FdarError;
use crate::matrix::FdMatrix;
use super::phase::{spm_monitor, spm_phase1, SpmChart, SpmConfig, SpmMonitorResult};
#[derive(Debug, Clone, PartialEq)]
pub struct ElasticSpmConfig {
pub spm: SpmConfig,
pub align_lambda: f64,
pub monitor_phase: bool,
pub warp_ncomp: usize,
pub max_karcher_iterations: usize,
pub karcher_tolerance: f64,
}
impl Default for ElasticSpmConfig {
fn default() -> Self {
Self {
spm: SpmConfig::default(),
align_lambda: 0.0,
monitor_phase: true,
warp_ncomp: 3,
max_karcher_iterations: 20,
karcher_tolerance: 1e-4,
}
}
}
#[derive(Debug, Clone, PartialEq)]
#[non_exhaustive]
pub struct ElasticSpmChart {
pub karcher_mean: Vec<f64>,
pub amplitude_chart: SpmChart,
pub phase_chart: Option<SpmChart>,
pub config: ElasticSpmConfig,
pub mean_alignment_residual: f64,
}
#[derive(Debug, Clone, PartialEq)]
#[non_exhaustive]
pub struct ElasticSpmMonitorResult {
pub amplitude: SpmMonitorResult,
pub phase: Option<SpmMonitorResult>,
pub aligned_data: FdMatrix,
pub warping_functions: FdMatrix,
}
#[must_use = "expensive computation whose result should not be discarded"]
pub fn elastic_spm_phase1(
data: &FdMatrix,
argvals: &[f64],
config: &ElasticSpmConfig,
) -> Result<ElasticSpmChart, FdarError> {
let (n, m) = data.shape();
if n < 4 {
return Err(FdarError::InvalidDimension {
parameter: "data",
expected: "at least 4 observations".to_string(),
actual: format!("{n} observations"),
});
}
if argvals.len() != m {
return Err(FdarError::InvalidDimension {
parameter: "argvals",
expected: format!("{m}"),
actual: format!("{}", argvals.len()),
});
}
if config.monitor_phase && config.warp_ncomp < 1 {
return Err(FdarError::InvalidParameter {
parameter: "warp_ncomp",
message: "warp_ncomp must be >= 1 when monitor_phase is true".to_string(),
});
}
if config.align_lambda < 0.0 || config.align_lambda > 1.0 {
return Err(FdarError::InvalidParameter {
parameter: "align_lambda",
message: format!(
"align_lambda must be in [0, 1], got {}",
config.align_lambda
),
});
}
let km_result = karcher_mean(
data,
argvals,
config.max_karcher_iterations,
config.karcher_tolerance,
config.align_lambda,
);
let alignment = align_to_target(data, &km_result.mean, argvals, config.align_lambda);
let (n_aligned, m_aligned) = alignment.aligned_data.shape();
let mean_alignment_residual = if m_aligned > 0 && n_aligned > 0 {
let mut total_residual = 0.0;
for i in 0..n_aligned {
let mut sq_diff = 0.0;
for j in 0..m_aligned {
let d = alignment.aligned_data[(i, j)] - km_result.mean[j];
sq_diff += d * d;
}
total_residual += sq_diff / m_aligned as f64;
}
total_residual / n_aligned as f64
} else {
0.0
};
let amplitude_chart = spm_phase1(&alignment.aligned_data, argvals, &config.spm)?;
let phase_chart = if config.monitor_phase {
let phase_config = SpmConfig {
ncomp: config.warp_ncomp,
alpha: config.spm.alpha,
tuning_fraction: config.spm.tuning_fraction,
seed: config.spm.seed.wrapping_add(1),
};
Some(spm_phase1(&alignment.gammas, argvals, &phase_config)?)
} else {
None
};
Ok(ElasticSpmChart {
karcher_mean: km_result.mean,
amplitude_chart,
phase_chart,
config: config.clone(),
mean_alignment_residual,
})
}
#[must_use = "monitoring result should not be discarded"]
pub fn elastic_spm_monitor(
chart: &ElasticSpmChart,
new_data: &FdMatrix,
argvals: &[f64],
) -> Result<ElasticSpmMonitorResult, FdarError> {
let m = chart.karcher_mean.len();
if new_data.ncols() != m {
return Err(FdarError::InvalidDimension {
parameter: "new_data",
expected: format!("{m} columns"),
actual: format!("{} columns", new_data.ncols()),
});
}
let alignment = align_to_target(
new_data,
&chart.karcher_mean,
argvals,
chart.config.align_lambda,
);
let amplitude = spm_monitor(&chart.amplitude_chart, &alignment.aligned_data, argvals)?;
let phase = if let Some(ref phase_chart) = chart.phase_chart {
Some(spm_monitor(phase_chart, &alignment.gammas, argvals)?)
} else {
None
};
Ok(ElasticSpmMonitorResult {
amplitude,
phase,
aligned_data: alignment.aligned_data,
warping_functions: alignment.gammas,
})
}