redicat 0.4.2

REDICAT - RNA Editing Cellular Assessment Toolkit: A highly parallelized utility for analyzing RNA editing events in single-cell RNA-seq data
Documentation
//! End-to-end integration tests for the `call` pipeline.
//!
//! These tests build realistic AnnDataContainer instances and run the full
//! pipeline path (calculate_ref_alt_matrices → calculate_cei →
//! calculate_site_mismatch_stats) to produce a correctness oracle that must
//! be maintained across all refactoring.

#[cfg(test)]
mod tests {
    use crate::core::sparse::SparseOps;
    use crate::pipeline::call::anndata_ops::AnnDataContainer;
    use crate::pipeline::call::base_matrix::{calculate_coverage, count_base_levels, filter_sites_by_coverage};
    use crate::pipeline::call::editing::EditingType;
    use crate::pipeline::call::editing_analysis::{
        calculate_cei, calculate_ref_alt_matrices, calculate_site_mismatch_stats,
    };
    use nalgebra_sparse::CsrMatrix;
    use polars::prelude::*;

    fn csr(nrows: usize, ncols: usize, triplets: &[(usize, usize, u32)]) -> CsrMatrix<u32> {
        SparseOps::from_triplets_u32(nrows, ncols, triplets.to_vec()).unwrap()
    }

    fn val(matrix: &CsrMatrix<u32>, row: usize, col: usize) -> u32 {
        let r = matrix.row(row);
        r.col_indices()
            .iter()
            .zip(r.values())
            .find_map(|(&c, &v)| (c == col).then_some(v))
            .unwrap_or(0)
    }

    /// Build a full 8-layer stranded AnnDataContainer ready for the pipeline.
    ///
    /// Layout: 4 cells × 3 sites, stranded (A0/A1/T0/T1/G0/G1/C0/C1).
    ///
    /// Site 0: ref=A (editing site, will pass filter for AG)
    /// Site 1: ref=T (editing site, complementary strand for AG → alt=C)
    /// Site 2: ref=C (NOT valid for AG editing, should be filtered out)
    fn build_full_stranded_adata() -> AnnDataContainer {
        let n_obs = 4;
        let n_vars = 3;

        let mut adata = AnnDataContainer::new(n_obs, n_vars);
        adata.obs_names = (0..n_obs).map(|i| format!("cell_{}", i)).collect();
        adata.var_names = vec!["chr1:100".into(), "chr1:200".into(), "chr1:300".into()];
        adata.obs = DataFrame::new(vec![
            Series::new("obs_names".into(), adata.obs_names.clone()).into_column(),
        ]).unwrap();

        // --- Positive strand layers (*1) ---
        // A1: cell0@site0=30, cell1@site0=20, cell2@site1=5, cell3@site2=50
        adata.layers.insert("A1".into(), csr(n_obs, n_vars, &[
            (0, 0, 30), (1, 0, 20), (2, 1, 5), (3, 2, 50),
        ]));
        // T1: cell0@site1=40, cell1@site1=35, cell3@site0=2
        adata.layers.insert("T1".into(), csr(n_obs, n_vars, &[
            (0, 1, 40), (1, 1, 35), (3, 0, 2),
        ]));
        // G1: cell0@site0=8, cell1@site0=4, cell2@site1=1, cell3@site2=3
        adata.layers.insert("G1".into(), csr(n_obs, n_vars, &[
            (0, 0, 8), (1, 0, 4), (2, 1, 1), (3, 2, 3),
        ]));
        // C1: cell0@site1=6, cell1@site1=5, cell3@site2=2
        adata.layers.insert("C1".into(), csr(n_obs, n_vars, &[
            (0, 1, 6), (1, 1, 5), (3, 2, 2),
        ]));

        // --- Negative strand layers (*0) ---
        // A0: cell0@site0=10, cell2@site1=3
        adata.layers.insert("A0".into(), csr(n_obs, n_vars, &[
            (0, 0, 10), (2, 1, 3),
        ]));
        // T0: cell0@site1=5, cell1@site1=3
        adata.layers.insert("T0".into(), csr(n_obs, n_vars, &[
            (0, 1, 5), (1, 1, 3),
        ]));
        // G0: cell0@site0=2, cell1@site0=1
        adata.layers.insert("G0".into(), csr(n_obs, n_vars, &[
            (0, 0, 2), (1, 0, 1),
        ]));
        // C0: cell0@site1=1
        adata.layers.insert("C0".into(), csr(n_obs, n_vars, &[
            (0, 1, 1),
        ]));

        // var DataFrame: ref bases and editing site/filter annotations
        adata.var = DataFrame::new(vec![
            Series::new("var_names".into(), adata.var_names.clone()).into_column(),
            Series::new("ref".into(), &["A", "T", "C"]).into_column(),
            Series::new("is_editing_site".into(), &[true, true, false]).into_column(),
            Series::new("filter_pass".into(), &[true, true, false]).into_column(),
        ]).unwrap();

        adata
    }

