redicat_lib/pipeline/bam2mtx/
processor.rs

1//! BAM file processing for single-cell data
2use std::path::Path;
3use std::sync::Arc;
4
5use anyhow::{anyhow, Result};
6use rust_htslib::bam::{self, pileup::Alignment, record::Record, Read};
7use rustc_hash::FxHashMap;
8
9use crate::pipeline::bam2mtx::barcode::BarcodeProcessor;
10
11/// Base counts for a specific position
12#[derive(Debug, Clone, Default, serde::Serialize)]
13pub struct BaseCounts {
14    /// Count of adenine (A) nucleotides
15    pub a: u32,
16    /// Count of thymine (T) nucleotides
17    pub t: u32,
18    /// Count of guanine (G) nucleotides
19    pub g: u32,
20    /// Count of cytosine (C) nucleotides
21    pub c: u32,
22}
23
24/// Strand-specific base counts
25#[derive(Debug, Clone, Default, serde::Serialize)]
26pub struct StrandBaseCounts {
27    /// Base counts for the forward strand
28    pub forward: BaseCounts,
29    /// Base counts for the reverse strand
30    pub reverse: BaseCounts,
31}
32
33/// Processed data for a specific genomic position
34#[derive(Debug, Clone, serde::Serialize)]
35pub struct PositionData {
36    /// Numeric contig identifier (matches the BAM header TID)
37    pub contig_id: u32,
38    /// 1-based genomic position
39    pub pos: u64,
40    /// Counts per cell barcode (indexed by whitelist order)
41    pub counts: FxHashMap<u32, StrandBaseCounts>,
42}
43
44/// Consensus code used to indicate conflicting UMI calls
45pub const UMI_CONFLICT_CODE: u8 = u8::MAX;
46
47fn clean_tag_value(raw: &str) -> Option<String> {
48    let clean = raw.split('-').next().unwrap_or(raw).trim();
49    if clean.is_empty() || clean == "-" {
50        None
51    } else {
52        Some(clean.to_string())
53    }
54}
55
56/// Extract and normalize a cell barcode from the requested BAM tag.
57pub fn decode_cell_barcode(record: &Record, tag: &[u8]) -> Result<Option<String>> {
58    match record.aux(tag) {
59        Ok(bam::record::Aux::String(s)) => Ok(clean_tag_value(s)),
60        Ok(bam::record::Aux::ArrayU8(arr)) => {
61            let bytes: Vec<u8> = arr.iter().collect();
62            let raw = std::str::from_utf8(&bytes)?;
63            Ok(clean_tag_value(raw))
64        }
65        Ok(_) => Ok(None),
66        Err(_) => Ok(None),
67    }
68}
69
70/// Extract and normalize a UMI from the requested BAM tag.
71pub fn decode_umi(record: &Record, tag: &[u8]) -> Result<Option<String>> {
72    match record.aux(tag) {
73        Ok(bam::record::Aux::String(s)) => Ok(clean_tag_value(s)),
74        Ok(bam::record::Aux::ArrayU8(arr)) => {
75            let bytes: Vec<u8> = arr.iter().collect();
76            let raw = std::str::from_utf8(&bytes)?;
77            Ok(clean_tag_value(raw))
78        }
79        Ok(_) => Ok(None),
80        Err(_) => Ok(None),
81    }
82}
83
84/// Retrieve the canonical base at the requested query position.
85pub fn decode_base(record: &Record, qpos: Option<usize>) -> Result<char> {
86    let qpos = qpos.ok_or_else(|| anyhow!("Invalid query position"))?;
87    let seq = record.seq();
88    let base = seq.as_bytes()[qpos];
89
90    Ok(match base {
91        b'A' | b'a' => 'A',
92        b'T' | b't' => 'T',
93        b'G' | b'g' => 'G',
94        b'C' | b'c' => 'C',
95        _ => 'N',
96    })
97}
98
99#[inline]
100pub fn encode_call(stranded: bool, base: char, is_reverse: bool) -> Option<u8> {
101    let base_code = match base {
102        'A' => 0,
103        'T' => 1,
104        'G' => 2,
105        'C' => 3,
106        _ => return None,
107    };
108
109    if stranded {
110        let strand_bit = if is_reverse { 1 } else { 0 };
111        Some((base_code << 1) | strand_bit)
112    } else {
113        Some(base_code)
114    }
115}
116
117#[inline]
118pub fn apply_encoded_call(stranded: bool, code: u8, counts_entry: &mut StrandBaseCounts) {
119    if stranded {
120        let strand_bit = code & 1;
121        let base_code = code >> 1;
122        let target = if strand_bit == 1 {
123            &mut counts_entry.reverse
124        } else {
125            &mut counts_entry.forward
126        };
127
128        match base_code {
129            0 => target.a += 1,
130            1 => target.t += 1,
131            2 => target.g += 1,
132            3 => target.c += 1,
133            _ => {}
134        }
135    } else {
136        match code {
137            0 => counts_entry.forward.a += 1,
138            1 => counts_entry.forward.t += 1,
139            2 => counts_entry.forward.g += 1,
140            3 => counts_entry.forward.c += 1,
141            _ => {}
142        }
143    }
144}
145
146/// Configuration for BAM processing
147#[derive(Debug, Clone)]
148pub struct BamProcessorConfig {
149    /// Minimum mapping quality for a read to be considered
150    pub min_mapping_quality: u8,
151    /// Minimum base quality for a base to be counted
152    pub min_base_quality: u8,
153    /// Minimum depth (excluding Ns) required to keep a position
154    pub min_depth: u32,
155    /// Maximum allowed N fraction denominator (depth / max_n_fraction)
156    pub max_n_fraction: u32,
157    /// Editing threshold used to require multi-base support
158    pub editing_threshold: u32,
159    /// Whether the data is stranded (true) or unstranded (false)
160    pub stranded: bool,
161    /// Maximum pileup depth to examine per genomic position
162    pub max_depth: u32,
163    /// Tag name for UMI (Unique Molecular Identifier)
164    pub umi_tag: String,
165    /// Tag name for cell barcode
166    pub cell_barcode_tag: String,
167}
168
169impl Default for BamProcessorConfig {
170    fn default() -> Self {
171        Self {
172            min_mapping_quality: 255,
173            min_base_quality: 30,
174            min_depth: 10,
175            max_n_fraction: 20,
176            editing_threshold: 1000,
177            stranded: true,
178            max_depth: 65_536,
179            umi_tag: "UB".to_string(),
180            cell_barcode_tag: "CB".to_string(),
181        }
182    }
183}
184
185/// Main processor for BAM files
186pub struct BamProcessor {
187    /// Configuration for BAM processing
188    config: BamProcessorConfig,
189    /// Processor for validating cell barcodes
190    barcode_processor: Arc<BarcodeProcessor>,
191}
192
193impl BamProcessor {
194    /// Create a new BamProcessor
195    pub fn new(config: BamProcessorConfig, barcode_processor: Arc<BarcodeProcessor>) -> Self {
196        Self {
197            config,
198            barcode_processor,
199        }
200    }
201
202    /// Process a single genomic position
203    pub fn process_position(&self, bam_path: &Path, chrom: &str, pos: u64) -> Result<PositionData> {
204        let mut reader = bam::IndexedReader::from_path(bam_path)?;
205
206        // Convert to 0-based position for rust-htslib
207        let start_pos = (pos - 1) as u32;
208        let end_pos = pos as u32;
209
210        // Get chromosome ID
211        let header = reader.header().to_owned();
212        let tid = header
213            .tid(chrom.as_bytes())
214            .ok_or_else(|| anyhow::anyhow!("Chromosome '{}' not found", chrom))?;
215
216        // Fetch the region
217        reader.fetch((tid, start_pos, end_pos))?;
218        let mut pileups: bam::pileup::Pileups<'_, bam::IndexedReader> = reader.pileup();
219        pileups.set_max_depth(self.config.max_depth.min(i32::MAX as u32));
220        let mut counts: FxHashMap<u32, StrandBaseCounts> = FxHashMap::default();
221        let mut umi_consensus: FxHashMap<(u32, String), u8> = FxHashMap::default();
222
223        // Process pileup
224        for pileup in pileups {
225            let pileup = pileup?;
226            if pileup.pos() != start_pos {
227                continue;
228            }
229
230            if (pileup.depth() as u32) >= self.config.max_depth {
231                continue;
232            }
233
234            // let mut processed = 0u32;
235
236            for read in pileup.alignments() {
237                if !self.should_process_read(&read) {
238                    continue;
239                }
240
241                // processed = processed.saturating_add(1);
242                // if processed > self.config.max_depth {
243                //     break;
244                // }
245
246                let record = read.record();
247                let cell_id =
248                    match decode_cell_barcode(&record, self.config.cell_barcode_tag.as_bytes())? {
249                        Some(barcode) => match self.barcode_processor.id_of(&barcode) {
250                            Some(id) => id,
251                            None => continue,
252                        },
253                        None => continue,
254                    };
255
256                let umi = match decode_umi(&record, self.config.umi_tag.as_bytes())? {
257                    Some(umi) => umi,
258                    None => continue,
259                };
260
261                let base = decode_base(&record, read.qpos())?;
262                if let Some(encoded) = encode_call(self.config.stranded, base, record.is_reverse())
263                {
264                    umi_consensus
265                        .entry((cell_id, umi))
266                        .and_modify(|existing| {
267                            if *existing != encoded {
268                                *existing = UMI_CONFLICT_CODE;
269                            }
270                        })
271                        .or_insert(encoded);
272                }
273            }
274        }
275
276        // Aggregate counts by cell barcode
277        for ((cell_id, _umi), encoded) in umi_consensus.drain() {
278            if encoded == UMI_CONFLICT_CODE {
279                continue;
280            }
281
282            let counts_entry = counts.entry(cell_id).or_default();
283
284            apply_encoded_call(self.config.stranded, encoded, counts_entry);
285        }
286
287        Ok(PositionData {
288            contig_id: tid,
289            pos,
290            counts,
291        })
292    }
293
294    /// Check if a read should be processed
295    fn should_process_read(&self, read: &Alignment) -> bool {
296        if read.is_del() || read.is_refskip() {
297            return false;
298        }
299
300        let record = read.record();
301
302        // Check mapping quality
303        if record.mapq() < self.config.min_mapping_quality {
304            return false;
305        }
306
307        // Check base quality
308        if let Some(qpos) = read.qpos() {
309            if let Some(qual) = record.qual().get(qpos) {
310                if *qual < self.config.min_base_quality {
311                    return false;
312                }
313            }
314        }
315
316        true
317    }
318}