use super::karcher::karcher_mean;
use crate::error::FdarError;
use crate::matrix::FdMatrix;
#[derive(Debug, Clone, PartialEq)]
#[non_exhaustive]
pub struct PersistenceDiagramResult {
pub lambdas: Vec<f64>,
pub peak_counts: Vec<usize>,
pub persistence_pairs: Vec<(usize, usize)>,
pub optimal_lambda: f64,
pub optimal_index: usize,
}
fn count_peaks(mean: &[f64], prominence_frac: f64) -> usize {
let m = mean.len();
if m < 3 {
return 0;
}
let min_val = mean.iter().copied().fold(f64::INFINITY, f64::min);
let max_val = mean.iter().copied().fold(f64::NEG_INFINITY, f64::max);
let range = max_val - min_val;
let threshold = prominence_frac * range;
let mut count = 0;
for j in 1..m - 1 {
if mean[j] > mean[j - 1] && mean[j] > mean[j + 1] {
let prom = (mean[j] - mean[j - 1]).min(mean[j] - mean[j + 1]);
if prom > threshold {
count += 1;
}
}
}
count
}
fn build_persistence_pairs(peak_counts: &[usize]) -> Vec<(usize, usize)> {
if peak_counts.is_empty() {
return Vec::new();
}
let mut pairs = Vec::new();
let mut start = 0;
for i in 1..peak_counts.len() {
if peak_counts[i] != peak_counts[start] {
pairs.push((start, i - 1));
start = i;
}
}
pairs.push((start, peak_counts.len() - 1));
pairs
}
#[must_use = "expensive computation whose result should not be discarded"]
pub fn peak_persistence(
data: &FdMatrix,
argvals: &[f64],
lambdas: &[f64],
max_iter: usize,
tol: f64,
) -> Result<PersistenceDiagramResult, FdarError> {
let n = data.nrows();
let m = data.ncols();
if n < 2 {
return Err(FdarError::InvalidDimension {
parameter: "data",
expected: "at least 2 rows".to_string(),
actual: format!("{n} rows"),
});
}
if argvals.len() != m {
return Err(FdarError::InvalidDimension {
parameter: "argvals",
expected: format!("{m}"),
actual: format!("{}", argvals.len()),
});
}
if lambdas.is_empty() {
return Err(FdarError::InvalidParameter {
parameter: "lambdas",
message: "must be non-empty".to_string(),
});
}
if lambdas.iter().any(|&l| l < 0.0) {
return Err(FdarError::InvalidParameter {
parameter: "lambdas",
message: "all lambda values must be >= 0".to_string(),
});
}
if max_iter == 0 {
return Err(FdarError::InvalidParameter {
parameter: "max_iter",
message: "must be > 0".to_string(),
});
}
let mut peak_counts = Vec::with_capacity(lambdas.len());
for &lam in lambdas {
let result = karcher_mean(data, argvals, max_iter, tol, lam);
let count = count_peaks(&result.mean, 0.001);
peak_counts.push(count);
}
let persistence_pairs = build_persistence_pairs(&peak_counts);
let (best_start, best_end) = persistence_pairs
.iter()
.copied()
.max_by_key(|&(s, e)| {
let span = lambdas[e] - lambdas[s];
(span * 1e9) as u64
})
.unwrap_or((0, 0));
let optimal_index = (best_start + best_end) / 2;
let optimal_lambda = lambdas[optimal_index];
Ok(PersistenceDiagramResult {
lambdas: lambdas.to_vec(),
peak_counts,
persistence_pairs,
optimal_lambda,
optimal_index,
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_helpers::uniform_grid;
fn single_peak_data(n: usize, m: usize) -> (FdMatrix, Vec<f64>) {
let t = uniform_grid(m);
let mut data_vec = vec![0.0; n * m];
for i in 0..n {
let shift = 0.05 * i as f64;
for j in 0..m {
data_vec[i + j * n] = (std::f64::consts::PI * (t[j] + shift)).sin();
}
}
let data = FdMatrix::from_column_major(data_vec, n, m).unwrap();
(data, t)
}
#[test]
fn persistence_single_peak_stable() {
let (data, t) = single_peak_data(6, 31);
let lambdas = vec![0.0, 0.01, 0.1, 1.0];
let result = peak_persistence(&data, &t, &lambdas, 5, 1e-2).unwrap();
let count_one = result.peak_counts.iter().filter(|&&c| c == 1).count();
assert!(
count_one >= lambdas.len() / 2,
"Expected most peak counts to be 1, got {:?}",
result.peak_counts
);
}
#[test]
fn persistence_optimal_in_range() {
let (data, t) = single_peak_data(6, 31);
let lambdas = vec![0.0, 0.01, 0.1, 1.0, 10.0];
let result = peak_persistence(&data, &t, &lambdas, 5, 1e-2).unwrap();
assert!(
result.optimal_lambda >= lambdas[0],
"optimal_lambda {} below range",
result.optimal_lambda
);
assert!(
result.optimal_lambda <= *lambdas.last().unwrap(),
"optimal_lambda {} above range",
result.optimal_lambda
);
}
#[test]
fn persistence_peak_counts_length() {
let (data, t) = single_peak_data(6, 31);
let lambdas = vec![0.0, 0.5, 1.0];
let result = peak_persistence(&data, &t, &lambdas, 3, 1e-2).unwrap();
assert_eq!(result.peak_counts.len(), lambdas.len());
}
#[test]
fn persistence_rejects_empty_lambdas() {
let (data, t) = single_peak_data(4, 21);
let result = peak_persistence(&data, &t, &[], 5, 1e-3);
assert!(result.is_err(), "Empty lambdas should produce an error");
}
}