    // ====================================================================
    // Snapshot: full pipeline for AG editing with stranded data
    // ====================================================================

    #[test]
    fn snapshot_full_pipeline_ag_stranded() {
        let adata = build_full_stranded_adata();

        // --- Step 1: calculate_ref_alt_matrices ---
        // This internally filters by editing type (AG allows ref A and T),
        // collapses strands, then builds ref/alt/others.
        let adata = calculate_ref_alt_matrices(adata, &EditingType::AG)
            .expect("ref/alt matrix calculation failed");

        // After filtering, site2 (ref=C) is removed. Remaining: site0(ref=A), site1(ref=T).
        assert_eq!(adata.n_vars, 2, "site2 should have been filtered out");
        assert_eq!(adata.var_names, vec!["chr1:100", "chr1:200"]);

        let ref_layer = adata.layers.get("ref").expect("missing ref");
        let alt_layer = adata.layers.get("alt").expect("missing alt");
        let others_layer = adata.layers.get("others").expect("missing others");

        // Site 0 (ref=A): For AG editing, ref=A layer, alt=G layer.
        //   A total = A0 + A1: cell0=40, cell1=20, cell2=0, cell3=0
        //   G total = G0 + G1: cell0=10, cell1=5, cell2=0, cell3=0
        //   T total = T0 + T1: cell0=0, cell1=0, cell2=0, cell3=2
        //   C total = C0 + C1: cell0=0, cell1=0, cell2=0, cell3=0
        //   ref = A total, alt = G total, others = T + C total
        assert_eq!(val(ref_layer, 0, 0), 40, "cell0 site0 ref(A)");
        assert_eq!(val(ref_layer, 1, 0), 20, "cell1 site0 ref(A)");
        assert_eq!(val(alt_layer, 0, 0), 10, "cell0 site0 alt(G)");
        assert_eq!(val(alt_layer, 1, 0), 5, "cell1 site0 alt(G)");
        assert_eq!(val(others_layer, 3, 0), 2, "cell3 site0 others(T+C)");

        // Site 1 (ref=T): For AG editing with ref=T, alt=C (complementary).
        //   T total = T0 + T1: cell0=45, cell1=38, cell2=0, cell3=0
        //   C total = C0 + C1: cell0=7, cell1=5, cell2=0, cell3=0
        //   A total = A0 + A1: cell0=0, cell1=0, cell2=8, cell3=0
        //   G total = G0 + G1: cell0=0, cell1=0, cell2=1, cell3=0
        //   ref = T total, alt = C total, others = A + G total
        assert_eq!(val(ref_layer, 0, 1), 45, "cell0 site1 ref(T)");
        assert_eq!(val(ref_layer, 1, 1), 38, "cell1 site1 ref(T)");
        assert_eq!(val(alt_layer, 0, 1), 7, "cell0 site1 alt(C)");
        assert_eq!(val(alt_layer, 1, 1), 5, "cell1 site1 alt(C)");
        assert_eq!(val(others_layer, 2, 1), 9, "cell2 site1 others(A+G)");

        // X matrix should equal alt layer as f64
        let x = adata.x.as_ref().expect("missing X");
        assert_eq!(x.nrows(), 4);
        assert_eq!(x.ncols(), 2);

        // --- Step 2: calculate_cei ---
        let adata = calculate_cei(adata).expect("CEI calculation failed");
        let cei_col = adata.obs.column("CEI").expect("missing CEI");
        let cei: Vec<f32> = cei_col.f32().unwrap().into_iter()
            .map(|v| v.unwrap_or(0.0))
            .collect();
        // cell0: alt = 10 + 7 = 17, ref = 40 + 45 = 85 → CEI = 17 / (85+17) = 17/102
        let expected_cei_0 = 17.0 / 102.0;
        assert!((cei[0] - expected_cei_0).abs() < 1e-4, "cell0 CEI: got {}, expected {}", cei[0], expected_cei_0);
        // cell1: alt = 5 + 5 = 10, ref = 20 + 38 = 58 → CEI = 10/68
        let expected_cei_1 = 10.0 / 68.0;
        assert!((cei[1] - expected_cei_1).abs() < 1e-4, "cell1 CEI: got {}, expected {}", cei[1], expected_cei_1);

        // --- Step 3: calculate_site_mismatch_stats ---
        let adata = calculate_site_mismatch_stats(adata, 'A', 'G')
            .expect("mismatch stats failed");
        assert!(adata.var.column("AG_ref").is_ok());
        assert!(adata.var.column("AG_alt").is_ok());
        assert!(adata.var.column("AG_others").is_ok());

        let ag_ref = adata.var.column("AG_ref").unwrap().u32().unwrap();
        let ag_alt = adata.var.column("AG_alt").unwrap().u32().unwrap();
        // site0 ref col sum: 40+20+0+0 = 60
        assert_eq!(ag_ref.get(0), Some(60), "AG_ref site0");
        // site0 alt col sum: 10+5+0+0 = 15
        assert_eq!(ag_alt.get(0), Some(15), "AG_alt site0");
        // site1 ref col sum: 45+38+0+0 = 83
        assert_eq!(ag_ref.get(1), Some(83), "AG_ref site1");
        // site1 alt col sum: 7+5+0+0 = 12
        assert_eq!(ag_alt.get(1), Some(12), "AG_alt site1");
    }

