use anyhow::Result;
use redicat_lib::bam2mtx::barcode::BarcodeProcessor;
use redicat_lib::bam2mtx::processor::{
apply_encoded_call, decode_base, decode_cell_barcode_into, decode_umi_into, encode_call,
BamProcessorConfig, PositionData, StrandBaseCounts, UMI_CONFLICT_CODE,
};
use rust_htslib::bam::{self, Read};
use rustc_hash::FxHashMap;
use std::path::PathBuf;
use std::sync::Arc;
use super::input::PositionChunk;
fn intern_umi_id(cache: &mut FxHashMap<String, u32>, umi: &str) -> u32 {
if let Some(&id) = cache.get(umi) {
return id;
}
let id = cache.len() as u32;
cache.insert(umi.to_string(), id);
id
}
#[derive(Debug, Clone)]
pub struct SkippedSite {
pub contig: String,
pub pos: u64,
pub depth: u32,
}
pub struct OptimizedChunkProcessor {
bam_path: PathBuf,
config: BamProcessorConfig,
barcode_processor: Arc<BarcodeProcessor>,
tid_lookup: FxHashMap<String, u32>,
contig_names: Arc<Vec<String>>,
count_capacity_hint: usize,
umi_capacity_hint: usize,
}
impl OptimizedChunkProcessor {
pub fn new(
bam_path: PathBuf,
config: BamProcessorConfig,
barcode_processor: Arc<BarcodeProcessor>,
) -> Result<Self> {
let header = bam::IndexedReader::from_path(&bam_path)?
.header()
.to_owned();
let mut tid_lookup =
FxHashMap::with_capacity_and_hasher(header.target_count() as usize, Default::default());
let mut contig_names = Vec::with_capacity(header.target_count() as usize);
for tid in 0..header.target_count() {
if let Ok(name) = std::str::from_utf8(header.tid2name(tid)) {
tid_lookup.insert(name.to_string(), tid);
contig_names.push(name.to_string());
} else {
contig_names.push(String::from("unknown"));
}
}
let barcode_count = barcode_processor.len();
let count_capacity_hint = barcode_count.clamp(64, 4096);
let umi_capacity_hint = count_capacity_hint.saturating_mul(8);
Ok(Self {
bam_path,
config,
barcode_processor,
tid_lookup,
contig_names: Arc::new(contig_names),
count_capacity_hint,
umi_capacity_hint,
})
}
pub fn contig_names(&self) -> Arc<Vec<String>> {
Arc::clone(&self.contig_names)
}
pub fn open_reader(&self) -> Result<bam::IndexedReader> {
bam::IndexedReader::from_path(&self.bam_path)
.map_err(|e| anyhow::anyhow!("Failed to open BAM {}: {}", self.bam_path.display(), e))
}
pub fn process_chunk(
&self,
chunk: &PositionChunk,
reader: &mut bam::IndexedReader,
) -> Result<(Vec<PositionData>, Vec<SkippedSite>)> {
if chunk.is_empty() {
return Ok((Vec::new(), Vec::new()));
}
let chrom = &chunk.positions[0].chrom;
let tid = match self.tid_lookup.get(chrom) {
Some(&tid) => tid,
None => {
return Ok((Vec::new(), Vec::new()));
}
};
let fetch_start = chunk
.positions
.first()
.map(|p| p.pos.saturating_sub(1) as u32)
.unwrap_or(0);
let fetch_end = chunk
.positions
.last()
.map(|p| p.pos as u32)
.unwrap_or(fetch_start);
reader.fetch((tid, fetch_start, fetch_end))?;
let target_positions: Vec<u32> = chunk
.positions
.iter()
.map(|p| p.pos.saturating_sub(1) as u32)
.collect();
let mut chunk_results = Vec::with_capacity(chunk.len());
let mut skipped_sites: Vec<SkippedSite> = Vec::new();
let mut position_index = 0usize;
let has_high_depth = chunk.near_max_depth_count() > 0;
let estimated_count_capacity = if has_high_depth {
self.count_capacity_hint.min(5_000)
} else {
self.count_capacity_hint
};
let estimated_umi_capacity = if has_high_depth {
self.umi_capacity_hint.min(50_000)
} else {
self.umi_capacity_hint
};
let mut counts: FxHashMap<u32, StrandBaseCounts> =
FxHashMap::with_capacity_and_hasher(estimated_count_capacity, Default::default());
let mut umi_consensus: FxHashMap<(u32, u32), u8> =
FxHashMap::with_capacity_and_hasher(estimated_umi_capacity, Default::default());
let mut local_umi_cache: FxHashMap<String, u32> =
FxHashMap::with_capacity_and_hasher(estimated_umi_capacity, Default::default());
let mut cb_buf = String::with_capacity(32);
let mut umi_buf = String::with_capacity(32);
let mut pileups = reader.pileup();
pileups.set_max_depth(self.config.max_depth.min(i32::MAX as u32));
for pileup in pileups {
let pileup = pileup?;
let pile_pos = pileup.pos();
while position_index < target_positions.len()
&& target_positions[position_index] < pile_pos
{
position_index += 1;
}
if position_index >= target_positions.len() {
break;
}
if target_positions[position_index] != pile_pos {
continue;
}
let current_index = position_index;
position_index += 1;
let depth = pileup.depth() as u32;
if depth >= self.config.max_depth {
let position_meta = &chunk.positions[current_index];
skipped_sites.push(SkippedSite {
contig: position_meta.chrom.clone(),
pos: position_meta.pos,
depth,
});
continue;
}
counts.clear();
umi_consensus.clear();
local_umi_cache.clear();
for alignment in pileup.alignments() {
let record = alignment.record();
if record.mapq() < self.config.min_mapping_quality {
continue;
}
let qpos = match alignment.qpos() {
Some(q) => q,
None => continue,
};
let base_qual = record.qual().get(qpos).copied().unwrap_or(0);
if base_qual < self.config.min_base_quality {
continue;
}
let base = decode_base(&record, Some(qpos))?;
if base == 'N' {
continue;
}
let cell_id = {
cb_buf.clear();
match decode_cell_barcode_into(&record, self.config.cell_barcode_tag.as_bytes(), &mut cb_buf)? {
true => match self.barcode_processor.id_of(&cb_buf) {
Some(id) => id,
None => continue,
},
false => continue,
}
};
let umi_id = {
umi_buf.clear();
match decode_umi_into(&record, self.config.umi_tag.as_bytes(), &mut umi_buf)? {
true => intern_umi_id(&mut local_umi_cache, &umi_buf),
false => continue,
}
};
if let Some(encoded) = encode_call(self.config.stranded, base, record.is_reverse())
{
umi_consensus
.entry((cell_id, umi_id))
.and_modify(|existing| {
if *existing != encoded {
*existing = UMI_CONFLICT_CODE;
}
})
.or_insert(encoded);
}
}
for ((cell_id, _umi_id), encoded) in umi_consensus.drain() {
if encoded == UMI_CONFLICT_CODE {
continue;
}
let counts_entry = counts.entry(cell_id).or_default();
apply_encoded_call(self.config.stranded, encoded, counts_entry);
}
let position_meta = &chunk.positions[current_index];
chunk_results.push(PositionData {
contig_id: tid,
pos: position_meta.pos,
counts: counts.drain().collect(),
});
}
Ok((chunk_results, skipped_sites))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn umi_interner_reuses_id_for_repeated_values() {
let mut cache: FxHashMap<String, u32> = FxHashMap::default();
let first = intern_umi_id(&mut cache, "ACGTAC");
let second = intern_umi_id(&mut cache, "ACGTAC");
let third = intern_umi_id(&mut cache, "TTTTGG");
assert_eq!(first, second);
assert_ne!(first, third);
assert_eq!(cache.len(), 2);
}
#[test]
fn umi_interner_cache_can_be_reused_across_positions() {
let mut cache: FxHashMap<String, u32> = FxHashMap::default();
let before_clear = intern_umi_id(&mut cache, "AAAAAA");
assert_eq!(cache.len(), 1);
cache.clear();
let after_clear = intern_umi_id(&mut cache, "AAAAAA");
assert_eq!(cache.len(), 1);
assert_eq!(before_clear, 0);
assert_eq!(after_clear, 0);
}
#[test]
fn umi_interner_assigns_sequential_ids() {
let mut cache: FxHashMap<String, u32> = FxHashMap::default();
assert_eq!(intern_umi_id(&mut cache, "AAA"), 0);
assert_eq!(intern_umi_id(&mut cache, "BBB"), 1);
assert_eq!(intern_umi_id(&mut cache, "CCC"), 2);
assert_eq!(intern_umi_id(&mut cache, "AAA"), 0); assert_eq!(cache.len(), 3);
}
}