Skip to main content

fdars_core/alignment/
transfer.rs

1//! Transfer alignment: align curves across populations using a shared reference.
2
3use super::karcher::karcher_mean;
4use super::pairwise::{elastic_align_pair, elastic_distance};
5use super::srsf::{compose_warps, reparameterize_curve};
6use crate::error::FdarError;
7use crate::iter_maybe_parallel;
8use crate::matrix::FdMatrix;
9#[cfg(feature = "parallel")]
10use rayon::iter::ParallelIterator;
11
12// ─── Types ──────────────────────────────────────────────────────────────────
13
14/// Configuration for transfer alignment.
15#[derive(Debug, Clone, PartialEq)]
16pub struct TransferAlignConfig {
17    /// Roughness penalty for elastic alignment.
18    pub lambda: f64,
19    /// Maximum Karcher mean iterations.
20    pub max_iter: usize,
21    /// Convergence tolerance for the Karcher mean.
22    pub tol: f64,
23}
24
25impl Default for TransferAlignConfig {
26    fn default() -> Self {
27        Self {
28            lambda: 0.0,
29            max_iter: 15,
30            tol: 1e-3,
31        }
32    }
33}
34
35/// Result of transfer alignment.
36#[derive(Debug, Clone, PartialEq)]
37#[non_exhaustive]
38pub struct TransferAlignResult {
39    /// Source Karcher mean (population A's reference).
40    pub source_mean: Vec<f64>,
41    /// Target curves aligned to source coordinate system (n_target x m).
42    pub aligned_data: FdMatrix,
43    /// Warping functions mapping target curves to source frame (n_target x m).
44    pub gammas: FdMatrix,
45    /// Bridging warp from target mean to source mean.
46    pub bridging_gamma: Vec<f64>,
47    /// Per-curve elastic distances after alignment.
48    pub distances: Vec<f64>,
49}
50
51// ─── Public API ─────────────────────────────────────────────────────────────
52
53/// Align curves from a target population to a source population's coordinate system.
54///
55/// Computes Karcher means for both populations, finds the bridging warp that
56/// aligns the target mean to the source mean, then composes this bridge with
57/// each target curve's within-population warp to produce curves aligned in
58/// the source coordinate frame.
59///
60/// # Arguments
61/// * `source_data` - Source population (n_source x m).
62/// * `target_data` - Target population to align (n_target x m).
63/// * `argvals`     - Evaluation points (length m).
64/// * `config`      - Transfer alignment configuration.
65///
66/// # Errors
67/// Returns [`FdarError::InvalidDimension`] if matrices have different `ncols`,
68/// `argvals` length does not match, or either matrix has 0 rows.
69#[must_use = "expensive computation whose result should not be discarded"]
70pub fn transfer_alignment(
71    source_data: &FdMatrix,
72    target_data: &FdMatrix,
73    argvals: &[f64],
74    config: &TransferAlignConfig,
75) -> Result<TransferAlignResult, FdarError> {
76    let (n_source, m_source) = source_data.shape();
77    let (n_target, m_target) = target_data.shape();
78
79    // ── Validation ──
80    if m_source != m_target {
81        return Err(FdarError::InvalidDimension {
82            parameter: "target_data",
83            expected: format!("{m_source} columns (matching source_data)"),
84            actual: format!("{m_target} columns"),
85        });
86    }
87    let m = m_source;
88    if argvals.len() != m {
89        return Err(FdarError::InvalidDimension {
90            parameter: "argvals",
91            expected: format!("{m}"),
92            actual: format!("{}", argvals.len()),
93        });
94    }
95    if n_source < 1 {
96        return Err(FdarError::InvalidDimension {
97            parameter: "source_data",
98            expected: "at least 1 row".to_string(),
99            actual: format!("{n_source} rows"),
100        });
101    }
102    if n_target < 1 {
103        return Err(FdarError::InvalidDimension {
104            parameter: "target_data",
105            expected: "at least 1 row".to_string(),
106            actual: format!("{n_target} rows"),
107        });
108    }
109
110    // ── Compute source reference ──
111    let source_karcher = karcher_mean(
112        source_data,
113        argvals,
114        config.max_iter,
115        config.tol,
116        config.lambda,
117    );
118
119    // ── Compute target reference ──
120    let target_karcher = karcher_mean(
121        target_data,
122        argvals,
123        config.max_iter,
124        config.tol,
125        config.lambda,
126    );
127
128    // ── Bridging alignment: align target mean to source mean ──
129    let bridge_result = elastic_align_pair(
130        &source_karcher.mean,
131        &target_karcher.mean,
132        argvals,
133        config.lambda,
134    );
135
136    // ── Align target curves ──
137    // For each target curve: compose bridging warp with within-population warp,
138    // then apply to original target curve.
139    let results: Vec<(Vec<f64>, Vec<f64>, f64)> = iter_maybe_parallel!(0..n_target)
140        .map(|i| {
141            // Within-population warp for curve i (from target Karcher computation)
142            let within_gamma = target_karcher.gammas.row(i);
143
144            // Compose: bridge_gamma(within_gamma(t))
145            let gamma_total = compose_warps(&bridge_result.gamma, &within_gamma, argvals);
146
147            // Apply to original target curve
148            let aligned_i = reparameterize_curve(&target_data.row(i), argvals, &gamma_total);
149
150            // Compute distance to source mean
151            let dist_i = elastic_distance(&source_karcher.mean, &aligned_i, argvals, config.lambda);
152
153            (gamma_total, aligned_i, dist_i)
154        })
155        .collect();
156
157    // ── Assemble result ──
158    let mut gammas = FdMatrix::zeros(n_target, m);
159    let mut aligned_data = FdMatrix::zeros(n_target, m);
160    let mut distances = Vec::with_capacity(n_target);
161
162    for (i, (gamma, aligned, dist)) in results.into_iter().enumerate() {
163        for j in 0..m {
164            gammas[(i, j)] = gamma[j];
165            aligned_data[(i, j)] = aligned[j];
166        }
167        distances.push(dist);
168    }
169
170    Ok(TransferAlignResult {
171        source_mean: source_karcher.mean,
172        aligned_data,
173        gammas,
174        bridging_gamma: bridge_result.gamma,
175        distances,
176    })
177}
178
179// ─── Tests ──────────────────────────────────────────────────────────────────
180
181#[cfg(test)]
182mod tests {
183    use super::*;
184    use crate::simulation::{sim_fundata, EFunType, EValType};
185    use crate::test_helpers::uniform_grid;
186
187    fn make_data(n: usize, m: usize, seed: u64) -> (FdMatrix, Vec<f64>) {
188        let t = uniform_grid(m);
189        let data = sim_fundata(
190            n,
191            &t,
192            3,
193            EFunType::Fourier,
194            EValType::Exponential,
195            Some(seed),
196        );
197        (data, t)
198    }
199
200    #[test]
201    fn transfer_same_population() {
202        let (data, t) = make_data(8, 20, 42);
203        let config = TransferAlignConfig {
204            max_iter: 5,
205            tol: 1e-2,
206            ..Default::default()
207        };
208        let result = transfer_alignment(&data, &data, &t, &config).unwrap();
209
210        // Bridging warp should be close to identity
211        let max_dev: f64 = result
212            .bridging_gamma
213            .iter()
214            .zip(t.iter())
215            .map(|(&g, &ti)| (g - ti).abs())
216            .fold(0.0_f64, f64::max);
217        assert!(
218            max_dev < 0.3,
219            "bridging warp should be near identity for same population, max_dev={max_dev}"
220        );
221
222        // Distances should be small
223        for (i, &d) in result.distances.iter().enumerate() {
224            assert!(
225                d < 5.0,
226                "distance[{i}]={d} should be small for same-population transfer"
227            );
228        }
229    }
230
231    #[test]
232    fn transfer_shifted_population() {
233        let (source, t) = make_data(8, 20, 42);
234        let m = t.len();
235        let n = source.nrows();
236
237        // Create a shifted version of source
238        let mut target = FdMatrix::zeros(n, m);
239        for i in 0..n {
240            for j in 0..m {
241                target[(i, j)] = source[(i, j)] + 2.0;
242            }
243        }
244
245        let config = TransferAlignConfig {
246            max_iter: 5,
247            tol: 1e-2,
248            ..Default::default()
249        };
250        let result = transfer_alignment(&source, &target, &t, &config).unwrap();
251
252        // After alignment, the aligned target curves should be closer to the
253        // source mean than the raw target curves
254        let source_mean = &result.source_mean;
255        let raw_mean_dist: f64 = (0..m)
256            .map(|j| {
257                let diff = target[(0, j)] - source_mean[j];
258                diff * diff
259            })
260            .sum::<f64>()
261            .sqrt();
262
263        let aligned_mean_dist: f64 = (0..m)
264            .map(|j| {
265                let diff = result.aligned_data[(0, j)] - source_mean[j];
266                diff * diff
267            })
268            .sum::<f64>()
269            .sqrt();
270
271        // The aligned version should not be worse than raw (with some tolerance
272        // since the shift is in amplitude and alignment is mainly phase)
273        assert!(
274            aligned_mean_dist < raw_mean_dist + 1.0,
275            "aligned dist ({aligned_mean_dist:.2}) should not be much worse than raw dist ({raw_mean_dist:.2})"
276        );
277    }
278
279    #[test]
280    fn transfer_output_dimensions() {
281        let (source, t) = make_data(6, 20, 42);
282        let (target, _) = make_data(10, 20, 99);
283        let config = TransferAlignConfig {
284            max_iter: 3,
285            tol: 1e-2,
286            ..Default::default()
287        };
288        let result = transfer_alignment(&source, &target, &t, &config).unwrap();
289
290        assert_eq!(result.aligned_data.shape(), (10, 20));
291        assert_eq!(result.gammas.shape(), (10, 20));
292        assert_eq!(result.distances.len(), 10);
293        assert_eq!(result.source_mean.len(), 20);
294        assert_eq!(result.bridging_gamma.len(), 20);
295    }
296
297    #[test]
298    fn transfer_config_default() {
299        let config = TransferAlignConfig::default();
300        assert!((config.lambda - 0.0).abs() < f64::EPSILON);
301        assert_eq!(config.max_iter, 15);
302        assert!((config.tol - 1e-3).abs() < f64::EPSILON);
303    }
304}