    // ====================================================================
    // Snapshot: unstranded data (only *1 layers)
    // ====================================================================

    #[test]
    fn snapshot_pipeline_unstranded() {
        let n_obs = 2;
        let n_vars = 2;

        let mut adata = AnnDataContainer::new(n_obs, n_vars);
        adata.obs_names = vec!["c0".into(), "c1".into()];
        adata.var_names = vec!["chr1:10".into(), "chr1:20".into()];
        adata.obs = DataFrame::new(vec![
            Series::new("obs_names".into(), adata.obs_names.clone()).into_column(),
        ]).unwrap();

        adata.layers.insert("A1".into(), csr(2, 2, &[(0, 0, 50), (1, 0, 30)]));
        adata.layers.insert("T1".into(), csr(2, 2, &[(0, 1, 60), (1, 1, 40)]));
        adata.layers.insert("G1".into(), csr(2, 2, &[(0, 0, 5), (1, 0, 3)]));
        adata.layers.insert("C1".into(), csr(2, 2, &[(0, 1, 8), (1, 1, 6)]));

        adata.var = DataFrame::new(vec![
            Series::new("var_names".into(), adata.var_names.clone()).into_column(),
            Series::new("ref".into(), &["A", "T"]).into_column(),
            Series::new("is_editing_site".into(), &[true, true]).into_column(),
            Series::new("filter_pass".into(), &[true, true]).into_column(),
        ]).unwrap();

        let adata = calculate_ref_alt_matrices(adata, &EditingType::AG).unwrap();

        let ref_layer = adata.layers.get("ref").unwrap();
        let alt_layer = adata.layers.get("alt").unwrap();

        // site0 ref=A: ref=A1, alt=G1
        assert_eq!(val(ref_layer, 0, 0), 50);
        assert_eq!(val(alt_layer, 0, 0), 5);
        // site1 ref=T: ref=T1, alt=C1
        assert_eq!(val(ref_layer, 0, 1), 60);
        assert_eq!(val(alt_layer, 0, 1), 8);

        let adata = calculate_cei(adata).unwrap();
        let cei: Vec<f32> = adata.obs.column("CEI").unwrap().f32().unwrap()
            .into_iter().map(|v| v.unwrap_or(0.0)).collect();
        // c0: alt=5+8=13, ref=50+60=110 → CEI=13/123
        assert!((cei[0] - 13.0/123.0).abs() < 1e-4);
    }

