redicat 0.4.2

REDICAT - RNA Editing Cellular Assessment Toolkit: A highly parallelized utility for analyzing RNA editing events in single-cell RNA-seq data
Documentation
//! Optimized base matrix operations with efficient memory usage

use crate::core::error::{RedicatError, Result};
use crate::core::sparse::SparseOps;
use log::info;
use nalgebra_sparse::CsrMatrix;
use polars::prelude::*;
use rayon::prelude::*;

use super::anndata_ops::AnnDataContainer;

/// Calculate coverage matrix using highly optimized parallel sparse matrix addition
/// 
/// This implementation uses a divide-and-conquer parallel reduction strategy:
/// 1. Splits matrices into chunks for parallel processing
/// 2. Each chunk is summed independently in parallel
/// 3. Results are merged using a parallel tree reduction
/// 
/// This achieves O(log n) depth parallelism instead of sequential O(n) operations.
pub fn calculate_coverage(adata: &mut AnnDataContainer) -> Result<()> {
    info!("Calculating coverage matrix with parallel sparse operations...");

    // Determine which layers to use based on availability
    let layer_names = if adata.layers.contains_key("A0") {
        // Stranded data - use both strands
        vec!["A0", "T0", "G0", "C0", "A1", "T1", "G1", "C1"]
    } else {
        // Unstranded data - use only positive strand
        vec!["A1", "T1", "G1", "C1"]
    };

    // Collect existing matrices
    let matrices: Vec<&CsrMatrix<u32>> = layer_names
        .iter()
        .filter_map(|&name| adata.layers.get(name))
        .collect();

    if matrices.is_empty() {
        let coverage = CsrMatrix::<u32>::zeros(adata.n_obs, adata.n_vars);
        adata.layers.insert("coverage".to_string(), coverage);
        info!("No base layers found, created empty coverage matrix");
        return Ok(());
    }

    if matrices.len() == 1 {
        let nnz = matrices[0].nnz();
        adata.layers.insert("coverage".to_string(), matrices[0].clone());
        info!("Coverage matrix calculated with {} non-zero elements", nnz);
        return Ok(());
    }

    // Use parallel tree reduction for optimal performance
    // This reduces O(n) sequential additions to O(log n) parallel depth
    let coverage = SparseOps::parallel_sum_matrices(&matrices)?;

    let nnz = coverage.nnz();
    info!("Coverage matrix calculated with {} non-zero elements using parallel reduction", nnz);
    adata.layers.insert("coverage".to_string(), coverage);
    Ok(())
}

/// Optimized site filtering with parallel processing
pub fn filter_sites_by_coverage(
    mut adata: AnnDataContainer,
    min_coverage: u32,
) -> Result<AnnDataContainer> {
    info!("Filtering sites with min_coverage: {}", min_coverage);

    // Ensure coverage is calculated
    calculate_coverage(&mut adata)?;

    // Compute site coverage in parallel
    let site_coverage = adata
        .compute_layer_col_sums("coverage")
        .unwrap_or_else(|| vec![0; adata.n_vars]);

    // Create filter mask using parallel iterator
    let filter_mask: Vec<bool> = adata
        .var_names
        .par_iter()
        .zip(site_coverage.par_iter())
        .map(|(name, &cov)| name.starts_with("chr") && cov >= min_coverage)
        .collect();

    let kept_sites = filter_mask.iter().filter(|&&x| x).count();
    info!(
        "Keeping {} out of {} sites ({}% retained)",
        kept_sites,
        adata.n_vars,
        (kept_sites as f64 / adata.n_vars as f64 * 100.0) as u32
    );

    if kept_sites == 0 {
        return Err(RedicatError::EmptyData(
            "No sites passed coverage filter".to_string(),
        ));
    }

    apply_site_filter(adata, &filter_mask)
}

