use std::path::Path;
use std::sync::Arc;
use anyhow::{anyhow, Result};
use rust_htslib::bam::{self, pileup::Alignment, record::Record, Read};
use rustc_hash::FxHashMap;
use crate::pipeline::bam2mtx::barcode::BarcodeProcessor;
#[derive(Debug, Clone, Default, serde::Serialize)]
pub struct BaseCounts {
pub a: u32,
pub t: u32,
pub g: u32,
pub c: u32,
}
#[derive(Debug, Clone, Default, serde::Serialize)]
pub struct StrandBaseCounts {
pub forward: BaseCounts,
pub reverse: BaseCounts,
}
#[derive(Debug, Clone, serde::Serialize)]
pub struct PositionData {
pub contig_id: u32,
pub pos: u64,
pub counts: FxHashMap<u32, StrandBaseCounts>,
}
pub const UMI_CONFLICT_CODE: u8 = u8::MAX;
fn clean_tag_value(raw: &str) -> Option<String> {
let clean = raw.split('-').next().unwrap_or(raw).trim();
if clean.is_empty() || clean == "-" {
None
} else {
Some(clean.to_string())
}
}
fn clean_tag_value_into(raw: &str, buf: &mut String) -> bool {
let clean = raw.split('-').next().unwrap_or(raw).trim();
if clean.is_empty() || clean == "-" {
false
} else {
buf.push_str(clean);
true
}
}
pub fn decode_cell_barcode(record: &Record, tag: &[u8]) -> Result<Option<String>> {
match record.aux(tag) {
Ok(bam::record::Aux::String(s)) => Ok(clean_tag_value(s)),
Ok(bam::record::Aux::ArrayU8(arr)) => {
let bytes: Vec<u8> = arr.iter().collect();
let raw = std::str::from_utf8(&bytes)?;
Ok(clean_tag_value(raw))
}
Ok(_) => Ok(None),
Err(_) => Ok(None),
}
}
pub fn decode_cell_barcode_into(record: &Record, tag: &[u8], buf: &mut String) -> Result<bool> {
match record.aux(tag) {
Ok(bam::record::Aux::String(s)) => Ok(clean_tag_value_into(s, buf)),
Ok(bam::record::Aux::ArrayU8(arr)) => {
let bytes: Vec<u8> = arr.iter().collect();
let raw = std::str::from_utf8(&bytes)?;
Ok(clean_tag_value_into(raw, buf))
}
Ok(_) => Ok(false),
Err(_) => Ok(false),
}
}
pub fn decode_umi(record: &Record, tag: &[u8]) -> Result<Option<String>> {
match record.aux(tag) {
Ok(bam::record::Aux::String(s)) => Ok(clean_tag_value(s)),
Ok(bam::record::Aux::ArrayU8(arr)) => {
let bytes: Vec<u8> = arr.iter().collect();
let raw = std::str::from_utf8(&bytes)?;
Ok(clean_tag_value(raw))
}
Ok(_) => Ok(None),
Err(_) => Ok(None),
}
}
pub fn decode_umi_into(record: &Record, tag: &[u8], buf: &mut String) -> Result<bool> {
match record.aux(tag) {
Ok(bam::record::Aux::String(s)) => Ok(clean_tag_value_into(s, buf)),
Ok(bam::record::Aux::ArrayU8(arr)) => {
let bytes: Vec<u8> = arr.iter().collect();
let raw = std::str::from_utf8(&bytes)?;
Ok(clean_tag_value_into(raw, buf))
}
Ok(_) => Ok(false),
Err(_) => Ok(false),
}
}
pub fn decode_base(record: &Record, qpos: Option<usize>) -> Result<char> {
let qpos = qpos.ok_or_else(|| anyhow!("Invalid query position"))?;
let seq = record.seq();
let base = seq.as_bytes()[qpos];
Ok(match base {
b'A' | b'a' => 'A',
b'T' | b't' => 'T',
b'G' | b'g' => 'G',
b'C' | b'c' => 'C',
_ => 'N',
})
}
#[inline]
pub fn encode_call(stranded: bool, base: char, is_reverse: bool) -> Option<u8> {
let base_code = match base {
'A' => 0,
'T' => 1,
'G' => 2,
'C' => 3,
_ => return None,
};
if stranded {
let strand_bit = if is_reverse { 1 } else { 0 };
Some((base_code << 1) | strand_bit)
} else {
Some(base_code)
}
}
#[inline]
pub fn apply_encoded_call(stranded: bool, code: u8, counts_entry: &mut StrandBaseCounts) {
if stranded {
let strand_bit = code & 1;
let base_code = code >> 1;
let target = if strand_bit == 1 {
&mut counts_entry.reverse
} else {
&mut counts_entry.forward
};
match base_code {
0 => target.a += 1,
1 => target.t += 1,
2 => target.g += 1,
3 => target.c += 1,
_ => {}
}
} else {
match code {
0 => counts_entry.forward.a += 1,
1 => counts_entry.forward.t += 1,
2 => counts_entry.forward.g += 1,
3 => counts_entry.forward.c += 1,
_ => {}
}
}
}
#[derive(Debug, Clone)]
pub struct BamProcessorConfig {
pub min_mapping_quality: u8,
pub min_base_quality: u8,
pub stranded: bool,
pub max_depth: u32,
pub umi_tag: String,
pub cell_barcode_tag: String,
}
impl Default for BamProcessorConfig {
fn default() -> Self {
Self {
min_mapping_quality: 255,
min_base_quality: 30,
stranded: true,
max_depth: 65_536,
umi_tag: "UB".to_string(),
cell_barcode_tag: "CB".to_string(),
}
}
}
pub struct BamProcessor {
config: BamProcessorConfig,
barcode_processor: Arc<BarcodeProcessor>,
}
impl BamProcessor {
pub fn new(config: BamProcessorConfig, barcode_processor: Arc<BarcodeProcessor>) -> Self {
Self {
config,
barcode_processor,
}
}
pub fn process_position(&self, bam_path: &Path, chrom: &str, pos: u64) -> Result<PositionData> {
let mut reader = bam::IndexedReader::from_path(bam_path)?;
let start_pos = (pos - 1) as u32;
let end_pos = pos as u32;
let header = reader.header().to_owned();
let tid = header
.tid(chrom.as_bytes())
.ok_or_else(|| anyhow::anyhow!("Chromosome '{}' not found", chrom))?;
reader.fetch((tid, start_pos, end_pos))?;
let mut pileups: bam::pileup::Pileups<'_, bam::IndexedReader> = reader.pileup();
pileups.set_max_depth(self.config.max_depth.min(i32::MAX as u32));
let mut counts: FxHashMap<u32, StrandBaseCounts> = FxHashMap::default();
let mut umi_consensus: FxHashMap<(u32, String), u8> = FxHashMap::default();
for pileup in pileups {
let pileup = pileup?;
if pileup.pos() != start_pos {
continue;
}
if (pileup.depth() as u32) >= self.config.max_depth {
continue;
}
for read in pileup.alignments() {
if !self.should_process_read(&read) {
continue;
}
let record = read.record();
let cell_id =
match decode_cell_barcode(&record, self.config.cell_barcode_tag.as_bytes())? {
Some(barcode) => match self.barcode_processor.id_of(&barcode) {
Some(id) => id,
None => continue,
},
None => continue,
};
let umi = match decode_umi(&record, self.config.umi_tag.as_bytes())? {
Some(umi) => umi,
None => continue,
};
let base = decode_base(&record, read.qpos())?;
if let Some(encoded) = encode_call(self.config.stranded, base, record.is_reverse())
{
umi_consensus
.entry((cell_id, umi))
.and_modify(|existing| {
if *existing != encoded {
*existing = UMI_CONFLICT_CODE;
}
})
.or_insert(encoded);
}
}
}
for ((cell_id, _umi), encoded) in umi_consensus.drain() {
if encoded == UMI_CONFLICT_CODE {
continue;
}
let counts_entry = counts.entry(cell_id).or_default();
apply_encoded_call(self.config.stranded, encoded, counts_entry);
}
Ok(PositionData {
contig_id: tid,
pos,
counts,
})
}
fn should_process_read(&self, read: &Alignment) -> bool {
if read.is_del() || read.is_refskip() {
return false;
}
let record = read.record();
if record.mapq() < self.config.min_mapping_quality {
return false;
}
if let Some(qpos) = read.qpos() {
if let Some(qual) = record.qual().get(qpos) {
if *qual < self.config.min_base_quality {
return false;
}
}
}
true
}
}
#[cfg(test)]
mod tests {
use super::*;
use rust_htslib::bam::{self, Read};
use std::collections::BTreeSet;
fn collect_barcodes_at_pos(
bam_path: &Path,
chrom: &str,
pos: u64,
cell_tag: &str,
) -> Result<Vec<String>> {
let mut reader = bam::IndexedReader::from_path(bam_path)?;
let header = reader.header().to_owned();
let tid = header
.tid(chrom.as_bytes())
.ok_or_else(|| anyhow!("chromosome '{}' not found", chrom))?;
reader.fetch((tid, (pos - 1) as u32, pos as u32))?;
let mut barcodes = BTreeSet::new();
for pileup in reader.pileup() {
let pileup = pileup?;
if pileup.pos() != (pos - 1) as u32 {
continue;
}
for aln in pileup.alignments() {
if aln.is_del() || aln.is_refskip() {
continue;
}
let record = aln.record();
if let Some(cb) = decode_cell_barcode(&record, cell_tag.as_bytes())? {
barcodes.insert(cb);
}
}
}
Ok(barcodes.into_iter().collect())
}
fn manual_consensus(
bam_path: &Path,
chrom: &str,
pos: u64,
config: &BamProcessorConfig,
barcode_processor: &BarcodeProcessor,
) -> Result<FxHashMap<u32, StrandBaseCounts>> {
let mut reader = bam::IndexedReader::from_path(bam_path)?;
let header = reader.header().to_owned();
let tid = header
.tid(chrom.as_bytes())
.ok_or_else(|| anyhow!("chromosome '{}' not found", chrom))?;
reader.fetch((tid, (pos - 1) as u32, pos as u32))?;
let mut pileups = reader.pileup();
pileups.set_max_depth(config.max_depth.min(i32::MAX as u32));
let mut umi_consensus: FxHashMap<(u32, String), u8> = FxHashMap::default();
let mut counts: FxHashMap<u32, StrandBaseCounts> = FxHashMap::default();
for pileup in pileups {
let pileup = pileup?;
if pileup.pos() != (pos - 1) as u32 {
continue;
}
if (pileup.depth() as u32) >= config.max_depth {
continue;
}
for read in pileup.alignments() {
if read.is_del() || read.is_refskip() {
continue;
}
let record = read.record();
if record.mapq() < config.min_mapping_quality {
continue;
}
if let Some(qpos) = read.qpos() {
if let Some(qual) = record.qual().get(qpos) {
if *qual < config.min_base_quality {
continue;
}
}
}
let cell_id = match decode_cell_barcode(&record, config.cell_barcode_tag.as_bytes())? {
Some(cb) => match barcode_processor.id_of(&cb) {
Some(id) => id,
None => continue,
},
None => continue,
};
let umi = match decode_umi(&record, config.umi_tag.as_bytes())? {
Some(umi) => umi,
None => continue,
};
let base = decode_base(&record, read.qpos())?;
if let Some(encoded) = encode_call(config.stranded, base, record.is_reverse()) {
umi_consensus
.entry((cell_id, umi))
.and_modify(|existing| {
if *existing != encoded {
*existing = UMI_CONFLICT_CODE;
}
})
.or_insert(encoded);
}
}
}
for ((cell_id, _umi), encoded) in umi_consensus.drain() {
if encoded == UMI_CONFLICT_CODE {
continue;
}
let counts_entry = counts.entry(cell_id).or_default();
apply_encoded_call(config.stranded, encoded, counts_entry);
}
Ok(counts)
}
#[test]
fn clean_tag_value_strips_suffix_and_whitespace() {
assert_eq!(clean_tag_value("AAACCTG-1"), Some("AAACCTG".to_string()));
assert_eq!(clean_tag_value(" TTTGCAA "), Some("TTTGCAA".to_string()));
assert_eq!(clean_tag_value("-"), None);
assert_eq!(clean_tag_value(" "), None);
}
#[test]
fn encode_and_apply_calls_work_for_stranded_and_unstranded() {
let mut stranded_counts = StrandBaseCounts::default();
let mut unstranded_counts = StrandBaseCounts::default();
let fwd_a = encode_call(true, 'A', false).unwrap();
let rev_g = encode_call(true, 'G', true).unwrap();
apply_encoded_call(true, fwd_a, &mut stranded_counts);
apply_encoded_call(true, rev_g, &mut stranded_counts);
assert_eq!(stranded_counts.forward.a, 1);
assert_eq!(stranded_counts.reverse.g, 1);
let t = encode_call(false, 'T', true).unwrap();
let c = encode_call(false, 'C', false).unwrap();
apply_encoded_call(false, t, &mut unstranded_counts);
apply_encoded_call(false, c, &mut unstranded_counts);
assert_eq!(unstranded_counts.forward.t, 1);
assert_eq!(unstranded_counts.forward.c, 1);
assert_eq!(unstranded_counts.reverse.t, 0);
assert!(encode_call(false, 'N', false).is_none());
}
#[test]
fn process_position_chr22_matches_manual_consensus() -> Result<()> {
let bam_path = Path::new("test/chr22.bam");
if !bam_path.exists() {
return Ok(());
}
let chrom = "chr22";
let pos = 50_783_283u64;
let config = BamProcessorConfig {
min_mapping_quality: 255,
min_base_quality: 30,
stranded: true,
max_depth: 10_000,
umi_tag: "UB".to_string(),
cell_barcode_tag: "CB".to_string(),
};
let barcodes = collect_barcodes_at_pos(bam_path, chrom, pos, &config.cell_barcode_tag)?;
if barcodes.is_empty() {
return Ok(());
}
let barcode_processor = Arc::new(BarcodeProcessor::from_vec(barcodes));
let processor = BamProcessor::new(config.clone(), Arc::clone(&barcode_processor));
let observed = processor.process_position(bam_path, chrom, pos)?;
let expected = manual_consensus(bam_path, chrom, pos, &config, &barcode_processor)?;
assert_eq!(observed.pos, pos);
assert_eq!(observed.counts.len(), expected.len());
for (cell_id, exp) in expected.iter() {
let got = observed
.counts
.get(cell_id)
.unwrap_or_else(|| panic!("missing cell_id {} in observed counts", cell_id));
assert_eq!(got.forward.a, exp.forward.a);
assert_eq!(got.forward.t, exp.forward.t);
assert_eq!(got.forward.g, exp.forward.g);
assert_eq!(got.forward.c, exp.forward.c);
assert_eq!(got.reverse.a, exp.reverse.a);
assert_eq!(got.reverse.t, exp.reverse.t);
assert_eq!(got.reverse.g, exp.reverse.g);
assert_eq!(got.reverse.c, exp.reverse.c);
}
Ok(())
}
#[test]
fn decode_base_handles_all_bases_and_n() {
for (base, expected_code_unstranded) in [('A', 0u8), ('T', 1), ('G', 2), ('C', 3)] {
let code = encode_call(false, base, false).unwrap();
assert_eq!(code, expected_code_unstranded, "unstranded encode for {}", base);
let code_rev = encode_call(false, base, true).unwrap();
assert_eq!(code_rev, expected_code_unstranded, "unstranded+reverse for {}", base);
}
assert!(encode_call(false, 'N', false).is_none());
assert!(encode_call(true, 'N', true).is_none());
}
#[test]
fn encode_call_stranded_distinguishes_strands() {
for base in ['A', 'T', 'G', 'C'] {
let fwd = encode_call(true, base, false).unwrap();
let rev = encode_call(true, base, true).unwrap();
assert_ne!(fwd, rev, "stranded codes should differ for base {}", base);
assert_eq!(fwd & 1, 0);
assert_eq!(rev & 1, 1);
}
}
#[test]
fn apply_encoded_call_increments_correct_fields() {
let bases = ['A', 'T', 'G', 'C'];
for &base in &bases {
for is_reverse in [false, true] {
let mut counts = StrandBaseCounts::default();
let code = encode_call(true, base, is_reverse).unwrap();
apply_encoded_call(true, code, &mut counts);
let (fwd, rev) = (&counts.forward, &counts.reverse);
let total = fwd.a + fwd.t + fwd.g + fwd.c + rev.a + rev.t + rev.g + rev.c;
assert_eq!(total, 1, "exactly one field should be incremented");
if !is_reverse {
match base {
'A' => assert_eq!(fwd.a, 1),
'T' => assert_eq!(fwd.t, 1),
'G' => assert_eq!(fwd.g, 1),
'C' => assert_eq!(fwd.c, 1),
_ => unreachable!(),
}
} else {
match base {
'A' => assert_eq!(rev.a, 1),
'T' => assert_eq!(rev.t, 1),
'G' => assert_eq!(rev.g, 1),
'C' => assert_eq!(rev.c, 1),
_ => unreachable!(),
}
}
}
}
}
#[test]
fn apply_encoded_call_accumulates() {
let mut counts = StrandBaseCounts::default();
let code_a_fwd = encode_call(true, 'A', false).unwrap();
for _ in 0..5 {
apply_encoded_call(true, code_a_fwd, &mut counts);
}
assert_eq!(counts.forward.a, 5);
}
#[test]
fn umi_conflict_code_is_max_u8() {
assert_eq!(UMI_CONFLICT_CODE, 0xFF);
for base in ['A', 'T', 'G', 'C'] {
for stranded in [true, false] {
for is_reverse in [true, false] {
if let Some(code) = encode_call(stranded, base, is_reverse) {
assert_ne!(code, UMI_CONFLICT_CODE);
}
}
}
}
}
#[test]
fn clean_tag_value_edge_cases() {
assert_eq!(clean_tag_value(""), None);
assert_eq!(clean_tag_value("-"), None);
assert_eq!(clean_tag_value("ABC-1-2"), Some("ABC".to_string()));
assert_eq!(clean_tag_value("NOPREFIX"), Some("NOPREFIX".to_string()));
}
#[test]
fn bam_processor_config_defaults() {
let config = BamProcessorConfig::default();
assert_eq!(config.min_mapping_quality, 255);
assert_eq!(config.min_base_quality, 30);
assert!(config.stranded);
assert_eq!(config.max_depth, 65_536);
assert_eq!(config.umi_tag, "UB");
assert_eq!(config.cell_barcode_tag, "CB");
}
#[test]
fn base_counts_default_is_zero() {
let bc = BaseCounts::default();
assert_eq!(bc.a, 0);
assert_eq!(bc.t, 0);
assert_eq!(bc.g, 0);
assert_eq!(bc.c, 0);
}
#[test]
fn strand_base_counts_default_is_zero() {
let sbc = StrandBaseCounts::default();
assert_eq!(sbc.forward.a + sbc.forward.t + sbc.forward.g + sbc.forward.c, 0);
assert_eq!(sbc.reverse.a + sbc.reverse.t + sbc.reverse.g + sbc.reverse.c, 0);
}
#[test]
fn clean_tag_value_into_writes_to_buffer() {
let mut buf = String::new();
assert!(clean_tag_value_into("AAACCTG-1", &mut buf));
assert_eq!(buf, "AAACCTG");
}
#[test]
fn clean_tag_value_into_strips_whitespace() {
let mut buf = String::new();
assert!(clean_tag_value_into(" TTTGCAA ", &mut buf));
assert_eq!(buf, "TTTGCAA");
}
#[test]
fn clean_tag_value_into_returns_false_for_empty_and_dash() {
let mut buf = String::new();
assert!(!clean_tag_value_into("", &mut buf));
assert!(buf.is_empty());
assert!(!clean_tag_value_into("-", &mut buf));
assert!(buf.is_empty());
assert!(!clean_tag_value_into(" ", &mut buf));
assert!(buf.is_empty());
}
#[test]
fn clean_tag_value_into_matches_original() {
let inputs = [
"AAACCTG-1",
"NOPREFIX",
"ABC-1-2",
" TTTGCAA ",
"-",
"",
" ",
];
for input in inputs {
let original = clean_tag_value(input);
let mut buf = String::new();
let ok = clean_tag_value_into(input, &mut buf);
match original {
Some(ref s) => {
assert!(ok, "expected true for {:?}", input);
assert_eq!(&buf, s, "mismatch for {:?}", input);
}
None => {
assert!(!ok, "expected false for {:?}", input);
assert!(buf.is_empty(), "buffer should be empty for {:?}", input);
}
}
}
}
}