    // ====================================================================
    // Snapshot: coverage calculation consistency
    // ====================================================================

    #[test]
    fn snapshot_coverage_stranded_all_layers() {
        let mut adata = build_full_stranded_adata();
        calculate_coverage(&mut adata).unwrap();
        let cov = adata.layers.get("coverage").unwrap();
        // cell0: A0+A1=40, T0+T1=45, G0+G1=10, C0+C1=7 → total=102
        //   but only nonzero at sites actually populated
        // cell0 site0: A0(10)+A1(30)+G0(2)+G1(8) = 50
        // cell0 site1: T0(5)+T1(40)+C0(1)+C1(6) = 52
        // cell0 total row sum = 50 + 52 = 102
        let row_sums = SparseOps::compute_row_sums(cov);
        assert_eq!(row_sums[0], 102, "cell0 total coverage");
        // cell1: site0: A1(20)+G1(4)+T1(0)+C1(0)+A0(0)+G0(1) = 25
        //         site1: T1(35)+C1(5)+T0(3) = 43
        //         total = 68
        assert_eq!(row_sums[1], 68, "cell1 total coverage");
    }

    // ====================================================================
    // Snapshot: filter_sites_by_coverage retains correct sites
    // ====================================================================

    #[test]
    fn snapshot_filter_by_coverage_then_ref_alt() {
        let n_obs = 2;
        let n_vars = 3;

        let mut adata = AnnDataContainer::new(n_obs, n_vars);
        adata.obs_names = vec!["c0".into(), "c1".into()];
        adata.var_names = vec!["chr1:1".into(), "chr1:2".into(), "chr1:3".into()];
        adata.obs = DataFrame::new(vec![
            Series::new("obs_names".into(), adata.obs_names.clone()).into_column(),
        ]).unwrap();

        // Site 0 has high coverage, site 1 has low, site 2 has high
        adata.layers.insert("A1".into(), csr(2, 3, &[(0, 0, 50), (1, 0, 40), (0, 2, 60), (1, 2, 55)]));
        adata.layers.insert("T1".into(), csr(2, 3, &[(0, 1, 1)]));
        adata.layers.insert("G1".into(), csr(2, 3, &[(0, 0, 5), (1, 0, 3), (0, 2, 7), (1, 2, 4)]));
        adata.layers.insert("C1".into(), CsrMatrix::<u32>::zeros(2, 3));

        // filter needs var_names starting with "chr"
        adata.var = DataFrame::new(vec![
            Series::new("var_names".into(), adata.var_names.clone()).into_column(),
        ]).unwrap();

        let filtered = filter_sites_by_coverage(adata, 5).unwrap();
        // site0 col sum = 50+40+5+3 = 98 → pass
        // site1 col sum = 1 → fail
        // site2 col sum = 60+55+7+4 = 126 → pass
        assert_eq!(filtered.n_vars, 2);
        assert_eq!(filtered.var_names, vec!["chr1:1", "chr1:3"]);
    }

    // ====================================================================
    // Snapshot: count_base_levels produces correct per-site base counts
    // ====================================================================

