1use std::path::Path;
3use std::sync::Arc;
4
5use anyhow::{anyhow, Result};
6use rust_htslib::bam::{self, pileup::Alignment, record::Record, Read};
7use rustc_hash::FxHashMap;
8
9use crate::pipeline::bam2mtx::barcode::BarcodeProcessor;
10
11#[derive(Debug, Clone, Default, serde::Serialize)]
13pub struct BaseCounts {
14 pub a: u32,
16 pub t: u32,
18 pub g: u32,
20 pub c: u32,
22}
23
24#[derive(Debug, Clone, Default, serde::Serialize)]
26pub struct StrandBaseCounts {
27 pub forward: BaseCounts,
29 pub reverse: BaseCounts,
31}
32
33#[derive(Debug, Clone, serde::Serialize)]
35pub struct PositionData {
36 pub contig_id: u32,
38 pub pos: u64,
40 pub counts: FxHashMap<u32, StrandBaseCounts>,
42}
43
44pub const UMI_CONFLICT_CODE: u8 = u8::MAX;
46
47fn clean_tag_value(raw: &str) -> Option<String> {
48 let clean = raw.split('-').next().unwrap_or(raw).trim();
49 if clean.is_empty() || clean == "-" {
50 None
51 } else {
52 Some(clean.to_string())
53 }
54}
55
56fn clean_tag_value_into(raw: &str, buf: &mut String) -> bool {
62 let clean = raw.split('-').next().unwrap_or(raw).trim();
63 if clean.is_empty() || clean == "-" {
64 false
65 } else {
66 buf.push_str(clean);
67 true
68 }
69}
70
71pub fn decode_cell_barcode(record: &Record, tag: &[u8]) -> Result<Option<String>> {
73 match record.aux(tag) {
74 Ok(bam::record::Aux::String(s)) => Ok(clean_tag_value(s)),
75 Ok(bam::record::Aux::ArrayU8(arr)) => {
76 let bytes: Vec<u8> = arr.iter().collect();
77 let raw = std::str::from_utf8(&bytes)?;
78 Ok(clean_tag_value(raw))
79 }
80 Ok(_) => Ok(None),
81 Err(_) => Ok(None),
82 }
83}
84
85pub fn decode_cell_barcode_into(record: &Record, tag: &[u8], buf: &mut String) -> Result<bool> {
91 match record.aux(tag) {
92 Ok(bam::record::Aux::String(s)) => Ok(clean_tag_value_into(s, buf)),
93 Ok(bam::record::Aux::ArrayU8(arr)) => {
94 let bytes: Vec<u8> = arr.iter().collect();
95 let raw = std::str::from_utf8(&bytes)?;
96 Ok(clean_tag_value_into(raw, buf))
97 }
98 Ok(_) => Ok(false),
99 Err(_) => Ok(false),
100 }
101}
102
103pub fn decode_umi(record: &Record, tag: &[u8]) -> Result<Option<String>> {
105 match record.aux(tag) {
106 Ok(bam::record::Aux::String(s)) => Ok(clean_tag_value(s)),
107 Ok(bam::record::Aux::ArrayU8(arr)) => {
108 let bytes: Vec<u8> = arr.iter().collect();
109 let raw = std::str::from_utf8(&bytes)?;
110 Ok(clean_tag_value(raw))
111 }
112 Ok(_) => Ok(None),
113 Err(_) => Ok(None),
114 }
115}
116
117pub fn decode_umi_into(record: &Record, tag: &[u8], buf: &mut String) -> Result<bool> {
121 match record.aux(tag) {
122 Ok(bam::record::Aux::String(s)) => Ok(clean_tag_value_into(s, buf)),
123 Ok(bam::record::Aux::ArrayU8(arr)) => {
124 let bytes: Vec<u8> = arr.iter().collect();
125 let raw = std::str::from_utf8(&bytes)?;
126 Ok(clean_tag_value_into(raw, buf))
127 }
128 Ok(_) => Ok(false),
129 Err(_) => Ok(false),
130 }
131}
132
133pub fn decode_base(record: &Record, qpos: Option<usize>) -> Result<char> {
135 let qpos = qpos.ok_or_else(|| anyhow!("Invalid query position"))?;
136 let seq = record.seq();
137 let base = seq.as_bytes()[qpos];
138
139 Ok(match base {
140 b'A' | b'a' => 'A',
141 b'T' | b't' => 'T',
142 b'G' | b'g' => 'G',
143 b'C' | b'c' => 'C',
144 _ => 'N',
145 })
146}
147
148#[inline]
149pub fn encode_call(stranded: bool, base: char, is_reverse: bool) -> Option<u8> {
150 let base_code = match base {
151 'A' => 0,
152 'T' => 1,
153 'G' => 2,
154 'C' => 3,
155 _ => return None,
156 };
157
158 if stranded {
159 let strand_bit = if is_reverse { 1 } else { 0 };
160 Some((base_code << 1) | strand_bit)
161 } else {
162 Some(base_code)
163 }
164}
165
166#[inline]
167pub fn apply_encoded_call(stranded: bool, code: u8, counts_entry: &mut StrandBaseCounts) {
168 if stranded {
169 let strand_bit = code & 1;
170 let base_code = code >> 1;
171 let target = if strand_bit == 1 {
172 &mut counts_entry.reverse
173 } else {
174 &mut counts_entry.forward
175 };
176
177 match base_code {
178 0 => target.a += 1,
179 1 => target.t += 1,
180 2 => target.g += 1,
181 3 => target.c += 1,
182 _ => {}
183 }
184 } else {
185 match code {
186 0 => counts_entry.forward.a += 1,
187 1 => counts_entry.forward.t += 1,
188 2 => counts_entry.forward.g += 1,
189 3 => counts_entry.forward.c += 1,
190 _ => {}
191 }
192 }
193}
194
195#[derive(Debug, Clone)]
197pub struct BamProcessorConfig {
198 pub min_mapping_quality: u8,
200 pub min_base_quality: u8,
202 pub stranded: bool,
204 pub max_depth: u32,
206 pub umi_tag: String,
208 pub cell_barcode_tag: String,
210}
211
212impl Default for BamProcessorConfig {
213 fn default() -> Self {
214 Self {
215 min_mapping_quality: 255,
216 min_base_quality: 30,
217 stranded: true,
218 max_depth: 65_536,
219 umi_tag: "UB".to_string(),
220 cell_barcode_tag: "CB".to_string(),
221 }
222 }
223}
224
225pub struct BamProcessor {
227 config: BamProcessorConfig,
229 barcode_processor: Arc<BarcodeProcessor>,
231}
232
233impl BamProcessor {
234 pub fn new(config: BamProcessorConfig, barcode_processor: Arc<BarcodeProcessor>) -> Self {
236 Self {
237 config,
238 barcode_processor,
239 }
240 }
241
242 pub fn process_position(&self, bam_path: &Path, chrom: &str, pos: u64) -> Result<PositionData> {
244 let mut reader = bam::IndexedReader::from_path(bam_path)?;
245
246 let start_pos = (pos - 1) as u32;
248 let end_pos = pos as u32;
249
250 let header = reader.header().to_owned();
252 let tid = header
253 .tid(chrom.as_bytes())
254 .ok_or_else(|| anyhow::anyhow!("Chromosome '{}' not found", chrom))?;
255
256 reader.fetch((tid, start_pos, end_pos))?;
258 let mut pileups: bam::pileup::Pileups<'_, bam::IndexedReader> = reader.pileup();
259 pileups.set_max_depth(self.config.max_depth.min(i32::MAX as u32));
260 let mut counts: FxHashMap<u32, StrandBaseCounts> = FxHashMap::default();
261 let mut umi_consensus: FxHashMap<(u32, String), u8> = FxHashMap::default();
262
263 for pileup in pileups {
265 let pileup = pileup?;
266 if pileup.pos() != start_pos {
267 continue;
268 }
269
270 if (pileup.depth() as u32) >= self.config.max_depth {
271 continue;
272 }
273
274 for read in pileup.alignments() {
277 if !self.should_process_read(&read) {
278 continue;
279 }
280
281 let record = read.record();
287 let cell_id =
288 match decode_cell_barcode(&record, self.config.cell_barcode_tag.as_bytes())? {
289 Some(barcode) => match self.barcode_processor.id_of(&barcode) {
290 Some(id) => id,
291 None => continue,
292 },
293 None => continue,
294 };
295
296 let umi = match decode_umi(&record, self.config.umi_tag.as_bytes())? {
297 Some(umi) => umi,
298 None => continue,
299 };
300
301 let base = decode_base(&record, read.qpos())?;
302 if let Some(encoded) = encode_call(self.config.stranded, base, record.is_reverse())
303 {
304 umi_consensus
305 .entry((cell_id, umi))
306 .and_modify(|existing| {
307 if *existing != encoded {
308 *existing = UMI_CONFLICT_CODE;
309 }
310 })
311 .or_insert(encoded);
312 }
313 }
314 }
315
316 for ((cell_id, _umi), encoded) in umi_consensus.drain() {
318 if encoded == UMI_CONFLICT_CODE {
319 continue;
320 }
321
322 let counts_entry = counts.entry(cell_id).or_default();
323
324 apply_encoded_call(self.config.stranded, encoded, counts_entry);
325 }
326
327 Ok(PositionData {
328 contig_id: tid,
329 pos,
330 counts,
331 })
332 }
333
334 fn should_process_read(&self, read: &Alignment) -> bool {
336 if read.is_del() || read.is_refskip() {
337 return false;
338 }
339
340 let record = read.record();
341
342 if record.mapq() < self.config.min_mapping_quality {
344 return false;
345 }
346
347 if let Some(qpos) = read.qpos() {
349 if let Some(qual) = record.qual().get(qpos) {
350 if *qual < self.config.min_base_quality {
351 return false;
352 }
353 }
354 }
355
356 true
357 }
358}
359
360#[cfg(test)]
361mod tests {
362 use super::*;
363 use rust_htslib::bam::{self, Read};
364 use std::collections::BTreeSet;
365
366 fn collect_barcodes_at_pos(
367 bam_path: &Path,
368 chrom: &str,
369 pos: u64,
370 cell_tag: &str,
371 ) -> Result<Vec<String>> {
372 let mut reader = bam::IndexedReader::from_path(bam_path)?;
373 let header = reader.header().to_owned();
374 let tid = header
375 .tid(chrom.as_bytes())
376 .ok_or_else(|| anyhow!("chromosome '{}' not found", chrom))?;
377 reader.fetch((tid, (pos - 1) as u32, pos as u32))?;
378
379 let mut barcodes = BTreeSet::new();
380 for pileup in reader.pileup() {
381 let pileup = pileup?;
382 if pileup.pos() != (pos - 1) as u32 {
383 continue;
384 }
385 for aln in pileup.alignments() {
386 if aln.is_del() || aln.is_refskip() {
387 continue;
388 }
389 let record = aln.record();
390 if let Some(cb) = decode_cell_barcode(&record, cell_tag.as_bytes())? {
391 barcodes.insert(cb);
392 }
393 }
394 }
395
396 Ok(barcodes.into_iter().collect())
397 }
398
399 fn manual_consensus(
400 bam_path: &Path,
401 chrom: &str,
402 pos: u64,
403 config: &BamProcessorConfig,
404 barcode_processor: &BarcodeProcessor,
405 ) -> Result<FxHashMap<u32, StrandBaseCounts>> {
406 let mut reader = bam::IndexedReader::from_path(bam_path)?;
407 let header = reader.header().to_owned();
408 let tid = header
409 .tid(chrom.as_bytes())
410 .ok_or_else(|| anyhow!("chromosome '{}' not found", chrom))?;
411 reader.fetch((tid, (pos - 1) as u32, pos as u32))?;
412
413 let mut pileups = reader.pileup();
414 pileups.set_max_depth(config.max_depth.min(i32::MAX as u32));
415
416 let mut umi_consensus: FxHashMap<(u32, String), u8> = FxHashMap::default();
417 let mut counts: FxHashMap<u32, StrandBaseCounts> = FxHashMap::default();
418
419 for pileup in pileups {
420 let pileup = pileup?;
421 if pileup.pos() != (pos - 1) as u32 {
422 continue;
423 }
424
425 if (pileup.depth() as u32) >= config.max_depth {
426 continue;
427 }
428
429 for read in pileup.alignments() {
430 if read.is_del() || read.is_refskip() {
431 continue;
432 }
433
434 let record = read.record();
435 if record.mapq() < config.min_mapping_quality {
436 continue;
437 }
438
439 if let Some(qpos) = read.qpos() {
440 if let Some(qual) = record.qual().get(qpos) {
441 if *qual < config.min_base_quality {
442 continue;
443 }
444 }
445 }
446
447 let cell_id = match decode_cell_barcode(&record, config.cell_barcode_tag.as_bytes())? {
448 Some(cb) => match barcode_processor.id_of(&cb) {
449 Some(id) => id,
450 None => continue,
451 },
452 None => continue,
453 };
454
455 let umi = match decode_umi(&record, config.umi_tag.as_bytes())? {
456 Some(umi) => umi,
457 None => continue,
458 };
459
460 let base = decode_base(&record, read.qpos())?;
461 if let Some(encoded) = encode_call(config.stranded, base, record.is_reverse()) {
462 umi_consensus
463 .entry((cell_id, umi))
464 .and_modify(|existing| {
465 if *existing != encoded {
466 *existing = UMI_CONFLICT_CODE;
467 }
468 })
469 .or_insert(encoded);
470 }
471 }
472 }
473
474 for ((cell_id, _umi), encoded) in umi_consensus.drain() {
475 if encoded == UMI_CONFLICT_CODE {
476 continue;
477 }
478 let counts_entry = counts.entry(cell_id).or_default();
479 apply_encoded_call(config.stranded, encoded, counts_entry);
480 }
481
482 Ok(counts)
483 }
484
485 #[test]
486 fn clean_tag_value_strips_suffix_and_whitespace() {
487 assert_eq!(clean_tag_value("AAACCTG-1"), Some("AAACCTG".to_string()));
488 assert_eq!(clean_tag_value(" TTTGCAA "), Some("TTTGCAA".to_string()));
489 assert_eq!(clean_tag_value("-"), None);
490 assert_eq!(clean_tag_value(" "), None);
491 }
492
493 #[test]
494 fn encode_and_apply_calls_work_for_stranded_and_unstranded() {
495 let mut stranded_counts = StrandBaseCounts::default();
496 let mut unstranded_counts = StrandBaseCounts::default();
497
498 let fwd_a = encode_call(true, 'A', false).unwrap();
499 let rev_g = encode_call(true, 'G', true).unwrap();
500 apply_encoded_call(true, fwd_a, &mut stranded_counts);
501 apply_encoded_call(true, rev_g, &mut stranded_counts);
502
503 assert_eq!(stranded_counts.forward.a, 1);
504 assert_eq!(stranded_counts.reverse.g, 1);
505
506 let t = encode_call(false, 'T', true).unwrap();
507 let c = encode_call(false, 'C', false).unwrap();
508 apply_encoded_call(false, t, &mut unstranded_counts);
509 apply_encoded_call(false, c, &mut unstranded_counts);
510
511 assert_eq!(unstranded_counts.forward.t, 1);
512 assert_eq!(unstranded_counts.forward.c, 1);
513 assert_eq!(unstranded_counts.reverse.t, 0);
514 assert!(encode_call(false, 'N', false).is_none());
515 }
516
517 #[test]
518 fn process_position_chr22_matches_manual_consensus() -> Result<()> {
519 let bam_path = Path::new("test/chr22.bam");
520 if !bam_path.exists() {
521 return Ok(());
522 }
523
524 let chrom = "chr22";
525 let pos = 50_783_283u64;
526
527 let config = BamProcessorConfig {
528 min_mapping_quality: 255,
529 min_base_quality: 30,
530 stranded: true,
531 max_depth: 10_000,
532 umi_tag: "UB".to_string(),
533 cell_barcode_tag: "CB".to_string(),
534 };
535
536 let barcodes = collect_barcodes_at_pos(bam_path, chrom, pos, &config.cell_barcode_tag)?;
537 if barcodes.is_empty() {
538 return Ok(());
539 }
540
541 let barcode_processor = Arc::new(BarcodeProcessor::from_vec(barcodes));
542 let processor = BamProcessor::new(config.clone(), Arc::clone(&barcode_processor));
543
544 let observed = processor.process_position(bam_path, chrom, pos)?;
545 let expected = manual_consensus(bam_path, chrom, pos, &config, &barcode_processor)?;
546
547 assert_eq!(observed.pos, pos);
548 assert_eq!(observed.counts.len(), expected.len());
549
550 for (cell_id, exp) in expected.iter() {
551 let got = observed
552 .counts
553 .get(cell_id)
554 .unwrap_or_else(|| panic!("missing cell_id {} in observed counts", cell_id));
555 assert_eq!(got.forward.a, exp.forward.a);
556 assert_eq!(got.forward.t, exp.forward.t);
557 assert_eq!(got.forward.g, exp.forward.g);
558 assert_eq!(got.forward.c, exp.forward.c);
559 assert_eq!(got.reverse.a, exp.reverse.a);
560 assert_eq!(got.reverse.t, exp.reverse.t);
561 assert_eq!(got.reverse.g, exp.reverse.g);
562 assert_eq!(got.reverse.c, exp.reverse.c);
563 }
564
565 Ok(())
566 }
567
568 #[test]
574 fn decode_base_handles_all_bases_and_n() {
575 for (base, expected_code_unstranded) in [('A', 0u8), ('T', 1), ('G', 2), ('C', 3)] {
578 let code = encode_call(false, base, false).unwrap();
579 assert_eq!(code, expected_code_unstranded, "unstranded encode for {}", base);
580
581 let code_rev = encode_call(false, base, true).unwrap();
582 assert_eq!(code_rev, expected_code_unstranded, "unstranded+reverse for {}", base);
583 }
584 assert!(encode_call(false, 'N', false).is_none());
585 assert!(encode_call(true, 'N', true).is_none());
586 }
587
588 #[test]
590 fn encode_call_stranded_distinguishes_strands() {
591 for base in ['A', 'T', 'G', 'C'] {
592 let fwd = encode_call(true, base, false).unwrap();
593 let rev = encode_call(true, base, true).unwrap();
594 assert_ne!(fwd, rev, "stranded codes should differ for base {}", base);
595 assert_eq!(fwd & 1, 0);
597 assert_eq!(rev & 1, 1);
598 }
599 }
600
601 #[test]
603 fn apply_encoded_call_increments_correct_fields() {
604 let bases = ['A', 'T', 'G', 'C'];
606 for &base in &bases {
607 for is_reverse in [false, true] {
608 let mut counts = StrandBaseCounts::default();
609 let code = encode_call(true, base, is_reverse).unwrap();
610 apply_encoded_call(true, code, &mut counts);
611
612 let (fwd, rev) = (&counts.forward, &counts.reverse);
613 let total = fwd.a + fwd.t + fwd.g + fwd.c + rev.a + rev.t + rev.g + rev.c;
614 assert_eq!(total, 1, "exactly one field should be incremented");
615
616 if !is_reverse {
617 match base {
618 'A' => assert_eq!(fwd.a, 1),
619 'T' => assert_eq!(fwd.t, 1),
620 'G' => assert_eq!(fwd.g, 1),
621 'C' => assert_eq!(fwd.c, 1),
622 _ => unreachable!(),
623 }
624 } else {
625 match base {
626 'A' => assert_eq!(rev.a, 1),
627 'T' => assert_eq!(rev.t, 1),
628 'G' => assert_eq!(rev.g, 1),
629 'C' => assert_eq!(rev.c, 1),
630 _ => unreachable!(),
631 }
632 }
633 }
634 }
635 }
636
637 #[test]
639 fn apply_encoded_call_accumulates() {
640 let mut counts = StrandBaseCounts::default();
641 let code_a_fwd = encode_call(true, 'A', false).unwrap();
642 for _ in 0..5 {
643 apply_encoded_call(true, code_a_fwd, &mut counts);
644 }
645 assert_eq!(counts.forward.a, 5);
646 }
647
648 #[test]
650 fn umi_conflict_code_is_max_u8() {
651 assert_eq!(UMI_CONFLICT_CODE, 0xFF);
652 for base in ['A', 'T', 'G', 'C'] {
654 for stranded in [true, false] {
655 for is_reverse in [true, false] {
656 if let Some(code) = encode_call(stranded, base, is_reverse) {
657 assert_ne!(code, UMI_CONFLICT_CODE);
658 }
659 }
660 }
661 }
662 }
663
664 #[test]
666 fn clean_tag_value_edge_cases() {
667 assert_eq!(clean_tag_value(""), None);
668 assert_eq!(clean_tag_value("-"), None);
669 assert_eq!(clean_tag_value("ABC-1-2"), Some("ABC".to_string()));
670 assert_eq!(clean_tag_value("NOPREFIX"), Some("NOPREFIX".to_string()));
671 }
672
673 #[test]
675 fn bam_processor_config_defaults() {
676 let config = BamProcessorConfig::default();
677 assert_eq!(config.min_mapping_quality, 255);
678 assert_eq!(config.min_base_quality, 30);
679 assert!(config.stranded);
680 assert_eq!(config.max_depth, 65_536);
681 assert_eq!(config.umi_tag, "UB");
682 assert_eq!(config.cell_barcode_tag, "CB");
683 }
684
685 #[test]
687 fn base_counts_default_is_zero() {
688 let bc = BaseCounts::default();
689 assert_eq!(bc.a, 0);
690 assert_eq!(bc.t, 0);
691 assert_eq!(bc.g, 0);
692 assert_eq!(bc.c, 0);
693 }
694
695 #[test]
697 fn strand_base_counts_default_is_zero() {
698 let sbc = StrandBaseCounts::default();
699 assert_eq!(sbc.forward.a + sbc.forward.t + sbc.forward.g + sbc.forward.c, 0);
700 assert_eq!(sbc.reverse.a + sbc.reverse.t + sbc.reverse.g + sbc.reverse.c, 0);
701 }
702
703 #[test]
708 fn clean_tag_value_into_writes_to_buffer() {
709 let mut buf = String::new();
710 assert!(clean_tag_value_into("AAACCTG-1", &mut buf));
711 assert_eq!(buf, "AAACCTG");
712 }
713
714 #[test]
715 fn clean_tag_value_into_strips_whitespace() {
716 let mut buf = String::new();
717 assert!(clean_tag_value_into(" TTTGCAA ", &mut buf));
718 assert_eq!(buf, "TTTGCAA");
719 }
720
721 #[test]
722 fn clean_tag_value_into_returns_false_for_empty_and_dash() {
723 let mut buf = String::new();
724 assert!(!clean_tag_value_into("", &mut buf));
725 assert!(buf.is_empty());
726
727 assert!(!clean_tag_value_into("-", &mut buf));
728 assert!(buf.is_empty());
729
730 assert!(!clean_tag_value_into(" ", &mut buf));
731 assert!(buf.is_empty());
732 }
733
734 #[test]
735 fn clean_tag_value_into_matches_original() {
736 let inputs = [
739 "AAACCTG-1",
740 "NOPREFIX",
741 "ABC-1-2",
742 " TTTGCAA ",
743 "-",
744 "",
745 " ",
746 ];
747 for input in inputs {
748 let original = clean_tag_value(input);
749 let mut buf = String::new();
750 let ok = clean_tag_value_into(input, &mut buf);
751 match original {
752 Some(ref s) => {
753 assert!(ok, "expected true for {:?}", input);
754 assert_eq!(&buf, s, "mismatch for {:?}", input);
755 }
756 None => {
757 assert!(!ok, "expected false for {:?}", input);
758 assert!(buf.is_empty(), "buffer should be empty for {:?}", input);
759 }
760 }
761 }
762 }
763}