use super::karcher::karcher_mean;
use super::pairwise::{elastic_align_pair, elastic_distance};
use super::srsf::{compose_warps, reparameterize_curve};
use crate::error::FdarError;
use crate::iter_maybe_parallel;
use crate::matrix::FdMatrix;
#[cfg(feature = "parallel")]
use rayon::iter::ParallelIterator;
#[derive(Debug, Clone, PartialEq)]
pub struct TransferAlignConfig {
pub lambda: f64,
pub max_iter: usize,
pub tol: f64,
}
impl Default for TransferAlignConfig {
fn default() -> Self {
Self {
lambda: 0.0,
max_iter: 15,
tol: 1e-3,
}
}
}
#[derive(Debug, Clone, PartialEq)]
#[non_exhaustive]
pub struct TransferAlignResult {
pub source_mean: Vec<f64>,
pub aligned_data: FdMatrix,
pub gammas: FdMatrix,
pub bridging_gamma: Vec<f64>,
pub distances: Vec<f64>,
}
#[must_use = "expensive computation whose result should not be discarded"]
pub fn transfer_alignment(
source_data: &FdMatrix,
target_data: &FdMatrix,
argvals: &[f64],
config: &TransferAlignConfig,
) -> Result<TransferAlignResult, FdarError> {
let (n_source, m_source) = source_data.shape();
let (n_target, m_target) = target_data.shape();
if m_source != m_target {
return Err(FdarError::InvalidDimension {
parameter: "target_data",
expected: format!("{m_source} columns (matching source_data)"),
actual: format!("{m_target} columns"),
});
}
let m = m_source;
if argvals.len() != m {
return Err(FdarError::InvalidDimension {
parameter: "argvals",
expected: format!("{m}"),
actual: format!("{}", argvals.len()),
});
}
if n_source < 1 {
return Err(FdarError::InvalidDimension {
parameter: "source_data",
expected: "at least 1 row".to_string(),
actual: format!("{n_source} rows"),
});
}
if n_target < 1 {
return Err(FdarError::InvalidDimension {
parameter: "target_data",
expected: "at least 1 row".to_string(),
actual: format!("{n_target} rows"),
});
}
let source_karcher = karcher_mean(
source_data,
argvals,
config.max_iter,
config.tol,
config.lambda,
);
let target_karcher = karcher_mean(
target_data,
argvals,
config.max_iter,
config.tol,
config.lambda,
);
let bridge_result = elastic_align_pair(
&source_karcher.mean,
&target_karcher.mean,
argvals,
config.lambda,
);
let results: Vec<(Vec<f64>, Vec<f64>, f64)> = iter_maybe_parallel!(0..n_target)
.map(|i| {
let within_gamma = target_karcher.gammas.row(i);
let gamma_total = compose_warps(&bridge_result.gamma, &within_gamma, argvals);
let aligned_i = reparameterize_curve(&target_data.row(i), argvals, &gamma_total);
let dist_i = elastic_distance(&source_karcher.mean, &aligned_i, argvals, config.lambda);
(gamma_total, aligned_i, dist_i)
})
.collect();
let mut gammas = FdMatrix::zeros(n_target, m);
let mut aligned_data = FdMatrix::zeros(n_target, m);
let mut distances = Vec::with_capacity(n_target);
for (i, (gamma, aligned, dist)) in results.into_iter().enumerate() {
for j in 0..m {
gammas[(i, j)] = gamma[j];
aligned_data[(i, j)] = aligned[j];
}
distances.push(dist);
}
Ok(TransferAlignResult {
source_mean: source_karcher.mean,
aligned_data,
gammas,
bridging_gamma: bridge_result.gamma,
distances,
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::simulation::{sim_fundata, EFunType, EValType};
use crate::test_helpers::uniform_grid;
fn make_data(n: usize, m: usize, seed: u64) -> (FdMatrix, Vec<f64>) {
let t = uniform_grid(m);
let data = sim_fundata(
n,
&t,
3,
EFunType::Fourier,
EValType::Exponential,
Some(seed),
);
(data, t)
}
#[test]
fn transfer_same_population() {
let (data, t) = make_data(8, 20, 42);
let config = TransferAlignConfig {
max_iter: 5,
tol: 1e-2,
..Default::default()
};
let result = transfer_alignment(&data, &data, &t, &config).unwrap();
let max_dev: f64 = result
.bridging_gamma
.iter()
.zip(t.iter())
.map(|(&g, &ti)| (g - ti).abs())
.fold(0.0_f64, f64::max);
assert!(
max_dev < 0.3,
"bridging warp should be near identity for same population, max_dev={max_dev}"
);
for (i, &d) in result.distances.iter().enumerate() {
assert!(
d < 5.0,
"distance[{i}]={d} should be small for same-population transfer"
);
}
}
#[test]
fn transfer_shifted_population() {
let (source, t) = make_data(8, 20, 42);
let m = t.len();
let n = source.nrows();
let mut target = FdMatrix::zeros(n, m);
for i in 0..n {
for j in 0..m {
target[(i, j)] = source[(i, j)] + 2.0;
}
}
let config = TransferAlignConfig {
max_iter: 5,
tol: 1e-2,
..Default::default()
};
let result = transfer_alignment(&source, &target, &t, &config).unwrap();
let source_mean = &result.source_mean;
let raw_mean_dist: f64 = (0..m)
.map(|j| {
let diff = target[(0, j)] - source_mean[j];
diff * diff
})
.sum::<f64>()
.sqrt();
let aligned_mean_dist: f64 = (0..m)
.map(|j| {
let diff = result.aligned_data[(0, j)] - source_mean[j];
diff * diff
})
.sum::<f64>()
.sqrt();
assert!(
aligned_mean_dist < raw_mean_dist + 1.0,
"aligned dist ({aligned_mean_dist:.2}) should not be much worse than raw dist ({raw_mean_dist:.2})"
);
}
#[test]
fn transfer_output_dimensions() {
let (source, t) = make_data(6, 20, 42);
let (target, _) = make_data(10, 20, 99);
let config = TransferAlignConfig {
max_iter: 3,
tol: 1e-2,
..Default::default()
};
let result = transfer_alignment(&source, &target, &t, &config).unwrap();
assert_eq!(result.aligned_data.shape(), (10, 20));
assert_eq!(result.gammas.shape(), (10, 20));
assert_eq!(result.distances.len(), 10);
assert_eq!(result.source_mean.len(), 20);
assert_eq!(result.bridging_gamma.len(), 20);
}
#[test]
fn transfer_config_default() {
let config = TransferAlignConfig::default();
assert!((config.lambda - 0.0).abs() < f64::EPSILON);
assert_eq!(config.max_iter, 15);
assert!((config.tol - 1e-3).abs() < f64::EPSILON);
}
}