    #[test]
    fn snapshot_count_base_levels() {
        let mut adata = AnnDataContainer::new(3, 2);
        adata.var_names = vec!["s0".into(), "s1".into()];
        adata.var = DataFrame::new(vec![
            Series::new("var_names".into(), adata.var_names.clone()).into_column(),
        ]).unwrap();
        adata.layers.insert("A1".into(), csr(3, 2, &[(0, 0, 10), (1, 0, 20)]));
        adata.layers.insert("T1".into(), csr(3, 2, &[(0, 1, 5), (2, 1, 15)]));
        adata.layers.insert("G1".into(), csr(3, 2, &[(1, 0, 3)]));
        adata.layers.insert("C1".into(), csr(3, 2, &[(2, 0, 1)]));

        let adata = count_base_levels(adata).unwrap();
        let a_col = adata.var.column("A").unwrap().u32().unwrap();
        let t_col = adata.var.column("T").unwrap().u32().unwrap();
        let g_col = adata.var.column("G").unwrap().u32().unwrap();
        let c_col = adata.var.column("C").unwrap().u32().unwrap();
        let cov = adata.var.column("Coverage").unwrap().u32().unwrap();

        // site0: A=30, T=0, G=3, C=1 → cov=34
        assert_eq!(a_col.get(0), Some(30));
        assert_eq!(t_col.get(0), Some(0));
        assert_eq!(g_col.get(0), Some(3));
        assert_eq!(c_col.get(0), Some(1));
        assert_eq!(cov.get(0), Some(34));

        // site1: A=0, T=20, G=0, C=0 → cov=20
        assert_eq!(a_col.get(1), Some(0));
        assert_eq!(t_col.get(1), Some(20));
        assert_eq!(cov.get(1), Some(20));
    }

    // ====================================================================
    // Snapshot: CT editing type pipeline
    // ====================================================================

    #[test]
    fn snapshot_ct_editing_pipeline() {
        let n_obs = 2;
        let n_vars = 2;

        let mut adata = AnnDataContainer::new(n_obs, n_vars);
        adata.obs_names = vec!["c0".into(), "c1".into()];
        adata.var_names = vec!["chr1:10".into(), "chr1:20".into()];
        adata.obs = DataFrame::new(vec![
            Series::new("obs_names".into(), adata.obs_names.clone()).into_column(),
        ]).unwrap();

        // site0: ref=C, site1: ref=G (complementary for CT)
        adata.layers.insert("C1".into(), csr(2, 2, &[(0, 0, 80), (1, 0, 70)]));
        adata.layers.insert("T1".into(), csr(2, 2, &[(0, 0, 10), (1, 0, 8)]));
        adata.layers.insert("G1".into(), csr(2, 2, &[(0, 1, 90), (1, 1, 85)]));
        adata.layers.insert("A1".into(), csr(2, 2, &[(0, 1, 12), (1, 1, 9)]));

        adata.var = DataFrame::new(vec![
            Series::new("var_names".into(), adata.var_names.clone()).into_column(),
            Series::new("ref".into(), &["C", "G"]).into_column(),
            Series::new("is_editing_site".into(), &[true, true]).into_column(),
            Series::new("filter_pass".into(), &[true, true]).into_column(),
        ]).unwrap();

        let adata = calculate_ref_alt_matrices(adata, &EditingType::CT).unwrap();

        let ref_layer = adata.layers.get("ref").unwrap();
        let alt_layer = adata.layers.get("alt").unwrap();

        // site0: ref=C, CT editing → alt=T
        assert_eq!(val(ref_layer, 0, 0), 80, "c0 site0 ref=C");
        assert_eq!(val(alt_layer, 0, 0), 10, "c0 site0 alt=T");
        // site1: ref=G, CT editing complementary → alt=A
        assert_eq!(val(ref_layer, 0, 1), 90, "c0 site1 ref=G");
        assert_eq!(val(alt_layer, 0, 1), 12, "c0 site1 alt=A");

        let adata = calculate_cei(adata).unwrap();
        let adata = calculate_site_mismatch_stats(adata, 'C', 'T').unwrap();
        let ct_ref = adata.var.column("CT_ref").unwrap().u32().unwrap();
        assert_eq!(ct_ref.get(0), Some(150)); // 80 + 70
        assert_eq!(ct_ref.get(1), Some(175)); // 90 + 85
    }
}