Skip to main content

fdars_core/alignment/
persistence.rs

1//! Peak persistence diagram for choosing the alignment regularisation parameter.
2
3use super::karcher::karcher_mean;
4use crate::error::FdarError;
5use crate::matrix::FdMatrix;
6
7/// Result of the peak persistence analysis across a sweep of lambda values.
8#[derive(Debug, Clone, PartialEq)]
9#[non_exhaustive]
10pub struct PersistenceDiagramResult {
11    /// Lambda values evaluated.
12    pub lambdas: Vec<f64>,
13    /// Number of peaks in the Karcher mean at each lambda.
14    pub peak_counts: Vec<usize>,
15    /// Persistence pairs: (birth_lambda_index, death_lambda_index).
16    ///
17    /// Each pair describes a maximal interval where the peak count is constant.
18    pub persistence_pairs: Vec<(usize, usize)>,
19    /// Optimal lambda (center of the longest stable interval).
20    pub optimal_lambda: f64,
21    /// Index into `lambdas` for the optimal value.
22    pub optimal_index: usize,
23}
24
25/// Count peaks in a curve with a small prominence threshold.
26///
27/// A local maximum at index j (1 <= j < m-1) is counted when
28/// `mean[j-1] < mean[j] && mean[j] > mean[j+1]` and the prominence
29/// (relative to the curve's range) exceeds a small threshold.
30fn count_peaks(mean: &[f64], prominence_frac: f64) -> usize {
31    let m = mean.len();
32    if m < 3 {
33        return 0;
34    }
35
36    let min_val = mean.iter().copied().fold(f64::INFINITY, f64::min);
37    let max_val = mean.iter().copied().fold(f64::NEG_INFINITY, f64::max);
38    let range = max_val - min_val;
39    let threshold = prominence_frac * range;
40
41    let mut count = 0;
42    for j in 1..m - 1 {
43        if mean[j] > mean[j - 1] && mean[j] > mean[j + 1] {
44            // Check prominence: min height above the two neighbors
45            let prom = (mean[j] - mean[j - 1]).min(mean[j] - mean[j + 1]);
46            if prom > threshold {
47                count += 1;
48            }
49        }
50    }
51    count
52}
53
54/// Build persistence pairs from peak counts: maximal constant intervals.
55fn build_persistence_pairs(peak_counts: &[usize]) -> Vec<(usize, usize)> {
56    if peak_counts.is_empty() {
57        return Vec::new();
58    }
59    let mut pairs = Vec::new();
60    let mut start = 0;
61    for i in 1..peak_counts.len() {
62        if peak_counts[i] != peak_counts[start] {
63            pairs.push((start, i - 1));
64            start = i;
65        }
66    }
67    pairs.push((start, peak_counts.len() - 1));
68    pairs
69}
70
71/// Analyse the stability of peak count in the Karcher mean across a sweep
72/// of alignment penalty values.
73///
74/// For each candidate lambda the Karcher mean is computed and its peaks are
75/// counted. The longest interval of constant peak count is identified as the
76/// most stable configuration, and the midpoint lambda of that interval is
77/// returned as the optimal choice.
78///
79/// # Arguments
80/// * `data`     - Functional data matrix (n x m).
81/// * `argvals`  - Evaluation grid (length m).
82/// * `lambdas`  - Candidate lambda values to sweep (must be non-empty, all >= 0).
83/// * `max_iter` - Maximum Karcher iterations per lambda.
84/// * `tol`      - Karcher convergence tolerance.
85///
86/// # Errors
87/// Returns `FdarError::InvalidDimension` if `data` has fewer than 2 rows or
88/// `argvals` length mismatches.
89/// Returns `FdarError::InvalidParameter` if `lambdas` is empty, any lambda
90/// is negative, or `max_iter` is 0.
91#[must_use = "expensive computation whose result should not be discarded"]
92pub fn peak_persistence(
93    data: &FdMatrix,
94    argvals: &[f64],
95    lambdas: &[f64],
96    max_iter: usize,
97    tol: f64,
98) -> Result<PersistenceDiagramResult, FdarError> {
99    let n = data.nrows();
100    let m = data.ncols();
101
102    // ── Validation ──────────────────────────────────────────────────────
103    if n < 2 {
104        return Err(FdarError::InvalidDimension {
105            parameter: "data",
106            expected: "at least 2 rows".to_string(),
107            actual: format!("{n} rows"),
108        });
109    }
110    if argvals.len() != m {
111        return Err(FdarError::InvalidDimension {
112            parameter: "argvals",
113            expected: format!("{m}"),
114            actual: format!("{}", argvals.len()),
115        });
116    }
117    if lambdas.is_empty() {
118        return Err(FdarError::InvalidParameter {
119            parameter: "lambdas",
120            message: "must be non-empty".to_string(),
121        });
122    }
123    if lambdas.iter().any(|&l| l < 0.0) {
124        return Err(FdarError::InvalidParameter {
125            parameter: "lambdas",
126            message: "all lambda values must be >= 0".to_string(),
127        });
128    }
129    if max_iter == 0 {
130        return Err(FdarError::InvalidParameter {
131            parameter: "max_iter",
132            message: "must be > 0".to_string(),
133        });
134    }
135
136    // ── Lambda sweep ────────────────────────────────────────────────────
137    let mut peak_counts = Vec::with_capacity(lambdas.len());
138
139    for &lam in lambdas {
140        let result = karcher_mean(data, argvals, max_iter, tol, lam);
141        let count = count_peaks(&result.mean, 0.001);
142        peak_counts.push(count);
143    }
144
145    // ── Build persistence pairs ─────────────────────────────────────────
146    let persistence_pairs = build_persistence_pairs(&peak_counts);
147
148    // ── Find optimal lambda (longest stable interval) ───────────────────
149    let (best_start, best_end) = persistence_pairs
150        .iter()
151        .copied()
152        .max_by_key(|&(s, e)| {
153            // Use the span in lambda space as the primary criterion,
154            // discretised to avoid floating-point comparison issues.
155            let span = lambdas[e] - lambdas[s];
156            // Convert to an integer score (nano-units) for stable ordering
157            (span * 1e9) as u64
158        })
159        .unwrap_or((0, 0));
160
161    let optimal_index = (best_start + best_end) / 2;
162    let optimal_lambda = lambdas[optimal_index];
163
164    Ok(PersistenceDiagramResult {
165        lambdas: lambdas.to_vec(),
166        peak_counts,
167        persistence_pairs,
168        optimal_lambda,
169        optimal_index,
170    })
171}
172
173#[cfg(test)]
174mod tests {
175    use super::*;
176    use crate::test_helpers::uniform_grid;
177
178    /// Build a small dataset with one clear sine peak per curve.
179    fn single_peak_data(n: usize, m: usize) -> (FdMatrix, Vec<f64>) {
180        let t = uniform_grid(m);
181        let mut data_vec = vec![0.0; n * m];
182        for i in 0..n {
183            let shift = 0.05 * i as f64;
184            for j in 0..m {
185                // sin(pi * t) has exactly one peak on [0,1]
186                data_vec[i + j * n] = (std::f64::consts::PI * (t[j] + shift)).sin();
187            }
188        }
189        let data = FdMatrix::from_column_major(data_vec, n, m).unwrap();
190        (data, t)
191    }
192
193    #[test]
194    fn persistence_single_peak_stable() {
195        let (data, t) = single_peak_data(6, 31);
196        let lambdas = vec![0.0, 0.01, 0.1, 1.0];
197
198        let result = peak_persistence(&data, &t, &lambdas, 5, 1e-2).unwrap();
199
200        // All (or most) peak counts should be 1
201        let count_one = result.peak_counts.iter().filter(|&&c| c == 1).count();
202        assert!(
203            count_one >= lambdas.len() / 2,
204            "Expected most peak counts to be 1, got {:?}",
205            result.peak_counts
206        );
207    }
208
209    #[test]
210    fn persistence_optimal_in_range() {
211        let (data, t) = single_peak_data(6, 31);
212        let lambdas = vec![0.0, 0.01, 0.1, 1.0, 10.0];
213
214        let result = peak_persistence(&data, &t, &lambdas, 5, 1e-2).unwrap();
215
216        assert!(
217            result.optimal_lambda >= lambdas[0],
218            "optimal_lambda {} below range",
219            result.optimal_lambda
220        );
221        assert!(
222            result.optimal_lambda <= *lambdas.last().unwrap(),
223            "optimal_lambda {} above range",
224            result.optimal_lambda
225        );
226    }
227
228    #[test]
229    fn persistence_peak_counts_length() {
230        let (data, t) = single_peak_data(6, 31);
231        let lambdas = vec![0.0, 0.5, 1.0];
232
233        let result = peak_persistence(&data, &t, &lambdas, 3, 1e-2).unwrap();
234        assert_eq!(result.peak_counts.len(), lambdas.len());
235    }
236
237    #[test]
238    fn persistence_rejects_empty_lambdas() {
239        let (data, t) = single_peak_data(4, 21);
240        let result = peak_persistence(&data, &t, &[], 5, 1e-3);
241        assert!(result.is_err(), "Empty lambdas should produce an error");
242    }
243}