/// Optimized site filtering with sequential take-filter-reinsert to minimize peak memory.
/// Each layer is removed before filtering so the old matrix can be freed before or
/// during construction of the filtered matrix, keeping at most one old+new pair in
/// memory at a time.  The column filtering itself (`filter_columns_u32`) is still
/// internally parallelized via rayon.
pub fn apply_site_filter(
    mut adata: AnnDataContainer,
    filter_mask: &[bool],
) -> Result<AnnDataContainer> {
    // Pre-compute selected indices once
    let selected_indices: Vec<usize> = filter_mask
        .par_iter()
        .enumerate()
        .filter_map(|(i, &keep)| if keep { Some(i) } else { None })
        .collect();

    if selected_indices.is_empty() {
        return Err(RedicatError::EmptyData(
            "No sites selected after filtering".to_string(),
        ));
    }

    // Process layers one at a time: remove → filter → reinsert.
    // This ensures only one (old + new) matrix pair is live at any moment.
    let layer_names: Vec<String> = adata.layers.keys().cloned().collect();
    for name in layer_names {
        let matrix = adata.layers.remove(&name).unwrap();
        let filtered = SparseOps::filter_columns_u32(&matrix, &selected_indices)?;
        // `matrix` is dropped here, freeing the old allocation
        drop(matrix);
        adata.layers.insert(name, filtered);
    }

    // Filter var DataFrame
    let filtered_var = filter_dataframe_by_indices(&adata.var, &selected_indices)?;

    // Filter var_names in parallel
    let filtered_var_names: Vec<String> = selected_indices
        .par_iter()
        .map(|&i| adata.var_names[i].clone())
        .collect();

    // Update adata with filtered data
    adata.var = filtered_var;
    adata.n_vars = selected_indices.len();
    adata.var_names = filtered_var_names;

    info!("Site filtering completed: {} sites retained", adata.n_vars);
    Ok(adata)
}

/// Optimized base level counting with parallel processing
pub fn count_base_levels(mut adata: AnnDataContainer) -> Result<AnnDataContainer> {
    info!("Counting base levels at each site...");

    // Process all bases in parallel
    let base_counts: Vec<(char, Vec<u32>)> = ['A', 'T', 'G', 'C']
        .par_iter()
        .map(|&base| {
            let layer_name = format!("{}1", base);
            let counts = adata
                .compute_layer_col_sums(&layer_name)
                .unwrap_or_else(|| vec![0u32; adata.n_vars]);
            (base, counts)
        })
        .collect();

    let mut columns: Vec<Column> = Vec::with_capacity(5);
    let mut coverage_totals = vec![0u32; adata.n_vars];

    for (base, counts) in base_counts.into_iter() {
        for (cov, value) in coverage_totals.iter_mut().zip(&counts) {
            *cov = cov.saturating_add(*value);
        }
        let column_name = base.to_string();
        columns.push(Series::new(column_name.into(), counts).into_column());
    }

    columns.push(Series::new("Coverage".into(), coverage_totals).into_column());

    if !columns.is_empty() {
        adata.var.hstack_mut(&columns)?;
    }

    info!("Base levels counted successfully");
    Ok(adata)
}

