redicat 0.4.2

REDICAT - RNA Editing Cellular Assessment Toolkit: A highly parallelized utility for analyzing RNA editing events in single-cell RNA-seq data
Documentation
use anyhow::Result;
use redicat_lib::bam2mtx::barcode::BarcodeProcessor;
use redicat_lib::bam2mtx::processor::{
    apply_encoded_call, decode_base, decode_cell_barcode_into, decode_umi_into, encode_call,
    BamProcessorConfig, PositionData, StrandBaseCounts, UMI_CONFLICT_CODE,
};
use rust_htslib::bam::{self, Read};
use rustc_hash::FxHashMap;
use std::path::PathBuf;
use std::sync::Arc;

use super::input::PositionChunk;

/// Integer-encode UMI strings in a chunk-local cache, returning a compact u32 id.
///
/// This keeps UMI deduplication effective within a pileup while avoiding global,
/// run-long retention of all distinct UMIs. Using u32 ids instead of `Arc<str>`
/// reduces the `umi_consensus` key from ~24 bytes to 8 bytes and eliminates
/// atomic reference-counting overhead.
fn intern_umi_id(cache: &mut FxHashMap<String, u32>, umi: &str) -> u32 {
    if let Some(&id) = cache.get(umi) {
        return id;
    }
    let id = cache.len() as u32;
    cache.insert(umi.to_string(), id);
    id
}

/// Metadata for genomic sites skipped due to exceeding the configured depth ceiling.
#[derive(Debug, Clone)]
pub struct SkippedSite {
    pub contig: String,
    pub pos: u64,
    pub depth: u32,
}

/// Chunk processor with aggressive reuse of BAM readers and pre-sized hash maps.
/// Uses chunk-local UMI interning to keep peak memory bounded.
pub struct OptimizedChunkProcessor {
    bam_path: PathBuf,
    config: BamProcessorConfig,
    barcode_processor: Arc<BarcodeProcessor>,
    tid_lookup: FxHashMap<String, u32>,
    contig_names: Arc<Vec<String>>,
    count_capacity_hint: usize,
    umi_capacity_hint: usize,
}

impl OptimizedChunkProcessor {
    pub fn new(
        bam_path: PathBuf,
        config: BamProcessorConfig,
        barcode_processor: Arc<BarcodeProcessor>,
    ) -> Result<Self> {
        let header = bam::IndexedReader::from_path(&bam_path)?
            .header()
            .to_owned();
        let mut tid_lookup =
            FxHashMap::with_capacity_and_hasher(header.target_count() as usize, Default::default());

        let mut contig_names = Vec::with_capacity(header.target_count() as usize);
        for tid in 0..header.target_count() {
            if let Ok(name) = std::str::from_utf8(header.tid2name(tid)) {
                tid_lookup.insert(name.to_string(), tid);
                contig_names.push(name.to_string());
            } else {
                contig_names.push(String::from("unknown"));
            }
        }

        let barcode_count = barcode_processor.len();
        let count_capacity_hint = barcode_count.clamp(64, 4096);
        let umi_capacity_hint = count_capacity_hint.saturating_mul(8);

        Ok(Self {
            bam_path,
            config,
            barcode_processor,
            tid_lookup,
            contig_names: Arc::new(contig_names),
            count_capacity_hint,
            umi_capacity_hint,
        })
    }

    pub fn contig_names(&self) -> Arc<Vec<String>> {
        Arc::clone(&self.contig_names)
    }

    /// Create a new BAM reader for this processor's BAM file.
    ///
    /// Callers (typically Rayon thread init closures) use this to obtain a
    /// thread-local reader that is reused across many `process_chunk` calls.
    pub fn open_reader(&self) -> Result<bam::IndexedReader> {
        bam::IndexedReader::from_path(&self.bam_path)
            .map_err(|e| anyhow::anyhow!("Failed to open BAM {}: {}", self.bam_path.display(), e))
    }

