1use super::karcher::karcher_mean;
4use super::pairwise::{elastic_align_pair, elastic_distance};
5use super::srsf::{compose_warps, reparameterize_curve};
6use crate::error::FdarError;
7use crate::iter_maybe_parallel;
8use crate::matrix::FdMatrix;
9#[cfg(feature = "parallel")]
10use rayon::iter::ParallelIterator;
11
12#[derive(Debug, Clone, PartialEq)]
16pub struct TransferAlignConfig {
17 pub lambda: f64,
19 pub max_iter: usize,
21 pub tol: f64,
23}
24
25impl Default for TransferAlignConfig {
26 fn default() -> Self {
27 Self {
28 lambda: 0.0,
29 max_iter: 15,
30 tol: 1e-3,
31 }
32 }
33}
34
35#[derive(Debug, Clone, PartialEq)]
37#[non_exhaustive]
38pub struct TransferAlignResult {
39 pub source_mean: Vec<f64>,
41 pub aligned_data: FdMatrix,
43 pub gammas: FdMatrix,
45 pub bridging_gamma: Vec<f64>,
47 pub distances: Vec<f64>,
49}
50
51#[must_use = "expensive computation whose result should not be discarded"]
70pub fn transfer_alignment(
71 source_data: &FdMatrix,
72 target_data: &FdMatrix,
73 argvals: &[f64],
74 config: &TransferAlignConfig,
75) -> Result<TransferAlignResult, FdarError> {
76 let (n_source, m_source) = source_data.shape();
77 let (n_target, m_target) = target_data.shape();
78
79 if m_source != m_target {
81 return Err(FdarError::InvalidDimension {
82 parameter: "target_data",
83 expected: format!("{m_source} columns (matching source_data)"),
84 actual: format!("{m_target} columns"),
85 });
86 }
87 let m = m_source;
88 if argvals.len() != m {
89 return Err(FdarError::InvalidDimension {
90 parameter: "argvals",
91 expected: format!("{m}"),
92 actual: format!("{}", argvals.len()),
93 });
94 }
95 if n_source < 1 {
96 return Err(FdarError::InvalidDimension {
97 parameter: "source_data",
98 expected: "at least 1 row".to_string(),
99 actual: format!("{n_source} rows"),
100 });
101 }
102 if n_target < 1 {
103 return Err(FdarError::InvalidDimension {
104 parameter: "target_data",
105 expected: "at least 1 row".to_string(),
106 actual: format!("{n_target} rows"),
107 });
108 }
109
110 let source_karcher = karcher_mean(
112 source_data,
113 argvals,
114 config.max_iter,
115 config.tol,
116 config.lambda,
117 );
118
119 let target_karcher = karcher_mean(
121 target_data,
122 argvals,
123 config.max_iter,
124 config.tol,
125 config.lambda,
126 );
127
128 let bridge_result = elastic_align_pair(
130 &source_karcher.mean,
131 &target_karcher.mean,
132 argvals,
133 config.lambda,
134 );
135
136 let results: Vec<(Vec<f64>, Vec<f64>, f64)> = iter_maybe_parallel!(0..n_target)
140 .map(|i| {
141 let within_gamma = target_karcher.gammas.row(i);
143
144 let gamma_total = compose_warps(&bridge_result.gamma, &within_gamma, argvals);
146
147 let aligned_i = reparameterize_curve(&target_data.row(i), argvals, &gamma_total);
149
150 let dist_i = elastic_distance(&source_karcher.mean, &aligned_i, argvals, config.lambda);
152
153 (gamma_total, aligned_i, dist_i)
154 })
155 .collect();
156
157 let mut gammas = FdMatrix::zeros(n_target, m);
159 let mut aligned_data = FdMatrix::zeros(n_target, m);
160 let mut distances = Vec::with_capacity(n_target);
161
162 for (i, (gamma, aligned, dist)) in results.into_iter().enumerate() {
163 for j in 0..m {
164 gammas[(i, j)] = gamma[j];
165 aligned_data[(i, j)] = aligned[j];
166 }
167 distances.push(dist);
168 }
169
170 Ok(TransferAlignResult {
171 source_mean: source_karcher.mean,
172 aligned_data,
173 gammas,
174 bridging_gamma: bridge_result.gamma,
175 distances,
176 })
177}
178
179#[cfg(test)]
182mod tests {
183 use super::*;
184 use crate::simulation::{sim_fundata, EFunType, EValType};
185 use crate::test_helpers::uniform_grid;
186
187 fn make_data(n: usize, m: usize, seed: u64) -> (FdMatrix, Vec<f64>) {
188 let t = uniform_grid(m);
189 let data = sim_fundata(
190 n,
191 &t,
192 3,
193 EFunType::Fourier,
194 EValType::Exponential,
195 Some(seed),
196 );
197 (data, t)
198 }
199
200 #[test]
201 fn transfer_same_population() {
202 let (data, t) = make_data(8, 20, 42);
203 let config = TransferAlignConfig {
204 max_iter: 5,
205 tol: 1e-2,
206 ..Default::default()
207 };
208 let result = transfer_alignment(&data, &data, &t, &config).unwrap();
209
210 let max_dev: f64 = result
212 .bridging_gamma
213 .iter()
214 .zip(t.iter())
215 .map(|(&g, &ti)| (g - ti).abs())
216 .fold(0.0_f64, f64::max);
217 assert!(
218 max_dev < 0.3,
219 "bridging warp should be near identity for same population, max_dev={max_dev}"
220 );
221
222 for (i, &d) in result.distances.iter().enumerate() {
224 assert!(
225 d < 5.0,
226 "distance[{i}]={d} should be small for same-population transfer"
227 );
228 }
229 }
230
231 #[test]
232 fn transfer_shifted_population() {
233 let (source, t) = make_data(8, 20, 42);
234 let m = t.len();
235 let n = source.nrows();
236
237 let mut target = FdMatrix::zeros(n, m);
239 for i in 0..n {
240 for j in 0..m {
241 target[(i, j)] = source[(i, j)] + 2.0;
242 }
243 }
244
245 let config = TransferAlignConfig {
246 max_iter: 5,
247 tol: 1e-2,
248 ..Default::default()
249 };
250 let result = transfer_alignment(&source, &target, &t, &config).unwrap();
251
252 let source_mean = &result.source_mean;
255 let raw_mean_dist: f64 = (0..m)
256 .map(|j| {
257 let diff = target[(0, j)] - source_mean[j];
258 diff * diff
259 })
260 .sum::<f64>()
261 .sqrt();
262
263 let aligned_mean_dist: f64 = (0..m)
264 .map(|j| {
265 let diff = result.aligned_data[(0, j)] - source_mean[j];
266 diff * diff
267 })
268 .sum::<f64>()
269 .sqrt();
270
271 assert!(
274 aligned_mean_dist < raw_mean_dist + 1.0,
275 "aligned dist ({aligned_mean_dist:.2}) should not be much worse than raw dist ({raw_mean_dist:.2})"
276 );
277 }
278
279 #[test]
280 fn transfer_output_dimensions() {
281 let (source, t) = make_data(6, 20, 42);
282 let (target, _) = make_data(10, 20, 99);
283 let config = TransferAlignConfig {
284 max_iter: 3,
285 tol: 1e-2,
286 ..Default::default()
287 };
288 let result = transfer_alignment(&source, &target, &t, &config).unwrap();
289
290 assert_eq!(result.aligned_data.shape(), (10, 20));
291 assert_eq!(result.gammas.shape(), (10, 20));
292 assert_eq!(result.distances.len(), 10);
293 assert_eq!(result.source_mean.len(), 20);
294 assert_eq!(result.bridging_gamma.len(), 20);
295 }
296
297 #[test]
298 fn transfer_config_default() {
299 let config = TransferAlignConfig::default();
300 assert!((config.lambda - 0.0).abs() < f64::EPSILON);
301 assert_eq!(config.max_iter, 15);
302 assert!((config.tol - 1e-3).abs() < f64::EPSILON);
303 }
304}