/// Optimized DataFrame filtering using vectorized operations
fn filter_dataframe_by_indices(df: &DataFrame, indices: &[usize]) -> Result<DataFrame> {
    if indices.is_empty() {
        return Err(RedicatError::EmptyData(
            "No indices provided for filtering".to_string(),
        ));
    }

    let idx_chunked = UInt32Chunked::from_vec(
        "idx".into(),
        indices.iter().map(|&idx| idx as u32).collect(),
    );
    df.take(&idx_chunked)
        .map_err(|e| RedicatError::DataProcessing(format!("Failed to filter DataFrame: {}", e)))
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::core::sparse::SparseOps;
    use polars::prelude::{DataFrame, Series};

    fn make_matrix(
        nrows: usize,
        ncols: usize,
        triplets: Vec<(usize, usize, u32)>,
    ) -> CsrMatrix<u32> {
        SparseOps::from_triplets_u32(nrows, ncols, triplets).unwrap()
    }

    fn set_var_names(adata: &mut AnnDataContainer, names: Vec<&str>) {
        adata.var_names = names.iter().map(|s| s.to_string()).collect();
        adata.var = DataFrame::new(vec![Series::new("var_names".into(), adata.var_names.clone()).into()])
            .unwrap();
    }

    #[test]
    fn calculate_coverage_sums_existing_layers() {
        let mut adata = AnnDataContainer::new(2, 3);
        adata.layers.insert(
            "A1".to_string(),
            make_matrix(2, 3, vec![(0, 0, 2), (1, 2, 1)]),
        );
        adata.layers.insert(
            "T1".to_string(),
            make_matrix(2, 3, vec![(0, 0, 3), (0, 1, 4)]),
        );

        calculate_coverage(&mut adata).unwrap();

        let coverage = adata.layers.get("coverage").unwrap();
        assert_eq!(SparseOps::compute_col_sums(coverage), vec![5, 4, 1]);
        assert_eq!(SparseOps::compute_row_sums(coverage), vec![9, 1]);
    }

    #[test]
    fn filter_sites_by_coverage_applies_chr_prefix_and_threshold() {
        let mut adata = AnnDataContainer::new(2, 3);
        set_var_names(&mut adata, vec!["chr22:10", "chr1:20", "site_3"]);

        adata.layers.insert(
            "A1".to_string(),
            make_matrix(2, 3, vec![(0, 0, 2), (0, 2, 100), (1, 0, 3)]),
        );
        adata.layers.insert("T1".to_string(), make_matrix(2, 3, vec![(1, 1, 3)]));

        let filtered = filter_sites_by_coverage(adata, 4).unwrap();

        assert_eq!(filtered.n_vars, 1);
        assert_eq!(filtered.var_names, vec!["chr22:10".to_string()]);
        for matrix in filtered.layers.values() {
            assert_eq!(matrix.ncols(), 1);
        }
    }

    #[test]
    fn filter_sites_by_coverage_returns_empty_error() {
        let mut adata = AnnDataContainer::new(2, 2);
        set_var_names(&mut adata, vec!["site1", "site2"]);
        adata
            .layers
            .insert("A1".to_string(), make_matrix(2, 2, vec![(0, 0, 1), (1, 1, 1)]));

        let err = filter_sites_by_coverage(adata, 10).unwrap_err();
        assert!(format!("{}", err).contains("No sites passed coverage filter"));
    }

    // ===== New comprehensive tests added before refactoring =====

    #[test]
    fn apply_site_filter_filters_layers_var_and_names() {
        let mut adata = AnnDataContainer::new(2, 4);
        set_var_names(&mut adata, vec!["chr1:1", "chr1:2", "chr1:3", "chr1:4"]);
        adata.layers.insert("A1".to_string(), make_matrix(2, 4, vec![
            (0, 0, 1), (0, 1, 2), (0, 2, 3), (0, 3, 4),
            (1, 0, 5), (1, 1, 6), (1, 2, 7), (1, 3, 8),
        ]));
        let mask = vec![true, false, true, false];
        let filtered = apply_site_filter(adata, &mask).unwrap();

        assert_eq!(filtered.n_vars, 2);
        assert_eq!(filtered.var_names, vec!["chr1:1", "chr1:3"]);
        assert_eq!(filtered.var.height(), 2);

        let layer = filtered.layers.get("A1").unwrap();
        assert_eq!(layer.ncols(), 2);
        // row 0: old cols 0,2 -> new cols 0,1 -> values 1,3
        let row0 = layer.row(0);
        let vals: Vec<u32> = row0.values().to_vec();
        assert_eq!(vals, vec![1, 3]);
    }

    #[test]
    fn apply_site_filter_all_false_returns_error() {
        let adata = AnnDataContainer::new(2, 3);
        let mask = vec![false, false, false];
        assert!(apply_site_filter(adata, &mask).is_err());
    }

    #[test]
    fn apply_site_filter_all_true_keeps_everything() {
        let mut adata = AnnDataContainer::new(2, 3);
        set_var_names(&mut adata, vec!["chr1:1", "chr1:2", "chr1:3"]);
        adata.layers.insert("G1".to_string(), make_matrix(2, 3, vec![(0, 0, 42)]));
        let mask = vec![true, true, true];
        let filtered = apply_site_filter(adata, &mask).unwrap();
        assert_eq!(filtered.n_vars, 3);
        assert_eq!(filtered.var_names.len(), 3);
    }

    #[test]
    fn calculate_coverage_handles_stranded_data() {
        let mut adata = AnnDataContainer::new(2, 2);
        adata.layers.insert("A0".to_string(), make_matrix(2, 2, vec![(0, 0, 1)]));
        adata.layers.insert("A1".to_string(), make_matrix(2, 2, vec![(0, 0, 2)]));
        adata.layers.insert("T0".to_string(), make_matrix(2, 2, vec![(1, 1, 3)]));
        adata.layers.insert("T1".to_string(), make_matrix(2, 2, vec![(1, 1, 4)]));
        adata.layers.insert("G0".to_string(), make_matrix(2, 2, vec![]));
        adata.layers.insert("G1".to_string(), make_matrix(2, 2, vec![]));
        adata.layers.insert("C0".to_string(), make_matrix(2, 2, vec![]));
        adata.layers.insert("C1".to_string(), make_matrix(2, 2, vec![]));
        calculate_coverage(&mut adata).unwrap();
        let cov = adata.layers.get("coverage").unwrap();
        let col_sums = SparseOps::compute_col_sums(cov);
        assert_eq!(col_sums, vec![3, 7]);
    }

    #[test]
    fn calculate_coverage_no_layers_creates_empty() {
        let mut adata = AnnDataContainer::new(3, 2);
        calculate_coverage(&mut adata).unwrap();
        let cov = adata.layers.get("coverage").unwrap();
        assert_eq!(cov.nnz(), 0);
        assert_eq!(cov.nrows(), 3);
        assert_eq!(cov.ncols(), 2);
    }

    #[test]
    fn count_base_levels_uses_saturating_add() {
        // Ensure coverage doesn't overflow with large values
        let mut adata = AnnDataContainer::new(1, 1);
        adata.layers.insert("A1".to_string(), make_matrix(1, 1, vec![(0, 0, u32::MAX - 1)]));
        adata.layers.insert("T1".to_string(), make_matrix(1, 1, vec![(0, 0, 5)]));
        adata.layers.insert("G1".to_string(), make_matrix(1, 1, vec![]));
        adata.layers.insert("C1".to_string(), make_matrix(1, 1, vec![]));
        let out = count_base_levels(adata).unwrap();
        let cov = out.var.column("Coverage").unwrap().u32().unwrap();
        // Should saturate at u32::MAX
        assert_eq!(cov.get(0), Some(u32::MAX));
    }

    // ===== End of new comprehensive tests =====

    #[test]
    fn count_base_levels_adds_base_columns_and_total_coverage() {
        let mut adata = AnnDataContainer::new(2, 3);
        adata.layers.insert("A1".to_string(), make_matrix(2, 3, vec![(0, 0, 2)]));
        adata
            .layers
            .insert("T1".to_string(), make_matrix(2, 3, vec![(1, 0, 1), (1, 1, 2)]));
        adata
            .layers
            .insert("G1".to_string(), make_matrix(2, 3, vec![(0, 2, 5)]));
        adata.layers.insert("C1".to_string(), make_matrix(2, 3, vec![]));

        let out = count_base_levels(adata).unwrap();

        let cov = out.var.column("Coverage").unwrap().u32().unwrap();
        assert_eq!(cov.get(0), Some(3));
        assert_eq!(cov.get(1), Some(2));
        assert_eq!(cov.get(2), Some(5));

        assert!(out.var.column("A").is_ok());
        assert!(out.var.column("T").is_ok());
        assert!(out.var.column("G").is_ok());
        assert!(out.var.column("C").is_ok());
    }
}