#[cfg(test)]
mod tests {
use crate::core::sparse::SparseOps;
use crate::pipeline::call::anndata_ops::AnnDataContainer;
use crate::pipeline::call::base_matrix::{calculate_coverage, count_base_levels, filter_sites_by_coverage};
use crate::pipeline::call::editing::EditingType;
use crate::pipeline::call::editing_analysis::{
calculate_cei, calculate_ref_alt_matrices, calculate_site_mismatch_stats,
};
use nalgebra_sparse::CsrMatrix;
use polars::prelude::*;
fn csr(nrows: usize, ncols: usize, triplets: &[(usize, usize, u32)]) -> CsrMatrix<u32> {
SparseOps::from_triplets_u32(nrows, ncols, triplets.to_vec()).unwrap()
}
fn val(matrix: &CsrMatrix<u32>, row: usize, col: usize) -> u32 {
let r = matrix.row(row);
r.col_indices()
.iter()
.zip(r.values())
.find_map(|(&c, &v)| (c == col).then_some(v))
.unwrap_or(0)
}
fn build_full_stranded_adata() -> AnnDataContainer {
let n_obs = 4;
let n_vars = 3;
let mut adata = AnnDataContainer::new(n_obs, n_vars);
adata.obs_names = (0..n_obs).map(|i| format!("cell_{}", i)).collect();
adata.var_names = vec!["chr1:100".into(), "chr1:200".into(), "chr1:300".into()];
adata.obs = DataFrame::new(vec![
Series::new("obs_names".into(), adata.obs_names.clone()).into_column(),
]).unwrap();
adata.layers.insert("A1".into(), csr(n_obs, n_vars, &[
(0, 0, 30), (1, 0, 20), (2, 1, 5), (3, 2, 50),
]));
adata.layers.insert("T1".into(), csr(n_obs, n_vars, &[
(0, 1, 40), (1, 1, 35), (3, 0, 2),
]));
adata.layers.insert("G1".into(), csr(n_obs, n_vars, &[
(0, 0, 8), (1, 0, 4), (2, 1, 1), (3, 2, 3),
]));
adata.layers.insert("C1".into(), csr(n_obs, n_vars, &[
(0, 1, 6), (1, 1, 5), (3, 2, 2),
]));
adata.layers.insert("A0".into(), csr(n_obs, n_vars, &[
(0, 0, 10), (2, 1, 3),
]));
adata.layers.insert("T0".into(), csr(n_obs, n_vars, &[
(0, 1, 5), (1, 1, 3),
]));
adata.layers.insert("G0".into(), csr(n_obs, n_vars, &[
(0, 0, 2), (1, 0, 1),
]));
adata.layers.insert("C0".into(), csr(n_obs, n_vars, &[
(0, 1, 1),
]));
adata.var = DataFrame::new(vec![
Series::new("var_names".into(), adata.var_names.clone()).into_column(),
Series::new("ref".into(), &["A", "T", "C"]).into_column(),
Series::new("is_editing_site".into(), &[true, true, false]).into_column(),
Series::new("filter_pass".into(), &[true, true, false]).into_column(),
]).unwrap();
adata
}
#[test]
fn snapshot_full_pipeline_ag_stranded() {
let adata = build_full_stranded_adata();
let adata = calculate_ref_alt_matrices(adata, &EditingType::AG)
.expect("ref/alt matrix calculation failed");
assert_eq!(adata.n_vars, 2, "site2 should have been filtered out");
assert_eq!(adata.var_names, vec!["chr1:100", "chr1:200"]);
let ref_layer = adata.layers.get("ref").expect("missing ref");
let alt_layer = adata.layers.get("alt").expect("missing alt");
let others_layer = adata.layers.get("others").expect("missing others");
assert_eq!(val(ref_layer, 0, 0), 40, "cell0 site0 ref(A)");
assert_eq!(val(ref_layer, 1, 0), 20, "cell1 site0 ref(A)");
assert_eq!(val(alt_layer, 0, 0), 10, "cell0 site0 alt(G)");
assert_eq!(val(alt_layer, 1, 0), 5, "cell1 site0 alt(G)");
assert_eq!(val(others_layer, 3, 0), 2, "cell3 site0 others(T+C)");
assert_eq!(val(ref_layer, 0, 1), 45, "cell0 site1 ref(T)");
assert_eq!(val(ref_layer, 1, 1), 38, "cell1 site1 ref(T)");
assert_eq!(val(alt_layer, 0, 1), 7, "cell0 site1 alt(C)");
assert_eq!(val(alt_layer, 1, 1), 5, "cell1 site1 alt(C)");
assert_eq!(val(others_layer, 2, 1), 9, "cell2 site1 others(A+G)");
let x = adata.x.as_ref().expect("missing X");
assert_eq!(x.nrows(), 4);
assert_eq!(x.ncols(), 2);
let adata = calculate_cei(adata).expect("CEI calculation failed");
let cei_col = adata.obs.column("CEI").expect("missing CEI");
let cei: Vec<f32> = cei_col.f32().unwrap().into_iter()
.map(|v| v.unwrap_or(0.0))
.collect();
let expected_cei_0 = 17.0 / 102.0;
assert!((cei[0] - expected_cei_0).abs() < 1e-4, "cell0 CEI: got {}, expected {}", cei[0], expected_cei_0);
let expected_cei_1 = 10.0 / 68.0;
assert!((cei[1] - expected_cei_1).abs() < 1e-4, "cell1 CEI: got {}, expected {}", cei[1], expected_cei_1);
let adata = calculate_site_mismatch_stats(adata, 'A', 'G')
.expect("mismatch stats failed");
assert!(adata.var.column("AG_ref").is_ok());
assert!(adata.var.column("AG_alt").is_ok());
assert!(adata.var.column("AG_others").is_ok());
let ag_ref = adata.var.column("AG_ref").unwrap().u32().unwrap();
let ag_alt = adata.var.column("AG_alt").unwrap().u32().unwrap();
assert_eq!(ag_ref.get(0), Some(60), "AG_ref site0");
assert_eq!(ag_alt.get(0), Some(15), "AG_alt site0");
assert_eq!(ag_ref.get(1), Some(83), "AG_ref site1");
assert_eq!(ag_alt.get(1), Some(12), "AG_alt site1");
}
#[test]
fn snapshot_pipeline_unstranded() {
let n_obs = 2;
let n_vars = 2;
let mut adata = AnnDataContainer::new(n_obs, n_vars);
adata.obs_names = vec!["c0".into(), "c1".into()];
adata.var_names = vec!["chr1:10".into(), "chr1:20".into()];
adata.obs = DataFrame::new(vec![
Series::new("obs_names".into(), adata.obs_names.clone()).into_column(),
]).unwrap();
adata.layers.insert("A1".into(), csr(2, 2, &[(0, 0, 50), (1, 0, 30)]));
adata.layers.insert("T1".into(), csr(2, 2, &[(0, 1, 60), (1, 1, 40)]));
adata.layers.insert("G1".into(), csr(2, 2, &[(0, 0, 5), (1, 0, 3)]));
adata.layers.insert("C1".into(), csr(2, 2, &[(0, 1, 8), (1, 1, 6)]));
adata.var = DataFrame::new(vec![
Series::new("var_names".into(), adata.var_names.clone()).into_column(),
Series::new("ref".into(), &["A", "T"]).into_column(),
Series::new("is_editing_site".into(), &[true, true]).into_column(),
Series::new("filter_pass".into(), &[true, true]).into_column(),
]).unwrap();
let adata = calculate_ref_alt_matrices(adata, &EditingType::AG).unwrap();
let ref_layer = adata.layers.get("ref").unwrap();
let alt_layer = adata.layers.get("alt").unwrap();
assert_eq!(val(ref_layer, 0, 0), 50);
assert_eq!(val(alt_layer, 0, 0), 5);
assert_eq!(val(ref_layer, 0, 1), 60);
assert_eq!(val(alt_layer, 0, 1), 8);
let adata = calculate_cei(adata).unwrap();
let cei: Vec<f32> = adata.obs.column("CEI").unwrap().f32().unwrap()
.into_iter().map(|v| v.unwrap_or(0.0)).collect();
assert!((cei[0] - 13.0/123.0).abs() < 1e-4);
}
#[test]
fn snapshot_coverage_stranded_all_layers() {
let mut adata = build_full_stranded_adata();
calculate_coverage(&mut adata).unwrap();
let cov = adata.layers.get("coverage").unwrap();
let row_sums = SparseOps::compute_row_sums(cov);
assert_eq!(row_sums[0], 102, "cell0 total coverage");
assert_eq!(row_sums[1], 68, "cell1 total coverage");
}
#[test]
fn snapshot_filter_by_coverage_then_ref_alt() {
let n_obs = 2;
let n_vars = 3;
let mut adata = AnnDataContainer::new(n_obs, n_vars);
adata.obs_names = vec!["c0".into(), "c1".into()];
adata.var_names = vec!["chr1:1".into(), "chr1:2".into(), "chr1:3".into()];
adata.obs = DataFrame::new(vec![
Series::new("obs_names".into(), adata.obs_names.clone()).into_column(),
]).unwrap();
adata.layers.insert("A1".into(), csr(2, 3, &[(0, 0, 50), (1, 0, 40), (0, 2, 60), (1, 2, 55)]));
adata.layers.insert("T1".into(), csr(2, 3, &[(0, 1, 1)]));
adata.layers.insert("G1".into(), csr(2, 3, &[(0, 0, 5), (1, 0, 3), (0, 2, 7), (1, 2, 4)]));
adata.layers.insert("C1".into(), CsrMatrix::<u32>::zeros(2, 3));
adata.var = DataFrame::new(vec![
Series::new("var_names".into(), adata.var_names.clone()).into_column(),
]).unwrap();
let filtered = filter_sites_by_coverage(adata, 5).unwrap();
assert_eq!(filtered.n_vars, 2);
assert_eq!(filtered.var_names, vec!["chr1:1", "chr1:3"]);
}
#[test]
fn snapshot_count_base_levels() {
let mut adata = AnnDataContainer::new(3, 2);
adata.var_names = vec!["s0".into(), "s1".into()];
adata.var = DataFrame::new(vec![
Series::new("var_names".into(), adata.var_names.clone()).into_column(),
]).unwrap();
adata.layers.insert("A1".into(), csr(3, 2, &[(0, 0, 10), (1, 0, 20)]));
adata.layers.insert("T1".into(), csr(3, 2, &[(0, 1, 5), (2, 1, 15)]));
adata.layers.insert("G1".into(), csr(3, 2, &[(1, 0, 3)]));
adata.layers.insert("C1".into(), csr(3, 2, &[(2, 0, 1)]));
let adata = count_base_levels(adata).unwrap();
let a_col = adata.var.column("A").unwrap().u32().unwrap();
let t_col = adata.var.column("T").unwrap().u32().unwrap();
let g_col = adata.var.column("G").unwrap().u32().unwrap();
let c_col = adata.var.column("C").unwrap().u32().unwrap();
let cov = adata.var.column("Coverage").unwrap().u32().unwrap();
assert_eq!(a_col.get(0), Some(30));
assert_eq!(t_col.get(0), Some(0));
assert_eq!(g_col.get(0), Some(3));
assert_eq!(c_col.get(0), Some(1));
assert_eq!(cov.get(0), Some(34));
assert_eq!(a_col.get(1), Some(0));
assert_eq!(t_col.get(1), Some(20));
assert_eq!(cov.get(1), Some(20));
}
#[test]
fn snapshot_ct_editing_pipeline() {
let n_obs = 2;
let n_vars = 2;
let mut adata = AnnDataContainer::new(n_obs, n_vars);
adata.obs_names = vec!["c0".into(), "c1".into()];
adata.var_names = vec!["chr1:10".into(), "chr1:20".into()];
adata.obs = DataFrame::new(vec![
Series::new("obs_names".into(), adata.obs_names.clone()).into_column(),
]).unwrap();
adata.layers.insert("C1".into(), csr(2, 2, &[(0, 0, 80), (1, 0, 70)]));
adata.layers.insert("T1".into(), csr(2, 2, &[(0, 0, 10), (1, 0, 8)]));
adata.layers.insert("G1".into(), csr(2, 2, &[(0, 1, 90), (1, 1, 85)]));
adata.layers.insert("A1".into(), csr(2, 2, &[(0, 1, 12), (1, 1, 9)]));
adata.var = DataFrame::new(vec![
Series::new("var_names".into(), adata.var_names.clone()).into_column(),
Series::new("ref".into(), &["C", "G"]).into_column(),
Series::new("is_editing_site".into(), &[true, true]).into_column(),
Series::new("filter_pass".into(), &[true, true]).into_column(),
]).unwrap();
let adata = calculate_ref_alt_matrices(adata, &EditingType::CT).unwrap();
let ref_layer = adata.layers.get("ref").unwrap();
let alt_layer = adata.layers.get("alt").unwrap();
assert_eq!(val(ref_layer, 0, 0), 80, "c0 site0 ref=C");
assert_eq!(val(alt_layer, 0, 0), 10, "c0 site0 alt=T");
assert_eq!(val(ref_layer, 0, 1), 90, "c0 site1 ref=G");
assert_eq!(val(alt_layer, 0, 1), 12, "c0 site1 alt=A");
let adata = calculate_cei(adata).unwrap();
let adata = calculate_site_mismatch_stats(adata, 'C', 'T').unwrap();
let ct_ref = adata.var.column("CT_ref").unwrap().u32().unwrap();
assert_eq!(ct_ref.get(0), Some(150)); assert_eq!(ct_ref.get(1), Some(175)); }
}