use crate::PeacoQCData;
use crate::error::{PeacoQCError, Result};
use crate::stats::density::KernelDensity;
use crate::stats::median;
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, PartialEq)]
pub struct PeakDetectionConfig {
pub events_per_bin: usize,
pub peak_removal: f64,
pub min_nr_bins_peakdetection: f64,
pub remove_zeros: bool,
}
impl Default for PeakDetectionConfig {
fn default() -> Self {
Self {
events_per_bin: 1000,
peak_removal: 1.0 / 3.0,
min_nr_bins_peakdetection: 10.0,
remove_zeros: false,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PeakInfo {
pub bin: usize,
pub peak_value: f64,
pub cluster: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChannelPeakFrame {
pub peaks: Vec<PeakInfo>,
}
pub fn determine_peaks_all_channels<T: PeacoQCData>(
fcs: &T,
channels: &[String],
config: &PeakDetectionConfig,
) -> Result<HashMap<String, ChannelPeakFrame>> {
let mut results = HashMap::new();
let n_events = fcs.n_events();
let breaks = create_breaks(n_events, config.events_per_bin);
let n_bins = breaks.len();
if n_bins == 0 {
return Err(PeacoQCError::InsufficientData {
min: config.events_per_bin,
actual: n_events,
});
}
eprintln!("Calculating peaks for {} channels...", channels.len());
let channel_data: Vec<(String, Vec<f64>)> = channels
.iter()
.filter_map(|ch| fcs.get_channel_f64(ch).ok().map(|data| (ch.clone(), data)))
.collect();
let channel_results: Vec<(String, Option<ChannelPeakFrame>)> = channel_data
.par_iter()
.map(|(channel, data)| {
let peak_frame = determine_channel_peaks_from_data(data, &breaks, config);
(channel.clone(), peak_frame)
})
.collect();
for (channel, frame) in channel_results {
if let Some(frame) = frame {
results.insert(channel, frame);
}
}
Ok(results)
}
pub fn create_breaks(n_events: usize, events_per_bin: usize) -> Vec<(usize, usize)> {
let overlap = (events_per_bin + 1) / 2;
let step = events_per_bin - overlap;
let mut breaks = Vec::new();
let mut start = 0;
while start < n_events {
let end = (start + events_per_bin).min(n_events);
breaks.push((start, end));
start += step;
}
breaks
}
fn determine_channel_peaks_from_data(
data: &[f64],
breaks: &[(usize, usize)],
config: &PeakDetectionConfig,
) -> Option<ChannelPeakFrame> {
let bin_peaks: Vec<Vec<f64>> = breaks
.par_iter()
.map(|(start, end)| {
let bin_data: Vec<f64> = data[*start..*end].to_vec();
let bin_data = if config.remove_zeros {
bin_data.into_iter().filter(|&x| x != 0.0).collect()
} else {
bin_data
};
if bin_data.len() < 3 {
return Vec::new();
}
let mut peaks = match KernelDensity::estimate(&bin_data, 1.0, 512) {
Ok(kde) => kde.find_peaks(config.peak_removal),
Err(_) => Vec::new(),
};
peaks.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
peaks
})
.collect();
let mut all_peaks: Vec<PeakInfo> = Vec::new();
for (bin_idx, peaks) in bin_peaks.iter().enumerate() {
for &peak_value in peaks {
all_peaks.push(PeakInfo {
bin: bin_idx,
peak_value,
cluster: 0, });
}
}
if all_peaks.is_empty() {
return None;
}
if cluster_peaks(&mut all_peaks, &bin_peaks, config).is_err() {
return None;
}
if remove_small_clusters(&mut all_peaks, breaks.len()).is_err() {
return None;
}
if all_peaks.is_empty() {
return None;
}
Some(ChannelPeakFrame { peaks: all_peaks })
}
fn cluster_peaks(
all_peaks: &mut [PeakInfo],
bin_peaks: &[Vec<f64>],
config: &PeakDetectionConfig,
) -> Result<()> {
let peak_counts: Vec<usize> = bin_peaks.iter().map(|p| p.len()).collect();
let mut count_freq: HashMap<usize, usize> = HashMap::new();
for &count in &peak_counts {
*count_freq.entry(count).or_insert(0) += 1;
}
let min_bins =
(config.min_nr_bins_peakdetection / 100.0 * peak_counts.len() as f64).ceil() as usize;
let most_common_count = count_freq
.iter()
.filter(|(_, freq)| *freq >= &min_bins)
.max_by_key(|(count, _)| *count)
.map(|(count, _)| *count)
.unwrap_or(1);
let mut reference_peaks: Vec<Vec<f64>> = Vec::new();
for peaks in bin_peaks {
if peaks.len() == most_common_count {
reference_peaks.push(peaks.clone());
}
}
if reference_peaks.is_empty() {
for peak in all_peaks.iter_mut() {
peak.cluster = 1;
}
return Ok(());
}
let n_clusters = most_common_count;
let mut cluster_medians: Vec<f64> = Vec::new();
for cluster_idx in 0..n_clusters {
let values: Vec<f64> = reference_peaks
.iter()
.filter_map(|peaks| peaks.get(cluster_idx).copied())
.collect();
if !values.is_empty() {
cluster_medians.push(median(&values)?);
}
}
if cluster_medians.is_empty() {
return Ok(());
}
for peak in all_peaks.iter_mut() {
let mut min_dist = f64::INFINITY;
let mut best_cluster = 0;
for (cluster_idx, &cluster_median) in cluster_medians.iter().enumerate() {
let dist = (peak.peak_value - cluster_median).abs();
if dist < min_dist {
min_dist = dist;
best_cluster = cluster_idx + 1; }
}
peak.cluster = best_cluster;
}
Ok(())
}
fn remove_small_clusters(all_peaks: &mut Vec<PeakInfo>, n_bins: usize) -> Result<()> {
let mut cluster_bin_counts: HashMap<usize, std::collections::HashSet<usize>> = HashMap::new();
for peak in all_peaks.iter() {
cluster_bin_counts
.entry(peak.cluster)
.or_insert_with(std::collections::HashSet::new)
.insert(peak.bin);
}
let min_bins = (n_bins as f64 * 0.5).ceil() as usize;
let clusters_to_keep: Vec<usize> = cluster_bin_counts
.iter()
.filter(|(_, bins)| bins.len() >= min_bins)
.map(|(cluster, _)| *cluster)
.collect();
all_peaks.retain(|peak| clusters_to_keep.contains(&peak.cluster));
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::fcs::SimpleFcs;
use polars::df;
use std::collections::HashMap as StdHashMap;
use std::sync::Arc;
#[test]
fn test_peak_detection_basic() {
let mut data = Vec::new();
for _ in 0..5000 {
data.push(100.0 + rand::random::<f64>() * 10.0);
}
let df = Arc::new(
df![
"FL1-A" => data,
]
.unwrap(),
);
let fcs = SimpleFcs {
data_frame: df,
parameter_metadata: StdHashMap::new(),
};
let config = PeakDetectionConfig {
events_per_bin: 1000,
..Default::default()
};
let result = determine_peaks_all_channels(&fcs, &["FL1-A".to_string()], &config);
assert!(result.is_ok());
}
}