Skip to main content

fdars_core/alignment/
warp_stats.rs

1//! Warping function statistics: mean, variance, confidence bands.
2//!
3//! After elastic alignment, the warping functions contain information about
4//! phase variation. This module provides summary statistics and uncertainty
5//! quantification for sets of warping functions.
6
7use crate::error::FdarError;
8use crate::matrix::FdMatrix;
9use crate::warping::{
10    exp_map_sphere, gam_to_psi, inv_exp_map_sphere, l2_norm_l2, phase_distance, psi_to_gam,
11};
12
13/// Statistics computed on a set of warping functions.
14#[derive(Debug, Clone, PartialEq)]
15#[non_exhaustive]
16pub struct WarpStatistics {
17    /// Pointwise mean warp (length m).
18    pub mean: Vec<f64>,
19    /// Pointwise variance (length m).
20    pub variance: Vec<f64>,
21    /// Pointwise standard deviation (length m).
22    pub std_dev: Vec<f64>,
23    /// Lower confidence band (length m).
24    pub lower_band: Vec<f64>,
25    /// Upper confidence band (length m).
26    pub upper_band: Vec<f64>,
27    /// Karcher mean warp on the Hilbert sphere (length m).
28    pub karcher_mean_warp: Vec<f64>,
29    /// Per-warp geodesic distances from Karcher mean (length n).
30    pub geodesic_distances: Vec<f64>,
31}
32
33/// Inverse normal CDF (probit) via rational approximation (Abramowitz & Stegun 26.2.23).
34fn normal_quantile(p: f64) -> f64 {
35    const C0: f64 = 2.515_517;
36    const C1: f64 = 0.802_853;
37    const C2: f64 = 0.010_328;
38    const D1: f64 = 1.432_788;
39    const D2: f64 = 0.189_269;
40    const D3: f64 = 0.001_308;
41
42    if p <= 0.0 || p >= 1.0 {
43        return f64::NAN;
44    }
45    if (p - 0.5).abs() < 1e-15 {
46        return 0.0;
47    }
48
49    let (sign, q) = if p < 0.5 { (-1.0, 1.0 - p) } else { (1.0, p) };
50    let t = (-2.0 * (1.0 - q).ln()).sqrt();
51    let numerator = C0 + C1 * t + C2 * t * t;
52    let denominator = 1.0 + D1 * t + D2 * t * t + D3 * t * t * t;
53    sign * (t - numerator / denominator)
54}
55
56/// Compute summary statistics for a set of warping functions.
57///
58/// Given an n x m matrix of warping functions (one per row) and the common
59/// evaluation grid, computes pointwise statistics, confidence bands, the
60/// Karcher mean warp on the Hilbert sphere, and per-warp geodesic distances.
61///
62/// # Arguments
63/// * `gammas` — Warping functions matrix (n x m), one warp per row.
64/// * `argvals` — Common evaluation points (length m).
65/// * `confidence_level` — Confidence level for the bands (e.g. 0.95).
66///
67/// # Errors
68/// Returns `FdarError::InvalidDimension` if n < 2 or m does not match.
69/// Returns `FdarError::InvalidParameter` if confidence_level is not in (0, 1).
70#[must_use = "expensive computation whose result should not be discarded"]
71pub fn warp_statistics(
72    gammas: &FdMatrix,
73    argvals: &[f64],
74    confidence_level: f64,
75) -> Result<WarpStatistics, FdarError> {
76    let (n, m) = gammas.shape();
77
78    // Validate dimensions
79    if n < 2 {
80        return Err(FdarError::InvalidDimension {
81            parameter: "gammas",
82            expected: "at least 2 rows".to_string(),
83            actual: format!("{n} rows"),
84        });
85    }
86    if m != argvals.len() {
87        return Err(FdarError::InvalidDimension {
88            parameter: "argvals",
89            expected: format!("length {m}"),
90            actual: format!("length {}", argvals.len()),
91        });
92    }
93    if m < 2 {
94        return Err(FdarError::InvalidDimension {
95            parameter: "gammas",
96            expected: "at least 2 columns".to_string(),
97            actual: format!("{m} columns"),
98        });
99    }
100
101    // Validate confidence level
102    if confidence_level <= 0.0 || confidence_level >= 1.0 {
103        return Err(FdarError::InvalidParameter {
104            parameter: "confidence_level",
105            message: format!("must be in (0, 1), got {confidence_level}"),
106        });
107    }
108
109    let nf = n as f64;
110
111    // ── Step 1: Pointwise mean, variance, std_dev ──
112
113    let mut mean = vec![0.0; m];
114    let mut variance = vec![0.0; m];
115
116    for j in 0..m {
117        let col = gammas.column(j);
118        let mu = col.iter().sum::<f64>() / nf;
119        mean[j] = mu;
120        let var = col.iter().map(|&v| (v - mu) * (v - mu)).sum::<f64>() / (nf - 1.0);
121        variance[j] = var;
122    }
123
124    let std_dev: Vec<f64> = variance.iter().map(|&v| v.sqrt()).collect();
125
126    // ── Step 2: Confidence bands ──
127
128    let alpha = 1.0 - confidence_level;
129    let z = normal_quantile(1.0 - alpha / 2.0);
130    let sqrt_n = nf.sqrt();
131
132    let lower_band: Vec<f64> = mean
133        .iter()
134        .zip(std_dev.iter())
135        .map(|(&mu, &sd)| mu - z * sd / sqrt_n)
136        .collect();
137    let upper_band: Vec<f64> = mean
138        .iter()
139        .zip(std_dev.iter())
140        .map(|(&mu, &sd)| mu + z * sd / sqrt_n)
141        .collect();
142
143    // ── Step 3: Karcher mean warp on the Hilbert sphere ──
144
145    let t0 = argvals[0];
146    let t1 = argvals[m - 1];
147    let domain = t1 - t0;
148
149    // Uniform time grid on [0,1]
150    let time_01: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
151    let h = 1.0 / (m - 1) as f64;
152
153    // Convert all warps to psi representation
154    let mut psis: Vec<Vec<f64>> = Vec::with_capacity(n);
155    for i in 0..n {
156        let row = gammas.row(i);
157        let gam_01: Vec<f64> = row.iter().map(|&g| (g - t0) / domain).collect();
158        psis.push(gam_to_psi(&gam_01, h));
159    }
160
161    // Iterative Karcher mean on the sphere
162    let mut psi_mean = psis[0].clone();
163    let max_iter = 20;
164    let step_size = 0.5;
165
166    for _ in 0..max_iter {
167        // Compute mean tangent vector
168        let mut mean_tangent = vec![0.0; m];
169        for psi_i in &psis {
170            let v = inv_exp_map_sphere(&psi_mean, psi_i, &time_01);
171            for (mt, vi) in mean_tangent.iter_mut().zip(v.iter()) {
172                *mt += vi / nf;
173            }
174        }
175
176        // Check convergence
177        let tangent_norm = l2_norm_l2(&mean_tangent, &time_01);
178        if tangent_norm < 1e-10 {
179            break;
180        }
181
182        // Take a step along the mean tangent direction
183        let step_tangent: Vec<f64> = mean_tangent.iter().map(|&v| v * step_size).collect();
184        psi_mean = exp_map_sphere(&psi_mean, &step_tangent, &time_01);
185
186        // Re-normalize to unit sphere
187        let norm = l2_norm_l2(&psi_mean, &time_01);
188        if norm > 1e-10 {
189            for v in &mut psi_mean {
190                *v /= norm;
191            }
192        }
193    }
194
195    // Convert Karcher mean psi back to warping function
196    let karcher_gam_01 = psi_to_gam(&psi_mean, &time_01);
197    let mut karcher_mean_warp: Vec<f64> = karcher_gam_01.iter().map(|&g| t0 + g * domain).collect();
198    crate::warping::normalize_warp(&mut karcher_mean_warp, argvals);
199
200    // ── Step 4: Geodesic distances ──
201
202    let geodesic_distances: Vec<f64> = (0..n)
203        .map(|i| {
204            let row = gammas.row(i);
205            phase_distance(&row, argvals)
206        })
207        .collect();
208
209    Ok(WarpStatistics {
210        mean,
211        variance,
212        std_dev,
213        lower_band,
214        upper_band,
215        karcher_mean_warp,
216        geodesic_distances,
217    })
218}