use crate::core::error::{RedicatError, Result};
use crate::core::sparse::SparseOps;
use log::info;
use nalgebra_sparse::CsrMatrix;
use polars::prelude::*;
use rayon::prelude::*;
use super::anndata_ops::AnnDataContainer;
pub fn calculate_coverage(adata: &mut AnnDataContainer) -> Result<()> {
info!("Calculating coverage matrix with parallel sparse operations...");
let layer_names = if adata.layers.contains_key("A0") {
vec!["A0", "T0", "G0", "C0", "A1", "T1", "G1", "C1"]
} else {
vec!["A1", "T1", "G1", "C1"]
};
let matrices: Vec<&CsrMatrix<u32>> = layer_names
.iter()
.filter_map(|&name| adata.layers.get(name))
.collect();
if matrices.is_empty() {
let coverage = CsrMatrix::<u32>::zeros(adata.n_obs, adata.n_vars);
adata.layers.insert("coverage".to_string(), coverage);
info!("No base layers found, created empty coverage matrix");
return Ok(());
}
if matrices.len() == 1 {
let nnz = matrices[0].nnz();
adata.layers.insert("coverage".to_string(), matrices[0].clone());
info!("Coverage matrix calculated with {} non-zero elements", nnz);
return Ok(());
}
let coverage = SparseOps::parallel_sum_matrices(&matrices)?;
let nnz = coverage.nnz();
info!("Coverage matrix calculated with {} non-zero elements using parallel reduction", nnz);
adata.layers.insert("coverage".to_string(), coverage);
Ok(())
}
pub fn filter_sites_by_coverage(
mut adata: AnnDataContainer,
min_coverage: u32,
) -> Result<AnnDataContainer> {
info!("Filtering sites with min_coverage: {}", min_coverage);
calculate_coverage(&mut adata)?;
let site_coverage = adata
.compute_layer_col_sums("coverage")
.unwrap_or_else(|| vec![0; adata.n_vars]);
let filter_mask: Vec<bool> = adata
.var_names
.par_iter()
.zip(site_coverage.par_iter())
.map(|(name, &cov)| name.starts_with("chr") && cov >= min_coverage)
.collect();
let kept_sites = filter_mask.iter().filter(|&&x| x).count();
info!(
"Keeping {} out of {} sites ({}% retained)",
kept_sites,
adata.n_vars,
(kept_sites as f64 / adata.n_vars as f64 * 100.0) as u32
);
if kept_sites == 0 {
return Err(RedicatError::EmptyData(
"No sites passed coverage filter".to_string(),
));
}
apply_site_filter(adata, &filter_mask)
}
pub fn apply_site_filter(
mut adata: AnnDataContainer,
filter_mask: &[bool],
) -> Result<AnnDataContainer> {
let selected_indices: Vec<usize> = filter_mask
.par_iter()
.enumerate()
.filter_map(|(i, &keep)| if keep { Some(i) } else { None })
.collect();
if selected_indices.is_empty() {
return Err(RedicatError::EmptyData(
"No sites selected after filtering".to_string(),
));
}
let layer_names: Vec<String> = adata.layers.keys().cloned().collect();
for name in layer_names {
let matrix = adata.layers.remove(&name).unwrap();
let filtered = SparseOps::filter_columns_u32(&matrix, &selected_indices)?;
drop(matrix);
adata.layers.insert(name, filtered);
}
let filtered_var = filter_dataframe_by_indices(&adata.var, &selected_indices)?;
let filtered_var_names: Vec<String> = selected_indices
.par_iter()
.map(|&i| adata.var_names[i].clone())
.collect();
adata.var = filtered_var;
adata.n_vars = selected_indices.len();
adata.var_names = filtered_var_names;
info!("Site filtering completed: {} sites retained", adata.n_vars);
Ok(adata)
}
pub fn count_base_levels(mut adata: AnnDataContainer) -> Result<AnnDataContainer> {
info!("Counting base levels at each site...");
let base_counts: Vec<(char, Vec<u32>)> = ['A', 'T', 'G', 'C']
.par_iter()
.map(|&base| {
let layer_name = format!("{}1", base);
let counts = adata
.compute_layer_col_sums(&layer_name)
.unwrap_or_else(|| vec![0u32; adata.n_vars]);
(base, counts)
})
.collect();
let mut columns: Vec<Column> = Vec::with_capacity(5);
let mut coverage_totals = vec![0u32; adata.n_vars];
for (base, counts) in base_counts.into_iter() {
for (cov, value) in coverage_totals.iter_mut().zip(&counts) {
*cov = cov.saturating_add(*value);
}
let column_name = base.to_string();
columns.push(Series::new(column_name.into(), counts).into_column());
}
columns.push(Series::new("Coverage".into(), coverage_totals).into_column());
if !columns.is_empty() {
adata.var.hstack_mut(&columns)?;
}
info!("Base levels counted successfully");
Ok(adata)
}
fn filter_dataframe_by_indices(df: &DataFrame, indices: &[usize]) -> Result<DataFrame> {
if indices.is_empty() {
return Err(RedicatError::EmptyData(
"No indices provided for filtering".to_string(),
));
}
let idx_chunked = UInt32Chunked::from_vec(
"idx".into(),
indices.iter().map(|&idx| idx as u32).collect(),
);
df.take(&idx_chunked)
.map_err(|e| RedicatError::DataProcessing(format!("Failed to filter DataFrame: {}", e)))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::sparse::SparseOps;
use polars::prelude::{DataFrame, Series};
fn make_matrix(
nrows: usize,
ncols: usize,
triplets: Vec<(usize, usize, u32)>,
) -> CsrMatrix<u32> {
SparseOps::from_triplets_u32(nrows, ncols, triplets).unwrap()
}
fn set_var_names(adata: &mut AnnDataContainer, names: Vec<&str>) {
adata.var_names = names.iter().map(|s| s.to_string()).collect();
adata.var = DataFrame::new(vec![Series::new("var_names".into(), adata.var_names.clone()).into()])
.unwrap();
}
#[test]
fn calculate_coverage_sums_existing_layers() {
let mut adata = AnnDataContainer::new(2, 3);
adata.layers.insert(
"A1".to_string(),
make_matrix(2, 3, vec![(0, 0, 2), (1, 2, 1)]),
);
adata.layers.insert(
"T1".to_string(),
make_matrix(2, 3, vec![(0, 0, 3), (0, 1, 4)]),
);
calculate_coverage(&mut adata).unwrap();
let coverage = adata.layers.get("coverage").unwrap();
assert_eq!(SparseOps::compute_col_sums(coverage), vec![5, 4, 1]);
assert_eq!(SparseOps::compute_row_sums(coverage), vec![9, 1]);
}
#[test]
fn filter_sites_by_coverage_applies_chr_prefix_and_threshold() {
let mut adata = AnnDataContainer::new(2, 3);
set_var_names(&mut adata, vec!["chr22:10", "chr1:20", "site_3"]);
adata.layers.insert(
"A1".to_string(),
make_matrix(2, 3, vec![(0, 0, 2), (0, 2, 100), (1, 0, 3)]),
);
adata.layers.insert("T1".to_string(), make_matrix(2, 3, vec![(1, 1, 3)]));
let filtered = filter_sites_by_coverage(adata, 4).unwrap();
assert_eq!(filtered.n_vars, 1);
assert_eq!(filtered.var_names, vec!["chr22:10".to_string()]);
for matrix in filtered.layers.values() {
assert_eq!(matrix.ncols(), 1);
}
}
#[test]
fn filter_sites_by_coverage_returns_empty_error() {
let mut adata = AnnDataContainer::new(2, 2);
set_var_names(&mut adata, vec!["site1", "site2"]);
adata
.layers
.insert("A1".to_string(), make_matrix(2, 2, vec![(0, 0, 1), (1, 1, 1)]));
let err = filter_sites_by_coverage(adata, 10).unwrap_err();
assert!(format!("{}", err).contains("No sites passed coverage filter"));
}
#[test]
fn apply_site_filter_filters_layers_var_and_names() {
let mut adata = AnnDataContainer::new(2, 4);
set_var_names(&mut adata, vec!["chr1:1", "chr1:2", "chr1:3", "chr1:4"]);
adata.layers.insert("A1".to_string(), make_matrix(2, 4, vec![
(0, 0, 1), (0, 1, 2), (0, 2, 3), (0, 3, 4),
(1, 0, 5), (1, 1, 6), (1, 2, 7), (1, 3, 8),
]));
let mask = vec![true, false, true, false];
let filtered = apply_site_filter(adata, &mask).unwrap();
assert_eq!(filtered.n_vars, 2);
assert_eq!(filtered.var_names, vec!["chr1:1", "chr1:3"]);
assert_eq!(filtered.var.height(), 2);
let layer = filtered.layers.get("A1").unwrap();
assert_eq!(layer.ncols(), 2);
let row0 = layer.row(0);
let vals: Vec<u32> = row0.values().to_vec();
assert_eq!(vals, vec![1, 3]);
}
#[test]
fn apply_site_filter_all_false_returns_error() {
let adata = AnnDataContainer::new(2, 3);
let mask = vec![false, false, false];
assert!(apply_site_filter(adata, &mask).is_err());
}
#[test]
fn apply_site_filter_all_true_keeps_everything() {
let mut adata = AnnDataContainer::new(2, 3);
set_var_names(&mut adata, vec!["chr1:1", "chr1:2", "chr1:3"]);
adata.layers.insert("G1".to_string(), make_matrix(2, 3, vec![(0, 0, 42)]));
let mask = vec![true, true, true];
let filtered = apply_site_filter(adata, &mask).unwrap();
assert_eq!(filtered.n_vars, 3);
assert_eq!(filtered.var_names.len(), 3);
}
#[test]
fn calculate_coverage_handles_stranded_data() {
let mut adata = AnnDataContainer::new(2, 2);
adata.layers.insert("A0".to_string(), make_matrix(2, 2, vec![(0, 0, 1)]));
adata.layers.insert("A1".to_string(), make_matrix(2, 2, vec![(0, 0, 2)]));
adata.layers.insert("T0".to_string(), make_matrix(2, 2, vec![(1, 1, 3)]));
adata.layers.insert("T1".to_string(), make_matrix(2, 2, vec![(1, 1, 4)]));
adata.layers.insert("G0".to_string(), make_matrix(2, 2, vec![]));
adata.layers.insert("G1".to_string(), make_matrix(2, 2, vec![]));
adata.layers.insert("C0".to_string(), make_matrix(2, 2, vec![]));
adata.layers.insert("C1".to_string(), make_matrix(2, 2, vec![]));
calculate_coverage(&mut adata).unwrap();
let cov = adata.layers.get("coverage").unwrap();
let col_sums = SparseOps::compute_col_sums(cov);
assert_eq!(col_sums, vec![3, 7]);
}
#[test]
fn calculate_coverage_no_layers_creates_empty() {
let mut adata = AnnDataContainer::new(3, 2);
calculate_coverage(&mut adata).unwrap();
let cov = adata.layers.get("coverage").unwrap();
assert_eq!(cov.nnz(), 0);
assert_eq!(cov.nrows(), 3);
assert_eq!(cov.ncols(), 2);
}
#[test]
fn count_base_levels_uses_saturating_add() {
let mut adata = AnnDataContainer::new(1, 1);
adata.layers.insert("A1".to_string(), make_matrix(1, 1, vec![(0, 0, u32::MAX - 1)]));
adata.layers.insert("T1".to_string(), make_matrix(1, 1, vec![(0, 0, 5)]));
adata.layers.insert("G1".to_string(), make_matrix(1, 1, vec![]));
adata.layers.insert("C1".to_string(), make_matrix(1, 1, vec![]));
let out = count_base_levels(adata).unwrap();
let cov = out.var.column("Coverage").unwrap().u32().unwrap();
assert_eq!(cov.get(0), Some(u32::MAX));
}
#[test]
fn count_base_levels_adds_base_columns_and_total_coverage() {
let mut adata = AnnDataContainer::new(2, 3);
adata.layers.insert("A1".to_string(), make_matrix(2, 3, vec![(0, 0, 2)]));
adata
.layers
.insert("T1".to_string(), make_matrix(2, 3, vec![(1, 0, 1), (1, 1, 2)]));
adata
.layers
.insert("G1".to_string(), make_matrix(2, 3, vec![(0, 2, 5)]));
adata.layers.insert("C1".to_string(), make_matrix(2, 3, vec![]));
let out = count_base_levels(adata).unwrap();
let cov = out.var.column("Coverage").unwrap().u32().unwrap();
assert_eq!(cov.get(0), Some(3));
assert_eq!(cov.get(1), Some(2));
assert_eq!(cov.get(2), Some(5));
assert!(out.var.column("A").is_ok());
assert!(out.var.column("T").is_ok());
assert!(out.var.column("G").is_ok());
assert!(out.var.column("C").is_ok());
}
}