    pub fn process_chunk(
        &self,
        chunk: &PositionChunk,
        reader: &mut bam::IndexedReader,
    ) -> Result<(Vec<PositionData>, Vec<SkippedSite>)> {
        if chunk.is_empty() {
            return Ok((Vec::new(), Vec::new()));
        }

        let chrom = &chunk.positions[0].chrom;
        let tid = match self.tid_lookup.get(chrom) {
            Some(&tid) => tid,
            None => {
                return Ok((Vec::new(), Vec::new()));
            }
        };

        let fetch_start = chunk
            .positions
            .first()
            .map(|p| p.pos.saturating_sub(1) as u32)
            .unwrap_or(0);
        let fetch_end = chunk
            .positions
            .last()
            .map(|p| p.pos as u32)
            .unwrap_or(fetch_start);

        reader.fetch((tid, fetch_start, fetch_end))?;

        let target_positions: Vec<u32> = chunk
            .positions
            .iter()
            .map(|p| p.pos.saturating_sub(1) as u32)
            .collect();

        let mut chunk_results = Vec::with_capacity(chunk.len());
        let mut skipped_sites: Vec<SkippedSite> = Vec::new();
        let mut position_index = 0usize;

        // Improved HashMap capacity estimation based on chunk characteristics
        let has_high_depth = chunk.near_max_depth_count() > 0;
        let estimated_count_capacity = if has_high_depth {
            self.count_capacity_hint.min(5_000)
        } else {
            self.count_capacity_hint
        };
        let estimated_umi_capacity = if has_high_depth {
            self.umi_capacity_hint.min(50_000)
        } else {
            self.umi_capacity_hint
        };

        let mut counts: FxHashMap<u32, StrandBaseCounts> =
            FxHashMap::with_capacity_and_hasher(estimated_count_capacity, Default::default());
        let mut umi_consensus: FxHashMap<(u32, u32), u8> =
            FxHashMap::with_capacity_and_hasher(estimated_umi_capacity, Default::default());
        let mut local_umi_cache: FxHashMap<String, u32> =
            FxHashMap::with_capacity_and_hasher(estimated_umi_capacity, Default::default());
        // Reusable buffers for BAM tag decoding — avoids per-alignment String allocations.
        let mut cb_buf = String::with_capacity(32);
        let mut umi_buf = String::with_capacity(32);

        let mut pileups = reader.pileup();
        pileups.set_max_depth(self.config.max_depth.min(i32::MAX as u32));

        for pileup in pileups {
            let pileup = pileup?;
            let pile_pos = pileup.pos();

            while position_index < target_positions.len()
                && target_positions[position_index] < pile_pos
            {
                position_index += 1;
            }

            if position_index >= target_positions.len() {
                break;
            }

            if target_positions[position_index] != pile_pos {
                continue;
            }

            let current_index = position_index;
            position_index += 1;

            let depth = pileup.depth() as u32;
            if depth >= self.config.max_depth {
                let position_meta = &chunk.positions[current_index];
                skipped_sites.push(SkippedSite {
                    contig: position_meta.chrom.clone(),
                    pos: position_meta.pos,
                    depth,
                });
                continue;
            }

            counts.clear();
            umi_consensus.clear();
            local_umi_cache.clear();

            for alignment in pileup.alignments() {
                let record = alignment.record();

                if record.mapq() < self.config.min_mapping_quality {
                    continue;
                }

                let qpos = match alignment.qpos() {
                    Some(q) => q,
                    None => continue,
                };

                let base_qual = record.qual().get(qpos).copied().unwrap_or(0);
                if base_qual < self.config.min_base_quality {
                    continue;
                }

                let base = decode_base(&record, Some(qpos))?;
                if base == 'N' {
                    continue;
                }

                let cell_id = {
                    cb_buf.clear();
                    match decode_cell_barcode_into(&record, self.config.cell_barcode_tag.as_bytes(), &mut cb_buf)? {
                        true => match self.barcode_processor.id_of(&cb_buf) {
                            Some(id) => id,
                            None => continue,
                        },
                        false => continue,
                    }
                };

                let umi_id = {
                    umi_buf.clear();
                    match decode_umi_into(&record, self.config.umi_tag.as_bytes(), &mut umi_buf)? {
                        true => intern_umi_id(&mut local_umi_cache, &umi_buf),
                        false => continue,
                    }
                };

                if let Some(encoded) = encode_call(self.config.stranded, base, record.is_reverse())
                {
                    umi_consensus
                        .entry((cell_id, umi_id))
                        .and_modify(|existing| {
                            if *existing != encoded {
                                *existing = UMI_CONFLICT_CODE;
                            }
                        })
                        .or_insert(encoded);
                }
            }

            for ((cell_id, _umi_id), encoded) in umi_consensus.drain() {
                if encoded == UMI_CONFLICT_CODE {
                    continue;
                }

                let counts_entry = counts.entry(cell_id).or_default();

                apply_encoded_call(self.config.stranded, encoded, counts_entry);
            }

            let position_meta = &chunk.positions[current_index];
            chunk_results.push(PositionData {
                contig_id: tid,
                pos: position_meta.pos,
                counts: counts.drain().collect(),
            });
        }

        Ok((chunk_results, skipped_sites))
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn umi_interner_reuses_id_for_repeated_values() {
        let mut cache: FxHashMap<String, u32> = FxHashMap::default();

        let first = intern_umi_id(&mut cache, "ACGTAC");
        let second = intern_umi_id(&mut cache, "ACGTAC");
        let third = intern_umi_id(&mut cache, "TTTTGG");

        assert_eq!(first, second);
        assert_ne!(first, third);
        assert_eq!(cache.len(), 2);
    }

    #[test]
    fn umi_interner_cache_can_be_reused_across_positions() {
        let mut cache: FxHashMap<String, u32> = FxHashMap::default();

        let before_clear = intern_umi_id(&mut cache, "AAAAAA");
        assert_eq!(cache.len(), 1);

        cache.clear();

        // After clearing, re-interning may yield a different ID (since the
        // cache was cleared), but the important property is that the cache
        // can be reused without errors.
        let after_clear = intern_umi_id(&mut cache, "AAAAAA");
        assert_eq!(cache.len(), 1);
        // Both are the first entry in their respective cache state, so id=0
        assert_eq!(before_clear, 0);
        assert_eq!(after_clear, 0);
    }

    #[test]
    fn umi_interner_assigns_sequential_ids() {
        let mut cache: FxHashMap<String, u32> = FxHashMap::default();
        assert_eq!(intern_umi_id(&mut cache, "AAA"), 0);
        assert_eq!(intern_umi_id(&mut cache, "BBB"), 1);
        assert_eq!(intern_umi_id(&mut cache, "CCC"), 2);
        assert_eq!(intern_umi_id(&mut cache, "AAA"), 0); // Repeat
        assert_eq!(cache.len(), 3);
    }
}