1use super::karcher::karcher_mean;
4use crate::error::FdarError;
5use crate::matrix::FdMatrix;
6
7#[derive(Debug, Clone, PartialEq)]
9#[non_exhaustive]
10pub struct PersistenceDiagramResult {
11 pub lambdas: Vec<f64>,
13 pub peak_counts: Vec<usize>,
15 pub persistence_pairs: Vec<(usize, usize)>,
19 pub optimal_lambda: f64,
21 pub optimal_index: usize,
23}
24
25fn count_peaks(mean: &[f64], prominence_frac: f64) -> usize {
31 let m = mean.len();
32 if m < 3 {
33 return 0;
34 }
35
36 let min_val = mean.iter().copied().fold(f64::INFINITY, f64::min);
37 let max_val = mean.iter().copied().fold(f64::NEG_INFINITY, f64::max);
38 let range = max_val - min_val;
39 let threshold = prominence_frac * range;
40
41 let mut count = 0;
42 for j in 1..m - 1 {
43 if mean[j] > mean[j - 1] && mean[j] > mean[j + 1] {
44 let prom = (mean[j] - mean[j - 1]).min(mean[j] - mean[j + 1]);
46 if prom > threshold {
47 count += 1;
48 }
49 }
50 }
51 count
52}
53
54fn build_persistence_pairs(peak_counts: &[usize]) -> Vec<(usize, usize)> {
56 if peak_counts.is_empty() {
57 return Vec::new();
58 }
59 let mut pairs = Vec::new();
60 let mut start = 0;
61 for i in 1..peak_counts.len() {
62 if peak_counts[i] != peak_counts[start] {
63 pairs.push((start, i - 1));
64 start = i;
65 }
66 }
67 pairs.push((start, peak_counts.len() - 1));
68 pairs
69}
70
71#[must_use = "expensive computation whose result should not be discarded"]
92pub fn peak_persistence(
93 data: &FdMatrix,
94 argvals: &[f64],
95 lambdas: &[f64],
96 max_iter: usize,
97 tol: f64,
98) -> Result<PersistenceDiagramResult, FdarError> {
99 let n = data.nrows();
100 let m = data.ncols();
101
102 if n < 2 {
104 return Err(FdarError::InvalidDimension {
105 parameter: "data",
106 expected: "at least 2 rows".to_string(),
107 actual: format!("{n} rows"),
108 });
109 }
110 if argvals.len() != m {
111 return Err(FdarError::InvalidDimension {
112 parameter: "argvals",
113 expected: format!("{m}"),
114 actual: format!("{}", argvals.len()),
115 });
116 }
117 if lambdas.is_empty() {
118 return Err(FdarError::InvalidParameter {
119 parameter: "lambdas",
120 message: "must be non-empty".to_string(),
121 });
122 }
123 if lambdas.iter().any(|&l| l < 0.0) {
124 return Err(FdarError::InvalidParameter {
125 parameter: "lambdas",
126 message: "all lambda values must be >= 0".to_string(),
127 });
128 }
129 if max_iter == 0 {
130 return Err(FdarError::InvalidParameter {
131 parameter: "max_iter",
132 message: "must be > 0".to_string(),
133 });
134 }
135
136 let mut peak_counts = Vec::with_capacity(lambdas.len());
138
139 for &lam in lambdas {
140 let result = karcher_mean(data, argvals, max_iter, tol, lam);
141 let count = count_peaks(&result.mean, 0.001);
142 peak_counts.push(count);
143 }
144
145 let persistence_pairs = build_persistence_pairs(&peak_counts);
147
148 let (best_start, best_end) = persistence_pairs
150 .iter()
151 .copied()
152 .max_by_key(|&(s, e)| {
153 let span = lambdas[e] - lambdas[s];
156 (span * 1e9) as u64
158 })
159 .unwrap_or((0, 0));
160
161 let optimal_index = (best_start + best_end) / 2;
162 let optimal_lambda = lambdas[optimal_index];
163
164 Ok(PersistenceDiagramResult {
165 lambdas: lambdas.to_vec(),
166 peak_counts,
167 persistence_pairs,
168 optimal_lambda,
169 optimal_index,
170 })
171}
172
173#[cfg(test)]
174mod tests {
175 use super::*;
176 use crate::test_helpers::uniform_grid;
177
178 fn single_peak_data(n: usize, m: usize) -> (FdMatrix, Vec<f64>) {
180 let t = uniform_grid(m);
181 let mut data_vec = vec![0.0; n * m];
182 for i in 0..n {
183 let shift = 0.05 * i as f64;
184 for j in 0..m {
185 data_vec[i + j * n] = (std::f64::consts::PI * (t[j] + shift)).sin();
187 }
188 }
189 let data = FdMatrix::from_column_major(data_vec, n, m).unwrap();
190 (data, t)
191 }
192
193 #[test]
194 fn persistence_single_peak_stable() {
195 let (data, t) = single_peak_data(6, 31);
196 let lambdas = vec![0.0, 0.01, 0.1, 1.0];
197
198 let result = peak_persistence(&data, &t, &lambdas, 5, 1e-2).unwrap();
199
200 let count_one = result.peak_counts.iter().filter(|&&c| c == 1).count();
202 assert!(
203 count_one >= lambdas.len() / 2,
204 "Expected most peak counts to be 1, got {:?}",
205 result.peak_counts
206 );
207 }
208
209 #[test]
210 fn persistence_optimal_in_range() {
211 let (data, t) = single_peak_data(6, 31);
212 let lambdas = vec![0.0, 0.01, 0.1, 1.0, 10.0];
213
214 let result = peak_persistence(&data, &t, &lambdas, 5, 1e-2).unwrap();
215
216 assert!(
217 result.optimal_lambda >= lambdas[0],
218 "optimal_lambda {} below range",
219 result.optimal_lambda
220 );
221 assert!(
222 result.optimal_lambda <= *lambdas.last().unwrap(),
223 "optimal_lambda {} above range",
224 result.optimal_lambda
225 );
226 }
227
228 #[test]
229 fn persistence_peak_counts_length() {
230 let (data, t) = single_peak_data(6, 31);
231 let lambdas = vec![0.0, 0.5, 1.0];
232
233 let result = peak_persistence(&data, &t, &lambdas, 3, 1e-2).unwrap();
234 assert_eq!(result.peak_counts.len(), lambdas.len());
235 }
236
237 #[test]
238 fn persistence_rejects_empty_lambdas() {
239 let (data, t) = single_peak_data(4, 21);
240 let result = peak_persistence(&data, &t, &[], 5, 1e-3);
241 assert!(result.is_err(), "Empty lambdas should produce an error");
242 }
243}