use super::anndata_ops::AnnDataContainer;
use super::base_matrix::*;
use super::{EditingType, ReferenceGenome};
use crate::core::error::{RedicatError, Result};
use crate::core::sparse::SparseOps;
use log::info;
use nalgebra_sparse::CsrMatrix;
use polars::prelude::*;
use rayon::prelude::*;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
pub fn calculate_ref_alt_matrices(
mut adata: AnnDataContainer,
editing_type: &EditingType,
) -> Result<AnnDataContainer> {
info!(
"Calculating strand-aware ref/alt matrices for editing type: {:?}",
editing_type
);
adata = filter_by_editing_type_strand_aware(adata, editing_type)?;
if adata.n_vars == 0 {
return Err(RedicatError::EmptyData(
"No sites remain after strand-aware editing type filtering".to_string(),
));
}
let ref_bases = extract_reference_bases(&adata.var)?;
let (ref_matrix, alt_matrix, others_matrix) =
compute_editing_matrices_vectorized(&mut adata, &ref_bases, editing_type)?;
adata.layers.insert("ref".to_string(), ref_matrix);
adata.layers.insert("alt".to_string(), alt_matrix.clone());
adata.layers.insert("others".to_string(), others_matrix);
adata.x = Some(convert_u32_to_f64_csr(&alt_matrix));
adata = calculate_observation_sums_vectorized(adata)?;
info!("Strand-aware ref/alt matrices calculated using vectorized operations");
Ok(adata)
}
fn compute_editing_matrices_vectorized(
adata: &mut AnnDataContainer,
ref_bases: &[char],
editing_type: &EditingType,
) -> Result<(CsrMatrix<u32>, CsrMatrix<u32>, CsrMatrix<u32>)> {
info!("Computing editing matrices with optimized mask-based sparse operations");
let base_matrices = collect_strand_aware_base_layers(adata)?;
if base_matrices.is_empty() {
return Err(RedicatError::DataProcessing(
"No base matrices found for editing calculation".to_string(),
));
}
let n_vars = adata.n_vars;
let (ref_masks, alt_masks, others_masks) = build_onehot_masks(ref_bases, editing_type, n_vars);
let base_contributions: Vec<_> = base_matrices
.par_iter()
.map(|(base, base_matrix)| {
info!(
"Processing base {} with {} non-zeros",
base,
base_matrix.nnz()
);
let ref_mask = ref_masks.get(base).expect("Missing ref mask");
let alt_mask = alt_masks.get(base).expect("Missing alt mask");
let others_mask = others_masks.get(base).expect("Missing others mask");
let ref_contribution = apply_column_mask(base_matrix, ref_mask);
let alt_contribution = apply_column_mask(base_matrix, alt_mask);
let others_contribution = apply_column_mask(base_matrix, others_mask);
(ref_contribution, alt_contribution, others_contribution)
})
.collect();
let ref_contribs: Vec<&CsrMatrix<u32>> = base_contributions.iter().map(|(r, _, _)| r).collect();
let alt_contribs: Vec<&CsrMatrix<u32>> = base_contributions.iter().map(|(_, a, _)| a).collect();
let others_contribs: Vec<&CsrMatrix<u32>> = base_contributions.iter().map(|(_, _, o)| o).collect();
let ref_matrix = SparseOps::parallel_sum_matrices(&ref_contribs)?;
let alt_matrix = SparseOps::parallel_sum_matrices(&alt_contribs)?;
let others_matrix = SparseOps::parallel_sum_matrices(&others_contribs)?;
info!(
"Completed optimized editing matrix computation: ref_nnz={}, alt_nnz={}, others_nnz={}",
ref_matrix.nnz(),
alt_matrix.nnz(),
others_matrix.nnz()
);
Ok((ref_matrix, alt_matrix, others_matrix))
}
fn build_onehot_masks(
ref_bases: &[char],
editing_type: &EditingType,
n_vars: usize,
) -> (
HashMap<char, Vec<bool>>,
HashMap<char, Vec<bool>>,
HashMap<char, Vec<bool>>,
) {
let bases = ['A', 'T', 'G', 'C'];
let mut ref_masks: HashMap<char, Vec<bool>> = HashMap::new();
let mut alt_masks: HashMap<char, Vec<bool>> = HashMap::new();
let mut others_masks: HashMap<char, Vec<bool>> = HashMap::new();
for &base in &bases {
ref_masks.insert(base, vec![false; n_vars]);
alt_masks.insert(base, vec![false; n_vars]);
others_masks.insert(base, vec![false; n_vars]);
}
let position_data: Vec<_> = ref_bases
.par_iter()
.enumerate()
.map(|(var_idx, &ref_base)| {
let alt_base = editing_type.get_alt_base_for_ref(ref_base);
(var_idx, ref_base, alt_base)
})
.collect();
for (var_idx, ref_base, alt_base) in position_data {
if let Some(mask) = ref_masks.get_mut(&ref_base) {
mask[var_idx] = true;
}
if alt_base != 'N' {
if let Some(mask) = alt_masks.get_mut(&alt_base) {
mask[var_idx] = true;
}
}
for &base in &bases {
if base != ref_base && base != alt_base {
if let Some(mask) = others_masks.get_mut(&base) {
mask[var_idx] = true;
}
}
}
}
(ref_masks, alt_masks, others_masks)
}
fn apply_column_mask(matrix: &CsrMatrix<u32>, col_mask: &[bool]) -> CsrMatrix<u32> {
let n_rows = matrix.nrows();
let n_cols = matrix.ncols();
let counts: Vec<usize> = (0..n_rows)
.into_par_iter()
.map(|row_idx| {
let row = matrix.row(row_idx);
row.col_indices()
.iter()
.zip(row.values())
.filter(|(&c, &v)| c < col_mask.len() && col_mask[c] && v > 0)
.count()
})
.collect();
let mut row_offsets = Vec::with_capacity(n_rows + 1);
row_offsets.push(0usize);
for &cnt in &counts {
row_offsets.push(row_offsets.last().unwrap() + cnt);
}
let total_nnz = *row_offsets.last().unwrap();
if total_nnz == 0 {
return CsrMatrix::zeros(n_rows, n_cols);
}
let mut col_indices = vec![0usize; total_nnz];
let mut values = vec![0u32; total_nnz];
struct SendPtr<T>(*mut T);
unsafe impl<T> Send for SendPtr<T> {}
unsafe impl<T> Sync for SendPtr<T> {}
let col_ptr = SendPtr(col_indices.as_mut_ptr());
let val_ptr = SendPtr(values.as_mut_ptr());
(0..n_rows).into_par_iter().for_each(|row_idx| {
let row = matrix.row(row_idx);
let start = row_offsets[row_idx];
let mut pos = 0usize;
for (&col_idx, &value) in row.col_indices().iter().zip(row.values()) {
if col_idx < col_mask.len() && col_mask[col_idx] && value > 0 {
unsafe {
*col_ptr.0.add(start + pos) = col_idx;
*val_ptr.0.add(start + pos) = value;
}
pos += 1;
}
}
debug_assert_eq!(pos, counts[row_idx]);
});
CsrMatrix::try_from_csr_data(n_rows, n_cols, row_offsets, col_indices, values)
.unwrap_or_else(|_| CsrMatrix::zeros(n_rows, n_cols))
}
fn collect_strand_aware_base_layers(
adata: &mut AnnDataContainer,
) -> Result<Vec<(char, CsrMatrix<u32>)>> {
let mut combined_layers = Vec::with_capacity(4);
for &base in &['A', 'T', 'G', 'C'] {
let pos_layer = format!("{}1", base);
let neg_layer = format!("{}0", base);
let pos = adata.layers.remove(&pos_layer);
let neg = adata.layers.remove(&neg_layer);
match (pos, neg) {
(None, None) => continue,
(Some(matrix), None) | (None, Some(matrix)) => {
combined_layers.push((base, matrix))
}
(Some(pos_matrix), Some(neg_matrix)) => {
let summed = SparseOps::add_matrices(&pos_matrix, &neg_matrix)?;
combined_layers.push((base, summed));
}
}
}
Ok(combined_layers)
}
fn calculate_observation_sums_vectorized(mut adata: AnnDataContainer) -> Result<AnnDataContainer> {
info!("Calculating observation-level sums using vectorized operations");
let combined_mask = combined_filter_mask(&adata.var)?;
let passing_site_count = combined_mask.iter().filter(|&&flag| flag).count();
info!(
"{} sites contribute to observation-level metrics",
passing_site_count
);
let layer_sums: Vec<(String, Vec<u32>)> = ["ref", "alt", "others"]
.par_iter()
.filter_map(|&layer_name| {
adata.layers.get(layer_name).map(|matrix| {
if passing_site_count == 0 {
(layer_name.to_string(), vec![0; adata.n_obs])
} else {
(
layer_name.to_string(),
SparseOps::compute_masked_row_sums(matrix, &combined_mask),
)
}
})
})
.collect();
if !layer_sums.is_empty() {
for (layer_name, _) in &layer_sums {
let _ = adata.obs.drop_in_place(layer_name);
}
let columns: Vec<Column> = layer_sums
.into_iter()
.map(|(layer_name, sums)| Series::new(layer_name.into(), sums).into_column())
.collect();
adata.obs.hstack_mut(&columns)?;
}
info!("Vectorized observation-level sums calculated");
Ok(adata)
}
fn combined_filter_mask(var_df: &DataFrame) -> Result<Vec<bool>> {
let editing_mask = bool_mask_from_column(var_df, "is_editing_site")?;
let filter_pass_mask = bool_mask_from_column(var_df, "filter_pass")?;
if editing_mask.len() != filter_pass_mask.len() {
return Err(RedicatError::DataProcessing(
"Mismatched mask lengths for editing site filters".to_string(),
));
}
Ok(editing_mask
.into_par_iter()
.zip(filter_pass_mask.into_par_iter())
.map(|(is_editing, filter_pass)| is_editing && filter_pass)
.collect())
}
fn bool_mask_from_column(var_df: &DataFrame, column: &str) -> Result<Vec<bool>> {
let series = var_df.column(column).map_err(|e| {
RedicatError::DataProcessing(format!(
"Expected column '{}' for filtering but it was missing: {}",
column, e
))
})?;
let bool_chunked = series.bool().map_err(|_| {
RedicatError::DataProcessing(format!(
"Expected boolean column '{}' for filtering",
column
))
})?;
Ok(bool_chunked
.into_iter()
.map(|value| value.unwrap_or(false))
.collect())
}
pub fn annotate_variants_pipeline(
adata: AnnDataContainer,
editing_sites: Arc<HashSet<String>>,
reference: Arc<ReferenceGenome>,
editing_type: &EditingType,
max_other_threshold: f32,
min_edited_threshold: f32,
min_ref_threshold: f32,
min_coverage: u32,
) -> Result<AnnDataContainer> {
info!("Starting variant annotation pipeline...");
let mut adata = adata;
adata = mark_editing_sites(adata, &editing_sites)?;
adata = filter_sites_by_coverage(adata, min_coverage)?;
adata = count_base_levels(adata)?;
adata = add_reference_bases(adata, reference)?;
adata = apply_mismatch_filtering(
adata,
editing_type,
max_other_threshold,
min_edited_threshold,
min_ref_threshold,
min_coverage,
)?;
info!("Variant annotation completed");
Ok(adata)
}
pub fn calculate_cei(mut adata: AnnDataContainer) -> Result<AnnDataContainer> {
info!("Calculating Cell Editing Index (CEI)...");
let cei_expr = col("alt").cast(DataType::Float32)
/ (col("ref").cast(DataType::Float32) + col("alt").cast(DataType::Float32));
let cei_series = adata
.obs
.clone()
.lazy()
.with_columns([cei_expr.fill_null(0.0).alias("CEI")])
.collect()?
.column("CEI")?
.clone();
let _ = adata.obs.drop_in_place("CEI");
adata.obs.hstack_mut(&[cei_series.into_column()])?;
info!("CEI calculated");
Ok(adata)
}
pub fn calculate_site_mismatch_stats(
mut adata: AnnDataContainer,
ref_base: char,
alt_base: char,
) -> Result<AnnDataContainer> {
info!(
"Calculating site-level mismatch stats for {}>{} using efficient sparse column sums",
ref_base, alt_base
);
let ref_layer = adata
.layers
.get("ref")
.ok_or_else(|| RedicatError::DataProcessing("Missing 'ref' layer".to_string()))?;
let alt_layer = adata
.layers
.get("alt")
.ok_or_else(|| RedicatError::DataProcessing("Missing 'alt' layer".to_string()))?;
let others_layer = adata
.layers
.get("others")
.ok_or_else(|| RedicatError::DataProcessing("Missing 'others' layer".to_string()))?;
let ref_counts = SparseOps::compute_col_sums(ref_layer);
let alt_counts = SparseOps::compute_col_sums(alt_layer);
let others_counts = SparseOps::compute_col_sums(others_layer);
let ref_col_name = format!("{}{}_ref", ref_base, alt_base);
let alt_col_name = format!("{}{}_alt", ref_base, alt_base);
let others_col_name = format!("{}{}_others", ref_base, alt_base);
let mismatch_columns: Vec<Column> = vec![
Series::new(ref_col_name.into(), ref_counts).into_column(),
Series::new(alt_col_name.into(), alt_counts).into_column(),
Series::new(others_col_name.into(), others_counts).into_column(),
];
adata.var.hstack_mut(&mismatch_columns)?;
info!("Site-level mismatch stats calculated using efficient sparse column sums");
Ok(adata)
}
fn filter_by_editing_type_strand_aware(
adata: AnnDataContainer,
editing_type: &EditingType,
) -> Result<AnnDataContainer> {
info!(
"Filtering sites by strand-aware editing type: {:?}",
editing_type
);
let allowed_ref_bases = editing_type.get_strand_aware_ref_bases();
let ref_col = adata.var.column("ref")?;
let filter_mask: Vec<bool> = ref_col
.str()?
.par_iter()
.map(|opt_str| {
opt_str
.and_then(|s| s.chars().next())
.map(|c| {
let base = c.to_ascii_uppercase();
allowed_ref_bases.contains(&base)
})
.unwrap_or(false)
})
.collect();
let kept_count = filter_mask.par_iter().filter(|&&x| x).count();
info!(
"Keeping {} sites after strand-aware editing type filtering",
kept_count
);
apply_site_filter(adata, &filter_mask)
}
fn extract_reference_bases(var_df: &DataFrame) -> Result<Vec<char>> {
let ref_col = var_df.column("ref")?;
let ref_bases: Vec<char> = ref_col
.str()?
.par_iter()
.map(|opt_str| {
opt_str
.and_then(|s| s.chars().next())
.map(|c| c.to_ascii_uppercase())
.unwrap_or('N')
})
.collect();
Ok(ref_bases)
}
fn convert_u32_to_f64_csr(matrix: &CsrMatrix<u32>) -> CsrMatrix<f64> {
let (row_offsets, col_indices, values) = matrix.csr_data();
let values_f64: Vec<f64> = values.par_iter().map(|&x| x as f64).collect();
CsrMatrix::try_from_csr_data(
matrix.nrows(),
matrix.ncols(),
row_offsets.to_vec(),
col_indices.to_vec(),
values_f64,
)
.expect("Failed to convert u32 to f64 CSR matrix")
}
fn mark_editing_sites(
mut adata: AnnDataContainer,
editing_sites: &HashSet<String>,
) -> Result<AnnDataContainer> {
info!("Marking known editing sites...");
let is_editing_site: Vec<bool> = adata
.var_names
.par_iter()
.map(|name| editing_sites.contains(name))
.collect();
let marked_count = is_editing_site.par_iter().filter(|&&x| x).count();
info!(
"Marked {} editing sites out of {}",
marked_count, adata.n_vars
);
let filter_column = Series::new("is_editing_site".into(), is_editing_site).into_column();
adata.var.hstack_mut(&[filter_column])?;
Ok(adata)
}
fn add_reference_bases(
mut adata: AnnDataContainer,
reference: Arc<ReferenceGenome>,
) -> Result<AnnDataContainer> {
info!("Adding reference bases...");
let ref_bases_chars = reference.get_multiple_refs_batched(&adata.var_names)?;
let ref_bases: Vec<String> = ref_bases_chars.iter().map(|c| c.to_string()).collect();
let n_count = ref_bases_chars.par_iter().filter(|&&c| c == 'N').count();
info!(
"Retrieved {} valid reference bases, {} unknown",
ref_bases.len() - n_count,
n_count
);
let ref_column = Series::new("ref".into(), ref_bases).into_column();
adata.var.hstack_mut(&[ref_column])?;
Ok(adata)
}
#[derive(Debug, Clone)]
struct MismatchClassification {
label: String,
filter_pass: bool,
}
impl Default for MismatchClassification {
fn default() -> Self {
Self {
label: "-".to_string(),
filter_pass: false,
}
}
}
fn apply_mismatch_filtering(
mut adata: AnnDataContainer,
editing_type: &EditingType,
max_other_threshold: f32,
min_edited_threshold: f32,
min_ref_threshold: f32,
min_coverage: u32,
) -> Result<AnnDataContainer> {
info!("Applying mismatch filtering...");
let coverage_slice = extract_u32_column(&adata.var, "Coverage")?;
let ref_strings = extract_str_column(&adata.var, "ref")?;
let a_slice = extract_u32_column(&adata.var, "A")?;
let t_slice = extract_u32_column(&adata.var, "T")?;
let g_slice = extract_u32_column(&adata.var, "G")?;
let c_slice = extract_u32_column(&adata.var, "C")?;
let classifications: Vec<MismatchClassification> = (0..adata.n_vars)
.into_par_iter()
.map(|site_idx| {
classify_mismatch_fast(
site_idx,
&coverage_slice,
&ref_strings,
&a_slice,
&t_slice,
&g_slice,
&c_slice,
editing_type,
max_other_threshold,
min_edited_threshold,
min_ref_threshold,
min_coverage,
)
})
.collect();
let valid_count = classifications
.par_iter()
.filter(|classification| classification.filter_pass)
.count();
info!(
"Found {} valid mismatches out of {} sites",
valid_count, adata.n_vars
);
let mismatch_labels: Vec<String> = classifications.iter().map(|c| c.label.clone()).collect();
let filter_pass_values: Vec<bool> = classifications.iter().map(|c| c.filter_pass).collect();
let _ = adata.var.drop_in_place("Mismatch");
let _ = adata.var.drop_in_place("filter_pass");
let mismatch_column = Series::new("Mismatch".into(), mismatch_labels).into_column();
let filter_pass_column = Series::new("filter_pass".into(), filter_pass_values).into_column();
adata
.var
.hstack_mut(&[mismatch_column, filter_pass_column])?;
Ok(adata)
}
fn extract_u32_column(df: &DataFrame, col_name: &str) -> Result<Vec<u32>> {
let col = df.column(col_name).map_err(|e| {
RedicatError::DataProcessing(format!("Missing column '{}': {}", col_name, e))
})?;
Ok(col
.u32()
.map(|ca| ca.into_iter().map(|v| v.unwrap_or(0)).collect())
.or_else(|_| {
col.i32().map(|ca| {
ca.into_iter()
.map(|v| v.unwrap_or(0).max(0) as u32)
.collect()
})
})
.or_else(|_| {
col.u64().map(|ca| {
ca.into_iter()
.map(|v| v.unwrap_or(0).min(u32::MAX as u64) as u32)
.collect()
})
})
.or_else(|_| {
col.i64().map(|ca| {
ca.into_iter()
.map(|v| v.unwrap_or(0).max(0).min(u32::MAX as i64) as u32)
.collect()
})
})
.unwrap_or_else(|_| vec![0u32; df.height()]))
}
fn extract_str_column(df: &DataFrame, col_name: &str) -> Result<Vec<String>> {
let col = df.column(col_name).map_err(|e| {
RedicatError::DataProcessing(format!("Missing column '{}': {}", col_name, e))
})?;
let str_ca = col.str().map_err(|_| {
RedicatError::DataProcessing(format!("Column '{}' is not a string type", col_name))
})?;
Ok(str_ca
.into_iter()
.map(|v| v.unwrap_or("N").to_string())
.collect())
}
fn classify_mismatch_fast(
site_idx: usize,
coverage_slice: &[u32],
ref_strings: &[String],
a_slice: &[u32],
t_slice: &[u32],
g_slice: &[u32],
c_slice: &[u32],
editing_type: &EditingType,
max_other_threshold: f32,
min_edited_threshold: f32,
min_ref_threshold: f32,
min_coverage: u32,
) -> MismatchClassification {
let mut classification = MismatchClassification::default();
let coverage = coverage_slice[site_idx] as f32;
let ref_base_str = &ref_strings[site_idx];
if ref_base_str == "N" || coverage < min_coverage as f32 || coverage < 1.0 {
return classification;
}
let ref_char = ref_base_str.chars().next().unwrap().to_ascii_uppercase();
let expected_alt = editing_type.get_alt_base_for_ref(ref_char);
if expected_alt == 'N' {
return classification;
}
let other_max = (max_other_threshold * coverage).ceil().max(2.0) as u32;
let edited_min = (min_edited_threshold * coverage).ceil().max(1.0) as u32;
let ref_min = (min_ref_threshold * coverage).ceil().max(1.0) as u32;
let base_count = |base: char| -> u32 {
match base {
'A' => a_slice[site_idx],
'T' => t_slice[site_idx],
'G' => g_slice[site_idx],
'C' => c_slice[site_idx],
_ => 0,
}
};
let ref_count = base_count(ref_char);
if ref_count < ref_min {
return classification;
}
let alt_count = base_count(expected_alt);
if alt_count < edited_min {
return classification;
}
let others_count: u32 = ['A', 'T', 'G', 'C']
.iter()
.filter(|&&b| b != ref_char && b != expected_alt)
.map(|&b| base_count(b))
.sum();
if others_count > other_max {
return classification;
}
classification.filter_pass = true;
classification.label = format!("{}{}", ref_char, expected_alt);
classification
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::sparse::SparseOps;
fn csr_from_triplets(
n_rows: usize,
n_cols: usize,
triplets: &[(usize, usize, u32)],
) -> CsrMatrix<u32> {
SparseOps::from_triplets_u32(n_rows, n_cols, triplets.to_vec())
.expect("Failed to build CSR matrix for test")
}
fn matrix_value(matrix: &CsrMatrix<u32>, row: usize, col: usize) -> u32 {
let row_view = matrix.row(row);
row_view
.col_indices()
.iter()
.zip(row_view.values())
.find_map(|(&col_idx, &value)| (col_idx == col).then_some(value))
.unwrap_or(0)
}
fn build_adata(ref_base: &str, layers: Vec<(&str, CsrMatrix<u32>)>) -> AnnDataContainer {
let mut adata = AnnDataContainer::new(1, 1);
adata.obs_names = vec!["cell_0".into()];
adata.var_names = vec!["chr1:1".into()];
adata.obs = DataFrame::new(vec![
Series::new("obs_names".into(), &["cell_0"]).into_column()
])
.expect("Failed to build obs dataframe for test");
adata.var = DataFrame::new(vec![
Series::new("var_names".into(), &["chr1:1"]).into_column(),
Series::new("ref".into(), &[ref_base]).into_column(),
Series::new("is_editing_site".into(), &[true]).into_column(),
Series::new("filter_pass".into(), &[true]).into_column(),
])
.expect("Failed to build var dataframe for test");
for (name, matrix) in layers {
adata.layers.insert(name.to_string(), matrix);
}
adata
}
#[test]
fn strand_layers_are_summed_prior_to_assignment() {
let a1 = csr_from_triplets(1, 1, &[(0, 0, 2)]);
let a0 = csr_from_triplets(1, 1, &[(0, 0, 3)]);
let g1 = csr_from_triplets(1, 1, &[(0, 0, 4)]);
let g0 = csr_from_triplets(1, 1, &[(0, 0, 5)]);
let adata = build_adata("A", vec![("A1", a1), ("A0", a0), ("G1", g1), ("G0", g0)]);
let result = calculate_ref_alt_matrices(adata, &EditingType::AG)
.expect("ref/alt matrix calculation failed");
let ref_layer = result.layers.get("ref").expect("missing ref layer");
let alt_layer = result.layers.get("alt").expect("missing alt layer");
let others_layer = result.layers.get("others").expect("missing others layer");
assert_eq!(matrix_value(ref_layer, 0, 0), 5);
assert_eq!(matrix_value(alt_layer, 0, 0), 9);
assert_eq!(others_layer.nnz(), 0);
let x_matrix = result.x.expect("missing X matrix");
assert_eq!(x_matrix.nrows(), 1);
assert_eq!(x_matrix.ncols(), 1);
assert_eq!(x_matrix.csr_data().2[0], 9f64);
}
#[test]
fn negative_only_layers_are_handled() {
let a0 = csr_from_triplets(1, 1, &[(0, 0, 7)]);
let adata = build_adata("A", vec![("A0", a0)]);
let result = calculate_ref_alt_matrices(adata, &EditingType::AG)
.expect("ref/alt matrix calculation failed");
let ref_layer = result.layers.get("ref").expect("missing ref layer");
let alt_layer = result.layers.get("alt").expect("missing alt layer");
assert_eq!(matrix_value(ref_layer, 0, 0), 7);
assert_eq!(alt_layer.nnz(), 0);
}
#[test]
fn observation_sums_respect_filter_pass_mask() {
let mut adata = AnnDataContainer::new(2, 2);
adata.obs_names = vec!["cell0".into(), "cell1".into()];
adata.var_names = vec!["chr1:1".into(), "chr1:2".into()];
adata.obs = DataFrame::new(vec![
Series::new("obs_names".into(), &["cell0", "cell1"]).into_column()
])
.expect("Failed to build obs dataframe for test");
adata.var = DataFrame::new(vec![
Series::new("var_names".into(), &["chr1:1", "chr1:2"]).into_column(),
Series::new("ref".into(), &["A", "A"]).into_column(),
Series::new("is_editing_site".into(), &[true, true]).into_column(),
Series::new("filter_pass".into(), &[true, false]).into_column(),
])
.expect("Failed to build var dataframe for test");
let ref_triplets = vec![(0, 0, 8), (0, 1, 5), (1, 0, 4), (1, 1, 6)];
let alt_triplets = vec![(0, 0, 2), (0, 1, 5), (1, 0, 1), (1, 1, 3)];
let others_triplets: Vec<(usize, usize, u32)> = Vec::new();
adata
.layers
.insert("ref".into(), csr_from_triplets(2, 2, &ref_triplets));
adata
.layers
.insert("alt".into(), csr_from_triplets(2, 2, &alt_triplets));
adata
.layers
.insert("others".into(), csr_from_triplets(2, 2, &others_triplets));
let adata = calculate_observation_sums_vectorized(adata)
.expect("observation sums calculation failed");
let ref_values = adata
.obs
.column("ref")
.expect("missing ref column")
.u32()
.expect("ref column not u32");
let alt_values = adata
.obs
.column("alt")
.expect("missing alt column")
.u32()
.expect("alt column not u32");
assert_eq!(ref_values.get(0), Some(8));
assert_eq!(ref_values.get(1), Some(4));
assert_eq!(alt_values.get(0), Some(2));
assert_eq!(alt_values.get(1), Some(1));
let adata = calculate_cei(adata).expect("CEI calculation failed");
let cei_col = adata.obs.column("CEI").expect("missing CEI column");
let cei_values: Vec<f32> = cei_col
.f32()
.expect("CEI column not float")
.into_iter()
.map(|value| value.unwrap_or(0.0))
.collect();
assert!((cei_values[0] - 0.2).abs() < f32::EPSILON);
assert!((cei_values[1] - 0.2).abs() < f32::EPSILON);
}
#[test]
fn complementary_ref_alt_mapping_for_ag() {
let mut adata = AnnDataContainer::new(1, 1);
adata.var_names = vec!["chr1:1".into()];
adata.var = DataFrame::new(vec![
Series::new("var_names".into(), &["chr1:1"]).into_column(),
Series::new("ref".into(), &["T"]).into_column(),
Series::new("Coverage".into(), &[10u32]).into_column(),
Series::new("A".into(), &[1u32]).into_column(),
Series::new("G".into(), &[0u32]).into_column(),
Series::new("T".into(), &[6u32]).into_column(),
Series::new("C".into(), &[4u32]).into_column(),
])
.expect("Failed to construct var dataframe for test");
let adata = apply_mismatch_filtering(adata, &EditingType::AG, 0.4, 0.2, 0.5, 5)
.expect("mismatch filtering failed");
let filter_pass = adata
.var
.column("filter_pass")
.expect("missing filter_pass column")
.bool()
.expect("filter_pass not boolean")
.into_iter()
.map(|value| value.unwrap_or(false))
.collect::<Vec<bool>>();
assert_eq!(filter_pass, vec![true]);
let mismatch = adata
.var
.column("Mismatch")
.expect("missing mismatch column")
.str()
.expect("Mismatch column not utf8")
.into_iter()
.map(|value| value.unwrap_or(""))
.collect::<Vec<&str>>();
assert_eq!(mismatch, vec!["TC"]);
}
#[test]
fn mismatch_filter_rejects_unexpected_alt_base() {
let mut adata = AnnDataContainer::new(1, 1);
adata.var_names = vec!["chr1:1".into()];
adata.var = DataFrame::new(vec![
Series::new("var_names".into(), &["chr1:1"]).into_column(),
Series::new("ref".into(), &["A"]).into_column(),
Series::new("Coverage".into(), &[20u32]).into_column(),
Series::new("A".into(), &[12u32]).into_column(),
Series::new("G".into(), &[0u32]).into_column(),
Series::new("T".into(), &[0u32]).into_column(),
Series::new("C".into(), &[8u32]).into_column(),
])
.expect("Failed to construct var dataframe for test");
let adata = apply_mismatch_filtering(adata, &EditingType::AG, 0.6, 0.1, 0.4, 5)
.expect("mismatch filtering failed");
let filter_pass = adata
.var
.column("filter_pass")
.expect("missing filter_pass column")
.bool()
.expect("filter_pass not boolean")
.into_iter()
.map(|value| value.unwrap_or(false))
.collect::<Vec<bool>>();
assert_eq!(filter_pass, vec![false]);
let mismatch = adata
.var
.column("Mismatch")
.expect("missing mismatch column")
.str()
.expect("Mismatch column not utf8")
.into_iter()
.map(|value| value.unwrap_or(""))
.collect::<Vec<&str>>();
assert_eq!(mismatch, vec!["-"]);
}
#[test]
fn filter_pass_column_reflects_thresholds() {
let mut adata = AnnDataContainer::new(1, 2);
adata.var_names = vec!["chr1:1".into(), "chr1:2".into()];
adata.var = DataFrame::new(vec![
Series::new("var_names".into(), &["chr1:1", "chr1:2"]).into_column(),
Series::new("ref".into(), &["A", "A"]).into_column(),
Series::new("Coverage".into(), &[20u32, 20u32]).into_column(),
Series::new("A".into(), &[15u32, 19u32]).into_column(),
Series::new("G".into(), &[5u32, 1u32]).into_column(),
Series::new("T".into(), &[0u32, 0u32]).into_column(),
Series::new("C".into(), &[0u32, 0u32]).into_column(),
])
.expect("Failed to construct var dataframe for test");
let adata = apply_mismatch_filtering(adata, &EditingType::AG, 0.1, 0.2, 0.5, 10)
.expect("mismatch filtering failed");
let filter_pass = adata
.var
.column("filter_pass")
.expect("missing filter_pass column")
.bool()
.expect("filter_pass not boolean")
.into_iter()
.map(|value| value.unwrap_or(false))
.collect::<Vec<bool>>();
assert_eq!(filter_pass, vec![true, false]);
let mismatch = adata
.var
.column("Mismatch")
.expect("missing mismatch column")
.str()
.expect("Mismatch column not utf8")
.into_iter()
.map(|value| value.unwrap_or(""))
.collect::<Vec<&str>>();
assert_eq!(mismatch, vec!["AG", "-"]);
}
#[test]
fn test_ag_editing_ref_a_correct_assignment() {
let a_matrix = csr_from_triplets(1, 1, &[(0, 0, 100)]);
let g_matrix = csr_from_triplets(1, 1, &[(0, 0, 10)]);
let t_matrix = csr_from_triplets(1, 1, &[(0, 0, 2)]);
let c_matrix = csr_from_triplets(1, 1, &[(0, 0, 3)]);
let adata = build_adata(
"A",
vec![
("A1", a_matrix),
("G1", g_matrix),
("T1", t_matrix),
("C1", c_matrix),
],
);
let result = calculate_ref_alt_matrices(adata, &EditingType::AG)
.expect("ref/alt matrix calculation failed");
let ref_layer = result.layers.get("ref").expect("missing ref layer");
let alt_layer = result.layers.get("alt").expect("missing alt layer");
let others_layer = result.layers.get("others").expect("missing others layer");
assert_eq!(matrix_value(ref_layer, 0, 0), 100, "ref should be A count");
assert_eq!(matrix_value(alt_layer, 0, 0), 10, "alt should be G count");
assert_eq!(
matrix_value(others_layer, 0, 0),
5,
"obs should be T+C count"
);
}
#[test]
fn test_ag_editing_ref_t_correct_assignment() {
let a_matrix = csr_from_triplets(1, 1, &[(0, 0, 2)]);
let g_matrix = csr_from_triplets(1, 1, &[(0, 0, 3)]);
let t_matrix = csr_from_triplets(1, 1, &[(0, 0, 100)]);
let c_matrix = csr_from_triplets(1, 1, &[(0, 0, 10)]);
let adata = build_adata(
"T",
vec![
("A1", a_matrix),
("G1", g_matrix),
("T1", t_matrix),
("C1", c_matrix),
],
);
let result = calculate_ref_alt_matrices(adata, &EditingType::AG)
.expect("ref/alt matrix calculation failed");
let ref_layer = result.layers.get("ref").expect("missing ref layer");
let alt_layer = result.layers.get("alt").expect("missing alt layer");
let others_layer = result.layers.get("others").expect("missing others layer");
assert_eq!(
matrix_value(ref_layer, 0, 0),
100,
"ref should be T count (not A!)"
);
assert_eq!(
matrix_value(alt_layer, 0, 0),
10,
"alt should be C count (not G!)"
);
assert_eq!(
matrix_value(others_layer, 0, 0),
5,
"obs should be A+G count"
);
}
#[test]
fn test_ac_editing_ref_a_correct_assignment() {
let a_matrix = csr_from_triplets(1, 1, &[(0, 0, 80)]);
let c_matrix = csr_from_triplets(1, 1, &[(0, 0, 15)]);
let t_matrix = csr_from_triplets(1, 1, &[(0, 0, 3)]);
let g_matrix = csr_from_triplets(1, 1, &[(0, 0, 2)]);
let adata = build_adata(
"A",
vec![
("A1", a_matrix),
("C1", c_matrix),
("T1", t_matrix),
("G1", g_matrix),
],
);
let result = calculate_ref_alt_matrices(adata, &EditingType::AC)
.expect("ref/alt matrix calculation failed");
let ref_layer = result.layers.get("ref").expect("missing ref layer");
let alt_layer = result.layers.get("alt").expect("missing alt layer");
let others_layer = result.layers.get("others").expect("missing others layer");
assert_eq!(matrix_value(ref_layer, 0, 0), 80, "ref should be A count");
assert_eq!(matrix_value(alt_layer, 0, 0), 15, "alt should be C count");
assert_eq!(
matrix_value(others_layer, 0, 0),
5,
"obs should be T+G count"
);
}
#[test]
fn test_ac_editing_ref_t_correct_assignment() {
let a_matrix = csr_from_triplets(1, 1, &[(0, 0, 2)]);
let c_matrix = csr_from_triplets(1, 1, &[(0, 0, 3)]);
let t_matrix = csr_from_triplets(1, 1, &[(0, 0, 80)]);
let g_matrix = csr_from_triplets(1, 1, &[(0, 0, 15)]);
let adata = build_adata(
"T",
vec![
("A1", a_matrix),
("C1", c_matrix),
("T1", t_matrix),
("G1", g_matrix),
],
);
let result = calculate_ref_alt_matrices(adata, &EditingType::AC)
.expect("ref/alt matrix calculation failed");
let ref_layer = result.layers.get("ref").expect("missing ref layer");
let alt_layer = result.layers.get("alt").expect("missing alt layer");
let others_layer = result.layers.get("others").expect("missing others layer");
assert_eq!(matrix_value(ref_layer, 0, 0), 80, "ref should be T count");
assert_eq!(
matrix_value(alt_layer, 0, 0),
15,
"alt should be G count (complement of C)"
);
assert_eq!(
matrix_value(others_layer, 0, 0),
5,
"obs should be A+C count"
);
}
#[test]
fn test_ct_editing_ref_c_correct_assignment() {
let c_matrix = csr_from_triplets(1, 1, &[(0, 0, 90)]);
let t_matrix = csr_from_triplets(1, 1, &[(0, 0, 8)]);
let a_matrix = csr_from_triplets(1, 1, &[(0, 0, 1)]);
let g_matrix = csr_from_triplets(1, 1, &[(0, 0, 1)]);
let adata = build_adata(
"C",
vec![
("C1", c_matrix),
("T1", t_matrix),
("A1", a_matrix),
("G1", g_matrix),
],
);
let result = calculate_ref_alt_matrices(adata, &EditingType::CT)
.expect("ref/alt matrix calculation failed");
let ref_layer = result.layers.get("ref").expect("missing ref layer");
let alt_layer = result.layers.get("alt").expect("missing alt layer");
let others_layer = result.layers.get("others").expect("missing others layer");
assert_eq!(matrix_value(ref_layer, 0, 0), 90, "ref should be C count");
assert_eq!(matrix_value(alt_layer, 0, 0), 8, "alt should be T count");
assert_eq!(
matrix_value(others_layer, 0, 0),
2,
"obs should be A+G count"
);
}
#[test]
fn test_ct_editing_ref_g_correct_assignment() {
let c_matrix = csr_from_triplets(1, 1, &[(0, 0, 1)]);
let t_matrix = csr_from_triplets(1, 1, &[(0, 0, 1)]);
let a_matrix = csr_from_triplets(1, 1, &[(0, 0, 8)]);
let g_matrix = csr_from_triplets(1, 1, &[(0, 0, 90)]);
let adata = build_adata(
"G",
vec![
("C1", c_matrix),
("T1", t_matrix),
("A1", a_matrix),
("G1", g_matrix),
],
);
let result = calculate_ref_alt_matrices(adata, &EditingType::CT)
.expect("ref/alt matrix calculation failed");
let ref_layer = result.layers.get("ref").expect("missing ref layer");
let alt_layer = result.layers.get("alt").expect("missing alt layer");
let others_layer = result.layers.get("others").expect("missing others layer");
assert_eq!(matrix_value(ref_layer, 0, 0), 90, "ref should be G count");
assert_eq!(
matrix_value(alt_layer, 0, 0),
8,
"alt should be A count (complement of T)"
);
assert_eq!(
matrix_value(others_layer, 0, 0),
2,
"obs should be C+T count"
);
}
#[test]
fn test_ca_editing_ref_c_correct_assignment() {
let c_matrix = csr_from_triplets(1, 1, &[(0, 0, 85)]);
let a_matrix = csr_from_triplets(1, 1, &[(0, 0, 12)]);
let t_matrix = csr_from_triplets(1, 1, &[(0, 0, 2)]);
let g_matrix = csr_from_triplets(1, 1, &[(0, 0, 1)]);
let adata = build_adata(
"C",
vec![
("C1", c_matrix),
("A1", a_matrix),
("T1", t_matrix),
("G1", g_matrix),
],
);
let result = calculate_ref_alt_matrices(adata, &EditingType::CA)
.expect("ref/alt matrix calculation failed");
let ref_layer = result.layers.get("ref").expect("missing ref layer");
let alt_layer = result.layers.get("alt").expect("missing alt layer");
let others_layer = result.layers.get("others").expect("missing others layer");
assert_eq!(matrix_value(ref_layer, 0, 0), 85, "ref should be C count");
assert_eq!(matrix_value(alt_layer, 0, 0), 12, "alt should be A count");
assert_eq!(
matrix_value(others_layer, 0, 0),
3,
"obs should be T+G count"
);
}
#[test]
fn test_multi_site_different_ref_bases() {
let a_matrix = csr_from_triplets(1, 3, &[(0, 0, 100), (0, 1, 2), (0, 2, 90)]);
let g_matrix = csr_from_triplets(1, 3, &[(0, 0, 10), (0, 1, 3), (0, 2, 1)]);
let t_matrix = csr_from_triplets(1, 3, &[(0, 0, 2), (0, 1, 100), (0, 2, 8)]);
let c_matrix = csr_from_triplets(1, 3, &[(0, 0, 3), (0, 1, 10), (0, 2, 1)]);
let mut adata = AnnDataContainer::new(1, 3);
adata.obs_names = vec!["cell_0".into()];
adata.var_names = vec!["chr1:1".into(), "chr1:2".into(), "chr1:3".into()];
adata.obs = DataFrame::new(vec![
Series::new("obs_names".into(), &["cell_0"]).into_column()
])
.expect("Failed to build obs");
adata.var = DataFrame::new(vec![
Series::new("var_names".into(), &["chr1:1", "chr1:2", "chr1:3"]).into_column(),
Series::new("ref".into(), &["A", "T", "C"]).into_column(), Series::new("is_editing_site".into(), &[true, true, false]).into_column(),
Series::new("filter_pass".into(), &[true, true, false]).into_column(),
])
.expect("Failed to build var");
adata.layers.insert("A1".to_string(), a_matrix);
adata.layers.insert("G1".to_string(), g_matrix);
adata.layers.insert("T1".to_string(), t_matrix);
adata.layers.insert("C1".to_string(), c_matrix);
let result = calculate_ref_alt_matrices(adata, &EditingType::AG)
.expect("ref/alt matrix calculation failed");
let ref_layer = result.layers.get("ref").expect("missing ref layer");
let alt_layer = result.layers.get("alt").expect("missing alt layer");
let others_layer = result.layers.get("others").expect("missing others layer");
assert_eq!(matrix_value(ref_layer, 0, 0), 100);
assert_eq!(matrix_value(alt_layer, 0, 0), 10);
assert_eq!(matrix_value(others_layer, 0, 0), 5);
assert_eq!(matrix_value(ref_layer, 0, 1), 100);
assert_eq!(matrix_value(alt_layer, 0, 1), 10);
assert_eq!(matrix_value(others_layer, 0, 1), 5);
assert_eq!(matrix_value(ref_layer, 0, 2), 0);
assert_eq!(matrix_value(alt_layer, 0, 2), 0);
assert_eq!(matrix_value(others_layer, 0, 2), 0);
}
#[test]
fn test_mask_based_optimization_sparse_efficiency() {
let mut triplets_a = vec![];
let mut triplets_g = vec![];
let mut triplets_t = vec![];
let mut triplets_c = vec![];
for cell_idx in 0..100 {
if cell_idx % 10 == 0 {
triplets_a.push((cell_idx, 0, 50));
triplets_g.push((cell_idx, 0, 5));
}
if cell_idx % 10 == 1 {
triplets_t.push((cell_idx, 1, 50));
triplets_c.push((cell_idx, 1, 5));
}
}
let a_matrix = csr_from_triplets(100, 2, &triplets_a);
let g_matrix = csr_from_triplets(100, 2, &triplets_g);
let t_matrix = csr_from_triplets(100, 2, &triplets_t);
let c_matrix = csr_from_triplets(100, 2, &triplets_c);
let mut adata = AnnDataContainer::new(100, 2);
adata.obs_names = (0..100).map(|i| format!("cell_{}", i)).collect();
adata.var_names = vec!["chr1:1".into(), "chr1:2".into()];
adata.obs = DataFrame::new(vec![Series::new(
"obs_names".into(),
adata.obs_names.clone(),
)
.into_column()])
.expect("Failed to build obs");
adata.var = DataFrame::new(vec![
Series::new("var_names".into(), &["chr1:1", "chr1:2"]).into_column(),
Series::new("ref".into(), &["A", "T"]).into_column(),
Series::new("is_editing_site".into(), &[true, true]).into_column(),
Series::new("filter_pass".into(), &[true, true]).into_column(),
])
.expect("Failed to build var");
adata.layers.insert("A1".to_string(), a_matrix);
adata.layers.insert("G1".to_string(), g_matrix);
adata.layers.insert("T1".to_string(), t_matrix);
adata.layers.insert("C1".to_string(), c_matrix);
let result = calculate_ref_alt_matrices(adata, &EditingType::AG)
.expect("ref/alt matrix calculation failed");
let ref_layer = result.layers.get("ref").expect("missing ref layer");
let alt_layer = result.layers.get("alt").expect("missing alt layer");
let total_elements = 100 * 2;
let ref_nnz = ref_layer.nnz();
let alt_nnz = alt_layer.nnz();
assert!(ref_nnz < total_elements / 5, "ref matrix should be sparse");
assert!(alt_nnz < total_elements / 5, "alt matrix should be sparse");
assert_eq!(matrix_value(ref_layer, 0, 0), 50);
assert_eq!(matrix_value(alt_layer, 0, 0), 5);
assert_eq!(matrix_value(ref_layer, 1, 1), 50);
assert_eq!(matrix_value(alt_layer, 1, 1), 5);
}
#[test]
fn test_all_six_editing_types_with_mask_optimization() {
let a_matrix = csr_from_triplets(1, 1, &[(0, 0, 100)]);
let g_matrix = csr_from_triplets(1, 1, &[(0, 0, 10)]);
let t_matrix = csr_from_triplets(1, 1, &[(0, 0, 2)]);
let c_matrix = csr_from_triplets(1, 1, &[(0, 0, 3)]);
let adata = build_adata(
"A",
vec![
("A1", a_matrix.clone()),
("G1", g_matrix.clone()),
("T1", t_matrix.clone()),
("C1", c_matrix.clone()),
],
);
let result = calculate_ref_alt_matrices(adata, &EditingType::AG).expect("AG/A failed");
assert_eq!(matrix_value(result.layers.get("ref").unwrap(), 0, 0), 100);
assert_eq!(matrix_value(result.layers.get("alt").unwrap(), 0, 0), 10);
assert_eq!(matrix_value(result.layers.get("others").unwrap(), 0, 0), 5);
let a_matrix = csr_from_triplets(1, 1, &[(0, 0, 2)]);
let g_matrix = csr_from_triplets(1, 1, &[(0, 0, 3)]);
let t_matrix = csr_from_triplets(1, 1, &[(0, 0, 100)]);
let c_matrix = csr_from_triplets(1, 1, &[(0, 0, 10)]);
let adata = build_adata(
"T",
vec![
("A1", a_matrix),
("G1", g_matrix),
("T1", t_matrix),
("C1", c_matrix),
],
);
let result = calculate_ref_alt_matrices(adata, &EditingType::AG).expect("AG/T failed");
assert_eq!(matrix_value(result.layers.get("ref").unwrap(), 0, 0), 100);
assert_eq!(matrix_value(result.layers.get("alt").unwrap(), 0, 0), 10);
assert_eq!(matrix_value(result.layers.get("others").unwrap(), 0, 0), 5);
let a_matrix = csr_from_triplets(1, 1, &[(0, 0, 1)]);
let g_matrix = csr_from_triplets(1, 1, &[(0, 0, 1)]);
let t_matrix = csr_from_triplets(1, 1, &[(0, 0, 8)]);
let c_matrix = csr_from_triplets(1, 1, &[(0, 0, 90)]);
let adata = build_adata(
"C",
vec![
("A1", a_matrix),
("G1", g_matrix),
("T1", t_matrix),
("C1", c_matrix),
],
);
let result = calculate_ref_alt_matrices(adata, &EditingType::CT).expect("CT/C failed");
assert_eq!(matrix_value(result.layers.get("ref").unwrap(), 0, 0), 90);
assert_eq!(matrix_value(result.layers.get("alt").unwrap(), 0, 0), 8);
assert_eq!(matrix_value(result.layers.get("others").unwrap(), 0, 0), 2);
}
#[test]
fn test_apply_column_mask_all_true() {
let m = csr_from_triplets(2, 3, &[(0, 0, 1), (0, 1, 2), (0, 2, 3), (1, 0, 4), (1, 2, 5)]);
let mask = vec![true, true, true];
let result = apply_column_mask(&m, &mask);
assert_eq!(result.nnz(), m.nnz());
assert_eq!(matrix_value(&result, 0, 0), 1);
assert_eq!(matrix_value(&result, 0, 1), 2);
assert_eq!(matrix_value(&result, 0, 2), 3);
assert_eq!(matrix_value(&result, 1, 0), 4);
assert_eq!(matrix_value(&result, 1, 2), 5);
}
#[test]
fn test_apply_column_mask_all_false() {
let m = csr_from_triplets(2, 3, &[(0, 0, 1), (0, 1, 2), (1, 2, 3)]);
let mask = vec![false, false, false];
let result = apply_column_mask(&m, &mask);
assert_eq!(result.nnz(), 0);
}
#[test]
fn test_apply_column_mask_selective() {
let m = csr_from_triplets(2, 4, &[
(0, 0, 10), (0, 1, 20), (0, 2, 30), (0, 3, 40),
(1, 0, 50), (1, 1, 60), (1, 2, 70), (1, 3, 80),
]);
let mask = vec![true, false, false, true];
let result = apply_column_mask(&m, &mask);
assert_eq!(matrix_value(&result, 0, 0), 10);
assert_eq!(matrix_value(&result, 0, 1), 0); assert_eq!(matrix_value(&result, 0, 2), 0); assert_eq!(matrix_value(&result, 0, 3), 40);
assert_eq!(matrix_value(&result, 1, 0), 50);
assert_eq!(matrix_value(&result, 1, 3), 80);
}
#[test]
fn test_apply_column_mask_on_empty_matrix() {
let m = CsrMatrix::<u32>::zeros(3, 3);
let mask = vec![true, true, true];
let result = apply_column_mask(&m, &mask);
assert_eq!(result.nnz(), 0);
assert_eq!(result.nrows(), 3);
assert_eq!(result.ncols(), 3);
}
#[test]
fn test_build_onehot_masks_ag_ref_a() {
let ref_bases = vec!['A'];
let (ref_masks, alt_masks, others_masks) = build_onehot_masks(&ref_bases, &EditingType::AG, 1);
assert_eq!(ref_masks[&'A'], vec![true]);
assert_eq!(ref_masks[&'G'], vec![false]);
assert_eq!(alt_masks[&'G'], vec![true]);
assert_eq!(alt_masks[&'A'], vec![false]);
assert_eq!(others_masks[&'T'], vec![true]);
assert_eq!(others_masks[&'C'], vec![true]);
assert_eq!(others_masks[&'A'], vec![false]);
assert_eq!(others_masks[&'G'], vec![false]);
}
#[test]
fn test_build_onehot_masks_ag_ref_t() {
let ref_bases = vec!['T'];
let (ref_masks, alt_masks, others_masks) = build_onehot_masks(&ref_bases, &EditingType::AG, 1);
assert_eq!(ref_masks[&'T'], vec![true]);
assert_eq!(alt_masks[&'C'], vec![true]);
assert_eq!(others_masks[&'A'], vec![true]);
assert_eq!(others_masks[&'G'], vec![true]);
}
#[test]
fn test_build_onehot_masks_ct_ref_c() {
let ref_bases = vec!['C'];
let (ref_masks, alt_masks, others_masks) = build_onehot_masks(&ref_bases, &EditingType::CT, 1);
assert_eq!(ref_masks[&'C'], vec![true]);
assert_eq!(alt_masks[&'T'], vec![true]);
assert_eq!(others_masks[&'A'], vec![true]);
assert_eq!(others_masks[&'G'], vec![true]);
}
#[test]
fn test_build_onehot_masks_mixed_ref_bases() {
let ref_bases = vec!['A', 'T', 'A'];
let (ref_masks, alt_masks, _) = build_onehot_masks(&ref_bases, &EditingType::AG, 3);
assert_eq!(ref_masks[&'A'], vec![true, false, true]);
assert_eq!(ref_masks[&'T'], vec![false, true, false]);
assert_eq!(alt_masks[&'G'], vec![true, false, true]);
assert_eq!(alt_masks[&'C'], vec![false, true, false]);
}
#[test]
fn test_build_onehot_masks_unknown_ref_base() {
let ref_bases = vec!['N'];
let (ref_masks, alt_masks, others_masks) = build_onehot_masks(&ref_bases, &EditingType::AG, 1);
assert_eq!(ref_masks[&'A'], vec![false]);
assert_eq!(ref_masks[&'T'], vec![false]);
assert_eq!(alt_masks[&'A'], vec![false]);
assert_eq!(alt_masks[&'G'], vec![false]);
assert_eq!(others_masks[&'A'], vec![true]);
assert_eq!(others_masks[&'T'], vec![true]);
assert_eq!(others_masks[&'G'], vec![true]);
assert_eq!(others_masks[&'C'], vec![true]);
}
fn classify_from_df(
var_df: &DataFrame,
editing_type: &EditingType,
max_other: f32,
min_edited: f32,
min_ref: f32,
min_cov: u32,
) -> MismatchClassification {
let cov = extract_u32_column(var_df, "Coverage").unwrap();
let refs = extract_str_column(var_df, "ref").unwrap();
let a = extract_u32_column(var_df, "A").unwrap();
let t = extract_u32_column(var_df, "T").unwrap();
let g = extract_u32_column(var_df, "G").unwrap();
let c = extract_u32_column(var_df, "C").unwrap();
classify_mismatch_fast(0, &cov, &refs, &a, &t, &g, &c, editing_type, max_other, min_edited, min_ref, min_cov)
}
#[test]
fn test_classify_mismatch_insufficient_coverage() {
let var_df = DataFrame::new(vec![
Series::new("var_names".into(), &["chr1:1"]).into_column(),
Series::new("ref".into(), &["A"]).into_column(),
Series::new("Coverage".into(), &[3u32]).into_column(),
Series::new("A".into(), &[2u32]).into_column(),
Series::new("G".into(), &[1u32]).into_column(),
Series::new("T".into(), &[0u32]).into_column(),
Series::new("C".into(), &[0u32]).into_column(),
]).unwrap();
let result = classify_from_df(&var_df, &EditingType::AG, 0.1, 0.1, 0.1, 5);
assert!(!result.filter_pass);
assert_eq!(result.label, "-");
}
#[test]
fn test_classify_mismatch_ref_base_n() {
let var_df = DataFrame::new(vec![
Series::new("var_names".into(), &["chr1:1"]).into_column(),
Series::new("ref".into(), &["N"]).into_column(),
Series::new("Coverage".into(), &[100u32]).into_column(),
Series::new("A".into(), &[50u32]).into_column(),
Series::new("G".into(), &[50u32]).into_column(),
Series::new("T".into(), &[0u32]).into_column(),
Series::new("C".into(), &[0u32]).into_column(),
]).unwrap();
let result = classify_from_df(&var_df, &EditingType::AG, 0.1, 0.1, 0.1, 5);
assert!(!result.filter_pass);
}
#[test]
fn test_classify_mismatch_passes_all_thresholds() {
let var_df = DataFrame::new(vec![
Series::new("var_names".into(), &["chr1:1"]).into_column(),
Series::new("ref".into(), &["A"]).into_column(),
Series::new("Coverage".into(), &[100u32]).into_column(),
Series::new("A".into(), &[80u32]).into_column(),
Series::new("G".into(), &[18u32]).into_column(),
Series::new("T".into(), &[1u32]).into_column(),
Series::new("C".into(), &[1u32]).into_column(),
]).unwrap();
let result = classify_from_df(&var_df, &EditingType::AG, 0.1, 0.01, 0.01, 5);
assert!(result.filter_pass);
assert_eq!(result.label, "AG");
}
#[test]
fn test_classify_mismatch_too_many_others() {
let var_df = DataFrame::new(vec![
Series::new("var_names".into(), &["chr1:1"]).into_column(),
Series::new("ref".into(), &["A"]).into_column(),
Series::new("Coverage".into(), &[100u32]).into_column(),
Series::new("A".into(), &[50u32]).into_column(),
Series::new("G".into(), &[10u32]).into_column(),
Series::new("T".into(), &[20u32]).into_column(),
Series::new("C".into(), &[20u32]).into_column(),
]).unwrap();
let result = classify_from_df(&var_df, &EditingType::AG, 0.1, 0.01, 0.01, 5);
assert!(!result.filter_pass);
}
#[test]
fn test_extract_u32_column_basic() {
let df = DataFrame::new(vec![
Series::new("Coverage".into(), &[10u32, 20u32, 30u32]).into_column(),
]).unwrap();
let result = extract_u32_column(&df, "Coverage").unwrap();
assert_eq!(result, vec![10, 20, 30]);
}
#[test]
fn test_extract_u32_column_missing_column() {
let df = DataFrame::new(vec![
Series::new("A".into(), &[1u32]).into_column(),
]).unwrap();
assert!(extract_u32_column(&df, "NonExistent").is_err());
}
#[test]
fn test_extract_str_column_basic() {
let df = DataFrame::new(vec![
Series::new("ref".into(), &["A", "G", "N"]).into_column(),
]).unwrap();
let result = extract_str_column(&df, "ref").unwrap();
assert_eq!(result, vec!["A", "G", "N"]);
}
#[test]
fn test_extract_str_column_null_defaults_to_n() {
let s = Series::new("ref".into(), &[Some("A"), None, Some("C")]);
let df = DataFrame::new(vec![s.into_column()]).unwrap();
let result = extract_str_column(&df, "ref").unwrap();
assert_eq!(result, vec!["A", "N", "C"]);
}
#[test]
fn test_collect_strand_aware_base_layers_both_strands() {
let mut adata = AnnDataContainer::new(2, 2);
adata.layers.insert("A0".into(), csr_from_triplets(2, 2, &[(0, 0, 3)]));
adata.layers.insert("A1".into(), csr_from_triplets(2, 2, &[(0, 0, 7)]));
let layers = collect_strand_aware_base_layers(&mut adata).unwrap();
assert_eq!(layers.len(), 1); assert_eq!(layers[0].0, 'A');
assert_eq!(matrix_value(&layers[0].1, 0, 0), 10); assert!(adata.layers.is_empty());
}
#[test]
fn test_collect_strand_aware_base_layers_single_strand() {
let mut adata = AnnDataContainer::new(2, 2);
adata.layers.insert("G1".into(), csr_from_triplets(2, 2, &[(1, 1, 5)]));
let layers = collect_strand_aware_base_layers(&mut adata).unwrap();
assert_eq!(layers.len(), 1);
assert_eq!(layers[0].0, 'G');
assert_eq!(matrix_value(&layers[0].1, 1, 1), 5);
assert!(adata.layers.is_empty());
}
#[test]
fn test_collect_strand_aware_base_layers_empty() {
let mut adata = AnnDataContainer::new(2, 2);
let layers = collect_strand_aware_base_layers(&mut adata).unwrap();
assert!(layers.is_empty());
}
#[test]
fn test_combined_filter_mask_both_true() {
let var_df = DataFrame::new(vec![
Series::new("is_editing_site".into(), &[true, false, true]).into_column(),
Series::new("filter_pass".into(), &[true, true, false]).into_column(),
]).unwrap();
let mask = combined_filter_mask(&var_df).unwrap();
assert_eq!(mask, vec![true, false, false]);
}
#[test]
fn test_combined_filter_mask_missing_column() {
let var_df = DataFrame::new(vec![
Series::new("is_editing_site".into(), &[true]).into_column(),
]).unwrap();
assert!(combined_filter_mask(&var_df).is_err());
}
#[test]
fn test_calculate_cei_zero_denominator() {
let mut adata = AnnDataContainer::new(2, 1);
adata.obs = DataFrame::new(vec![
Series::new("obs_names".into(), &["c0", "c1"]).into_column(),
Series::new("ref".into(), &[0u32, 0u32]).into_column(),
Series::new("alt".into(), &[0u32, 5u32]).into_column(),
]).unwrap();
let result = calculate_cei(adata).unwrap();
let cei = result.obs.column("CEI").unwrap().f32().unwrap();
assert!(cei.get(0).unwrap().is_nan());
assert_eq!(cei.get(1), Some(1.0)); }
#[test]
fn test_calculate_site_mismatch_stats_adds_columns() {
let mut adata = AnnDataContainer::new(2, 2);
adata.var = DataFrame::new(vec![
Series::new("var_names".into(), &["s0", "s1"]).into_column(),
]).unwrap();
adata.layers.insert("ref".into(), csr_from_triplets(2, 2, &[(0, 0, 10), (1, 1, 20)]));
adata.layers.insert("alt".into(), csr_from_triplets(2, 2, &[(0, 0, 3), (1, 1, 7)]));
adata.layers.insert("others".into(), csr_from_triplets(2, 2, &[(0, 1, 1)]));
let result = calculate_site_mismatch_stats(adata, 'A', 'G').unwrap();
assert!(result.var.column("AG_ref").is_ok());
assert!(result.var.column("AG_alt").is_ok());
assert!(result.var.column("AG_others").is_ok());
let ref_col = result.var.column("AG_ref").unwrap().u32().unwrap();
assert_eq!(ref_col.get(0), Some(10));
assert_eq!(ref_col.get(1), Some(20));
}
#[test]
fn test_mark_editing_sites_partial_match() {
let mut adata = AnnDataContainer::new(1, 3);
adata.var_names = vec!["chr1:100".into(), "chr1:200".into(), "chr1:300".into()];
adata.var = DataFrame::new(vec![
Series::new("var_names".into(), adata.var_names.clone()).into_column(),
]).unwrap();
let mut sites = HashSet::new();
sites.insert("chr1:100".to_string());
sites.insert("chr1:300".to_string());
let result = mark_editing_sites(adata, &sites).unwrap();
let is_editing = result.var.column("is_editing_site").unwrap().bool().unwrap();
assert_eq!(is_editing.get(0), Some(true));
assert_eq!(is_editing.get(1), Some(false));
assert_eq!(is_editing.get(2), Some(true));
}
#[test]
fn test_convert_u32_to_f64_preserves_structure() {
let m = csr_from_triplets(2, 3, &[(0, 0, 100), (1, 2, 200)]);
let f64_m = convert_u32_to_f64_csr(&m);
assert_eq!(f64_m.nrows(), 2);
assert_eq!(f64_m.ncols(), 3);
assert_eq!(f64_m.nnz(), 2);
assert_eq!(f64_m.csr_data().2[0], 100.0);
assert_eq!(f64_m.csr_data().2[1], 200.0);
}
#[test]
fn test_calculate_ref_alt_matrices_no_base_layers_errors() {
let mut adata = AnnDataContainer::new(2, 1);
adata.var_names = vec!["chr1:1".into()];
adata.var = DataFrame::new(vec![
Series::new("var_names".into(), &["chr1:1"]).into_column(),
Series::new("ref".into(), &["A"]).into_column(),
Series::new("is_editing_site".into(), &[true]).into_column(),
Series::new("filter_pass".into(), &[true]).into_column(),
]).unwrap();
let result = calculate_ref_alt_matrices(adata, &EditingType::AG);
assert!(result.is_err());
}
#[test]
fn test_large_scale_correctness() {
let n_cells = 500;
let n_sites = 100;
let mut triplets_a = vec![];
let mut triplets_g = vec![];
let triplets_t = vec![];
let triplets_c = vec![];
for site_idx in 0..n_sites {
for cell_idx in 0..n_cells {
if (cell_idx + site_idx) % 20 == 0 {
triplets_a.push((cell_idx, site_idx, 50));
triplets_g.push((cell_idx, site_idx, 5));
}
}
}
let a_matrix = csr_from_triplets(n_cells, n_sites, &triplets_a);
let g_matrix = csr_from_triplets(n_cells, n_sites, &triplets_g);
let t_matrix = csr_from_triplets(n_cells, n_sites, &triplets_t);
let c_matrix = csr_from_triplets(n_cells, n_sites, &triplets_c);
let mut adata = AnnDataContainer::new(n_cells, n_sites);
adata.obs_names = (0..n_cells).map(|i| format!("cell_{}", i)).collect();
adata.var_names = (0..n_sites).map(|i| format!("chr1:{}", i)).collect();
adata.obs = DataFrame::new(vec![Series::new(
"obs_names".into(),
adata.obs_names.clone(),
)
.into_column()])
.expect("Failed to build obs");
let ref_bases: Vec<&str> = (0..n_sites).map(|_| "A").collect();
adata.var = DataFrame::new(vec![
Series::new("var_names".into(), adata.var_names.clone()).into_column(),
Series::new("ref".into(), ref_bases).into_column(),
Series::new("is_editing_site".into(), vec![true; n_sites]).into_column(),
Series::new("filter_pass".into(), vec![true; n_sites]).into_column(),
])
.expect("Failed to build var");
adata.layers.insert("A1".to_string(), a_matrix);
adata.layers.insert("G1".to_string(), g_matrix);
adata.layers.insert("T1".to_string(), t_matrix);
adata.layers.insert("C1".to_string(), c_matrix);
let result =
calculate_ref_alt_matrices(adata, &EditingType::AG).expect("Large scale test failed");
let ref_layer = result.layers.get("ref").expect("missing ref layer");
let alt_layer = result.layers.get("alt").expect("missing alt layer");
assert_eq!(ref_layer.nrows(), n_cells);
assert_eq!(ref_layer.ncols(), n_sites);
assert_eq!(alt_layer.nrows(), n_cells);
assert_eq!(alt_layer.ncols(), n_sites);
let total_elements = n_cells * n_sites;
assert!(ref_layer.nnz() < total_elements / 10);
assert!(alt_layer.nnz() < total_elements / 10);
}
}