fdars_core/alignment/
warp_stats.rs1use crate::error::FdarError;
8use crate::matrix::FdMatrix;
9use crate::warping::{
10 exp_map_sphere, gam_to_psi, inv_exp_map_sphere, l2_norm_l2, phase_distance, psi_to_gam,
11};
12
13#[derive(Debug, Clone, PartialEq)]
15#[non_exhaustive]
16pub struct WarpStatistics {
17 pub mean: Vec<f64>,
19 pub variance: Vec<f64>,
21 pub std_dev: Vec<f64>,
23 pub lower_band: Vec<f64>,
25 pub upper_band: Vec<f64>,
27 pub karcher_mean_warp: Vec<f64>,
29 pub geodesic_distances: Vec<f64>,
31}
32
33fn normal_quantile(p: f64) -> f64 {
35 const C0: f64 = 2.515_517;
36 const C1: f64 = 0.802_853;
37 const C2: f64 = 0.010_328;
38 const D1: f64 = 1.432_788;
39 const D2: f64 = 0.189_269;
40 const D3: f64 = 0.001_308;
41
42 if p <= 0.0 || p >= 1.0 {
43 return f64::NAN;
44 }
45 if (p - 0.5).abs() < 1e-15 {
46 return 0.0;
47 }
48
49 let (sign, q) = if p < 0.5 { (-1.0, 1.0 - p) } else { (1.0, p) };
50 let t = (-2.0 * (1.0 - q).ln()).sqrt();
51 let numerator = C0 + C1 * t + C2 * t * t;
52 let denominator = 1.0 + D1 * t + D2 * t * t + D3 * t * t * t;
53 sign * (t - numerator / denominator)
54}
55
56#[must_use = "expensive computation whose result should not be discarded"]
71pub fn warp_statistics(
72 gammas: &FdMatrix,
73 argvals: &[f64],
74 confidence_level: f64,
75) -> Result<WarpStatistics, FdarError> {
76 let (n, m) = gammas.shape();
77
78 if n < 2 {
80 return Err(FdarError::InvalidDimension {
81 parameter: "gammas",
82 expected: "at least 2 rows".to_string(),
83 actual: format!("{n} rows"),
84 });
85 }
86 if m != argvals.len() {
87 return Err(FdarError::InvalidDimension {
88 parameter: "argvals",
89 expected: format!("length {m}"),
90 actual: format!("length {}", argvals.len()),
91 });
92 }
93 if m < 2 {
94 return Err(FdarError::InvalidDimension {
95 parameter: "gammas",
96 expected: "at least 2 columns".to_string(),
97 actual: format!("{m} columns"),
98 });
99 }
100
101 if confidence_level <= 0.0 || confidence_level >= 1.0 {
103 return Err(FdarError::InvalidParameter {
104 parameter: "confidence_level",
105 message: format!("must be in (0, 1), got {confidence_level}"),
106 });
107 }
108
109 let nf = n as f64;
110
111 let mut mean = vec![0.0; m];
114 let mut variance = vec![0.0; m];
115
116 for j in 0..m {
117 let col = gammas.column(j);
118 let mu = col.iter().sum::<f64>() / nf;
119 mean[j] = mu;
120 let var = col.iter().map(|&v| (v - mu) * (v - mu)).sum::<f64>() / (nf - 1.0);
121 variance[j] = var;
122 }
123
124 let std_dev: Vec<f64> = variance.iter().map(|&v| v.sqrt()).collect();
125
126 let alpha = 1.0 - confidence_level;
129 let z = normal_quantile(1.0 - alpha / 2.0);
130 let sqrt_n = nf.sqrt();
131
132 let lower_band: Vec<f64> = mean
133 .iter()
134 .zip(std_dev.iter())
135 .map(|(&mu, &sd)| mu - z * sd / sqrt_n)
136 .collect();
137 let upper_band: Vec<f64> = mean
138 .iter()
139 .zip(std_dev.iter())
140 .map(|(&mu, &sd)| mu + z * sd / sqrt_n)
141 .collect();
142
143 let t0 = argvals[0];
146 let t1 = argvals[m - 1];
147 let domain = t1 - t0;
148
149 let time_01: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
151 let h = 1.0 / (m - 1) as f64;
152
153 let mut psis: Vec<Vec<f64>> = Vec::with_capacity(n);
155 for i in 0..n {
156 let row = gammas.row(i);
157 let gam_01: Vec<f64> = row.iter().map(|&g| (g - t0) / domain).collect();
158 psis.push(gam_to_psi(&gam_01, h));
159 }
160
161 let mut psi_mean = psis[0].clone();
163 let max_iter = 20;
164 let step_size = 0.5;
165
166 for _ in 0..max_iter {
167 let mut mean_tangent = vec![0.0; m];
169 for psi_i in &psis {
170 let v = inv_exp_map_sphere(&psi_mean, psi_i, &time_01);
171 for (mt, vi) in mean_tangent.iter_mut().zip(v.iter()) {
172 *mt += vi / nf;
173 }
174 }
175
176 let tangent_norm = l2_norm_l2(&mean_tangent, &time_01);
178 if tangent_norm < 1e-10 {
179 break;
180 }
181
182 let step_tangent: Vec<f64> = mean_tangent.iter().map(|&v| v * step_size).collect();
184 psi_mean = exp_map_sphere(&psi_mean, &step_tangent, &time_01);
185
186 let norm = l2_norm_l2(&psi_mean, &time_01);
188 if norm > 1e-10 {
189 for v in &mut psi_mean {
190 *v /= norm;
191 }
192 }
193 }
194
195 let karcher_gam_01 = psi_to_gam(&psi_mean, &time_01);
197 let mut karcher_mean_warp: Vec<f64> = karcher_gam_01.iter().map(|&g| t0 + g * domain).collect();
198 crate::warping::normalize_warp(&mut karcher_mean_warp, argvals);
199
200 let geodesic_distances: Vec<f64> = (0..n)
203 .map(|i| {
204 let row = gammas.row(i);
205 phase_distance(&row, argvals)
206 })
207 .collect();
208
209 Ok(WarpStatistics {
210 mean,
211 variance,
212 std_dev,
213 lower_band,
214 upper_band,
215 karcher_mean_warp,
216 geodesic_distances,
217 })
218}