use crate::error::{PeacoQCError, Result};
use crate::qc::peaks::ChannelPeakFrame;
use rayon::prelude::*;
use std::collections::HashMap;
#[cfg(feature = "gpu")]
use crate::gpu::{build_feature_matrix_gpu, is_gpu_available};
#[derive(Debug, Clone, PartialEq)]
pub struct IsolationTreeConfig {
pub it_limit: f64,
pub force_it: usize,
}
impl Default for IsolationTreeConfig {
fn default() -> Self {
Self {
it_limit: 0.6,
force_it: 150,
}
}
}
#[derive(Debug, Clone)]
pub struct IsolationTreeResult {
pub outlier_bins: Vec<bool>,
pub tree: Vec<TreeNode>,
pub stats: TreeStats,
}
#[derive(Debug, Clone)]
pub struct TreeNode {
pub id: usize,
pub left_child: Option<usize>,
pub right_child: Option<usize>,
pub gain: Option<f64>,
pub split_column: Option<String>,
pub split_value: Option<f64>,
pub depth: usize,
pub path_length: Option<f64>,
pub n_datapoints: usize,
}
#[derive(Debug, Clone)]
pub struct TreeStats {
pub n_bins: usize,
pub n_features: usize,
pub max_depth: usize,
pub largest_node_size: usize,
pub largest_node_id: usize,
}
const EULER_MASCHERONI: f64 = 0.5772156649;
fn avg_path_length(n: usize) -> f64 {
if n <= 1 {
0.0
} else {
let n_f = n as f64;
2.0 * ((n_f - 1.0).ln() + EULER_MASCHERONI) - (2.0 * (n_f - 1.0)) / n_f
}
}
fn std_dev(data: &[f64]) -> f64 {
if data.len() <= 1 {
return 0.0;
}
let n = data.len() as f64;
let mean = data.iter().sum::<f64>() / n;
let variance = data.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / (n - 1.0);
variance.sqrt()
}
pub fn isolation_tree_detect(
peak_results: &HashMap<String, ChannelPeakFrame>,
n_bins: usize,
config: &IsolationTreeConfig,
) -> Result<IsolationTreeResult> {
if n_bins < config.force_it {
return Err(PeacoQCError::InsufficientData {
min: config.force_it,
actual: n_bins,
});
}
if peak_results.is_empty() {
return Err(PeacoQCError::NoPeaksDetected);
}
#[cfg(feature = "gpu")]
let (feature_matrix, feature_names) = if is_gpu_available() {
build_feature_matrix_gpu(peak_results, n_bins)?
} else {
build_feature_matrix(peak_results, n_bins)?
};
#[cfg(not(feature = "gpu"))]
let (feature_matrix, feature_names) = build_feature_matrix(peak_results, n_bins)?;
let n_features = feature_matrix[0].len();
eprintln!(
"Running SD-based Isolation Tree: {} bins, {} features (clusters)",
n_bins, n_features
);
let (tree, selection) =
build_isolation_tree_sd(&feature_matrix, &feature_names, config.it_limit)?;
let largest_node = tree
.iter()
.filter(|node| node.path_length.is_some())
.max_by_key(|node| node.n_datapoints)
.ok_or_else(|| PeacoQCError::StatsError("No leaf nodes found".to_string()))?;
let largest_node_id = largest_node.id;
let largest_node_size = largest_node.n_datapoints;
let good_bins = &selection[largest_node_id];
let outlier_bins: Vec<bool> = good_bins.iter().map(|&in_node| !in_node).collect();
let n_outliers = outlier_bins.iter().filter(|&&x| x).count();
eprintln!(
"IT detected {} outlier bins ({:.1}%), largest node has {} bins",
n_outliers,
(n_outliers as f64 / n_bins as f64) * 100.0,
largest_node_size
);
let max_depth = tree.iter().map(|n| n.depth).max().unwrap_or(0);
Ok(IsolationTreeResult {
outlier_bins,
tree,
stats: TreeStats {
n_bins,
n_features,
max_depth,
largest_node_size,
largest_node_id,
},
})
}
pub fn build_feature_matrix(
peak_results: &HashMap<String, ChannelPeakFrame>,
n_bins: usize,
) -> Result<(Vec<Vec<f64>>, Vec<String>)> {
let mut channel_names: Vec<String> = peak_results.keys().cloned().collect();
channel_names.sort();
let mut feature_names = Vec::new();
let mut cluster_data: Vec<(String, usize, Vec<(usize, f64)>)> = Vec::new();
for channel in &channel_names {
let peak_frame = &peak_results[channel];
let mut clusters: HashMap<usize, Vec<(usize, f64)>> = HashMap::new();
for peak in &peak_frame.peaks {
clusters
.entry(peak.cluster)
.or_default()
.push((peak.bin, peak.peak_value));
}
let mut cluster_ids: Vec<usize> = clusters.keys().cloned().collect();
cluster_ids.sort();
for cluster_id in cluster_ids {
let peaks_in_cluster = &clusters[&cluster_id];
feature_names.push(format!("{}_cluster_{}", channel, cluster_id));
cluster_data.push((channel.clone(), cluster_id, peaks_in_cluster.clone()));
}
}
let n_features = feature_names.len();
let mut matrix = vec![vec![0.0; n_features]; n_bins];
for (feature_idx, (_, _, peaks_in_cluster)) in cluster_data.iter().enumerate() {
let peak_values: Vec<f64> = peaks_in_cluster.iter().map(|(_, v)| *v).collect();
let cluster_median = crate::stats::median(&peak_values)?;
for bin_idx in 0..n_bins {
matrix[bin_idx][feature_idx] = cluster_median;
}
for (bin_idx, peak_value) in peaks_in_cluster {
if *bin_idx < n_bins {
matrix[*bin_idx][feature_idx] = *peak_value;
}
}
}
Ok((matrix, feature_names))
}
fn build_isolation_tree_sd(
data: &[Vec<f64>],
feature_names: &[String],
initial_gain_limit: f64,
) -> Result<(Vec<TreeNode>, Vec<Vec<bool>>)> {
let n_bins = data.len();
let max_depth = (n_bins as f64).log2().ceil() as usize;
let mut tree = vec![TreeNode {
id: 0,
left_child: None,
right_child: None,
gain: None,
split_column: None,
split_value: None,
depth: 0,
path_length: None,
n_datapoints: n_bins,
}];
let mut selection: Vec<Vec<bool>> = vec![vec![true; n_bins]];
let mut nodes_to_split: Vec<usize> = vec![0];
let mut gain_limit = initial_gain_limit;
while let Some(node_idx) = nodes_to_split.pop() {
let node = &tree[node_idx];
let depth = node.depth;
let rows: Vec<usize> = selection[node_idx]
.iter()
.enumerate()
.filter_map(|(i, &in_node)| if in_node { Some(i) } else { None })
.collect();
if rows.len() <= 3 || depth >= max_depth {
let path_length = avg_path_length(rows.len()) + depth as f64;
tree[node_idx].path_length = Some(path_length);
tree[node_idx].n_datapoints = rows.len();
continue;
}
let best_split = find_best_split_parallel(data, &rows, feature_names, gain_limit);
match best_split {
Some((col_idx, split_value, gain)) => {
let left_rows: Vec<usize> = rows
.iter()
.filter(|&&r| data[r][col_idx] <= split_value)
.copied()
.collect();
let right_rows: Vec<usize> = rows
.iter()
.filter(|&&r| data[r][col_idx] > split_value)
.copied()
.collect();
if left_rows.is_empty()
|| right_rows.is_empty()
|| left_rows.len() == rows.len()
|| right_rows.len() == rows.len()
{
let path_length = avg_path_length(rows.len()) + depth as f64;
tree[node_idx].path_length = Some(path_length);
tree[node_idx].n_datapoints = rows.len();
continue;
}
let left_id = tree.len();
let right_id = tree.len() + 1;
tree[node_idx].left_child = Some(left_id);
tree[node_idx].right_child = Some(right_id);
tree[node_idx].gain = Some(gain);
tree[node_idx].split_column = Some(feature_names[col_idx].clone());
tree[node_idx].split_value = Some(split_value);
tree[node_idx].n_datapoints = rows.len();
gain_limit = gain;
let mut left_selection = vec![false; n_bins];
let mut right_selection = vec![false; n_bins];
for &r in &left_rows {
left_selection[r] = true;
}
for &r in &right_rows {
right_selection[r] = true;
}
tree.push(TreeNode {
id: left_id,
left_child: None,
right_child: None,
gain: None,
split_column: None,
split_value: None,
depth: depth + 1,
path_length: None,
n_datapoints: left_rows.len(),
});
tree.push(TreeNode {
id: right_id,
left_child: None,
right_child: None,
gain: None,
split_column: None,
split_value: None,
depth: depth + 1,
path_length: None,
n_datapoints: right_rows.len(),
});
selection.push(left_selection);
selection.push(right_selection);
nodes_to_split.push(left_id);
nodes_to_split.push(right_id);
}
None => {
let path_length = avg_path_length(rows.len()) + depth as f64;
tree[node_idx].path_length = Some(path_length);
tree[node_idx].n_datapoints = rows.len();
}
}
}
Ok((tree, selection))
}
fn find_best_split_parallel(
data: &[Vec<f64>],
rows: &[usize],
_feature_names: &[String],
gain_limit: f64,
) -> Option<(usize, f64, f64)> {
let n_features = data[0].len();
let column_results: Vec<Option<(usize, f64, f64)>> = (0..n_features)
.into_par_iter()
.map(|col| find_best_split_for_column(data, rows, col, gain_limit))
.collect();
column_results
.into_iter()
.flatten()
.max_by(|a, b| a.2.partial_cmp(&b.2).unwrap_or(std::cmp::Ordering::Equal))
}
fn find_best_split_for_column(
data: &[Vec<f64>],
rows: &[usize],
col: usize,
gain_limit: f64,
) -> Option<(usize, f64, f64)> {
let mut values: Vec<f64> = rows.iter().map(|&r| data[r][col]).collect();
values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let n = values.len();
if n < 2 {
return None;
}
let base_sd = std_dev(&values);
if base_sd == 0.0 {
return None;
}
let mut best_gain = gain_limit;
let mut best_split_value = None;
for i in 1..n {
let left = &values[..i];
let right = &values[i..];
let sd_1 = if i == 1 { 0.0 } else { std_dev(left) };
let sd_2 = if i == n - 1 { 0.0 } else { std_dev(right) };
let mean_child_sd = (sd_1 + sd_2) / 2.0;
let gain = (base_sd - mean_child_sd) / base_sd;
if gain.is_finite() && gain >= best_gain {
best_gain = gain;
best_split_value = Some(values[i - 1]);
}
}
best_split_value.map(|v| (col, v, best_gain))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::qc::peaks::PeakInfo;
#[test]
fn test_avg_path_length() {
assert!((avg_path_length(1) - 0.0).abs() < 1e-6);
let apl_2 = avg_path_length(2);
assert!((apl_2 - 0.1544).abs() < 0.02, "avgPL(2) = {}", apl_2);
assert!(avg_path_length(100) > avg_path_length(10));
let apl_10 = avg_path_length(10);
let apl_100 = avg_path_length(100);
assert!(apl_10 > 0.0, "avgPL(10) should be positive: {}", apl_10);
assert!(apl_100 > apl_10, "avgPL should increase with n");
}
#[test]
fn test_std_dev() {
let data = vec![2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0];
let sd = std_dev(&data);
assert!((sd - 2.138).abs() < 0.01, "sd = {}", sd);
}
#[test]
fn test_isolation_tree_basic() {
let mut peaks = Vec::new();
for bin in 0..200 {
let peak_value = if bin >= 50 && bin < 60 {
1000.0 } else {
100.0 + (bin as f64) * 0.5 };
peaks.push(PeakInfo {
bin,
peak_value,
cluster: 1,
});
}
let mut peak_results = HashMap::new();
peak_results.insert("FL1-A".to_string(), ChannelPeakFrame { peaks });
let config = IsolationTreeConfig {
force_it: 50,
it_limit: 0.6,
};
let result = isolation_tree_detect(&peak_results, 200, &config).unwrap();
let n_good = result.outlier_bins.iter().filter(|&&x| !x).count();
assert!(
n_good > 100,
"Most bins should be good, but only {} are",
n_good
);
}
#[test]
fn test_build_feature_matrix_old_behavior() {
let mut peaks1 = Vec::new();
let mut peaks2 = Vec::new();
for bin in 0..5 {
peaks1.push(PeakInfo {
bin,
peak_value: 100.0 + bin as f64,
cluster: 1,
});
peaks2.push(PeakInfo {
bin,
peak_value: 200.0 + bin as f64,
cluster: 1,
});
}
let mut peak_results = HashMap::new();
peak_results.insert("FL1-A".to_string(), ChannelPeakFrame { peaks: peaks1 });
peak_results.insert("FL2-A".to_string(), ChannelPeakFrame { peaks: peaks2 });
let (matrix, names) = build_feature_matrix(&peak_results, 5).unwrap();
assert_eq!(matrix.len(), 5); assert_eq!(
matrix[0].len(),
2,
"Should have 2 features (2 channels × 1 cluster each)"
);
assert_eq!(names.len(), 2);
assert!(matrix[0][0] > 0.0);
assert!(matrix[0][1] > 0.0);
assert!(
names[0].contains("_cluster_"),
"Feature name should contain '_cluster_'"
);
assert!(
names[1].contains("_cluster_"),
"Feature name should contain '_cluster_'"
);
}
}