crispr_screen 0.2.5

A fast and configurable differential expression analysis tool for CRISPR screens
use super::logging::Logger;
use ndarray::Array2;

type FilterTuple = (Array2<f64>, Vec<String>, Vec<String>);
pub fn filter_low_counts(
    norm_matrix: &Array2<f64>,
    sgrna_names: &[String],
    gene_names: &[String],
    min_base: f64,
    logger: &Logger,
) -> FilterTuple {
    logger.filtering(min_base);
    let sgrna_means = norm_matrix
        .mean_axis(ndarray::Axis(1))
        .expect("Failed to calculate mean of normalized matrix");
    let mask = sgrna_means
        .iter()
        .enumerate()
        .filter(|(_idx, v)| **v >= min_base)
        .map(|(idx, _)| idx)
        .collect::<Vec<usize>>();

    let filt_matrix = norm_matrix.select(ndarray::Axis(0), &mask);
    let filt_sgrna_names = mask.iter().map(|idx| sgrna_names[*idx].clone()).collect();
    let filt_gene_names = mask.iter().map(|idx| gene_names[*idx].clone()).collect();
    let num_filtered = sgrna_names.len() - mask.len();
    logger.num_filtered(num_filtered);
    (filt_matrix, filt_sgrna_names, filt_gene_names)
}

#[cfg(test)]
mod testing {

    use super::*;
    use ndarray::array;

    #[test]
    fn test_filtering() {
        let logger = Logger::new();
        let norm_matrix = array![
            [11.0, 12.0],
            [13.0, 14.0],
            [5.0, 6.0],
            [17.0, 18.0],
            [19.0, 20.0],
        ];
        println!("{:?}", norm_matrix);
        let sgrna_names = vec![
            "sgrna1".to_string(),
            "sgrna2".to_string(),
            "sgrna3".to_string(),
            "sgrna4".to_string(),
            "sgrna5".to_string(),
        ];
        let gene_names = vec![
            "gene1".to_string(),
            "gene2".to_string(),
            "gene1".to_string(),
            "gene2".to_string(),
            "gene1".to_string(),
        ];
        let min_base = 10.0;
        let (filt_matrix, filt_sgrna_names, filt_gene_names) =
            filter_low_counts(&norm_matrix, &sgrna_names, &gene_names, min_base, &logger);

        let expected_matrix = array![[11.0, 12.0], [13.0, 14.0], [17.0, 18.0], [19.0, 20.0]];
        let expected_sgrna_names = vec!["sgrna1", "sgrna2", "sgrna4", "sgrna5"];
        let expected_gene_names = vec!["gene1", "gene2", "gene2", "gene1"];

        assert_eq!(filt_matrix, expected_matrix);
        assert_eq!(filt_sgrna_names, expected_sgrna_names);
        assert_eq!(filt_gene_names, expected_gene_names);
    }
}