Skip to main content

rsomics_bam_calmd/
lib.rs

1//! `samtools calmd` port: recompute the MD and NM aux tags of every alignment
2//! against a reference FASTA, then re-emit the BAM.
3//!
4//! The MD/NM walk mirrors `bam_fillmd1_core` in samtools `bam_md.c` exactly,
5//! including its nibble-level base comparison (htslib `seq_nt16_table` /
6//! `bam_seqi`), its run-length MD string with `^`-prefixed deletions, and its
7//! "only rewrite the tag when the value differs" guard so unchanged records
8//! stay byte-for-byte identical (matching aux ordering and integer subtype).
9//!
10//! Reading uses the shared [`rsomics_bamio`] reader (libdeflate BGZF, parallel
11//! at `workers >= 2`); records are processed on the raw BAM byte level —
12//! seq/qual/cigar are never decoded into noodles types. MD/NM are written
13//! directly into the raw aux tail via `RawRecord::set_aux`. This avoids the
14//! full `RecordBuf` decode+re-encode round-trip (the former bottleneck at
15//! `-t4` and above, accounting for 67% of wall time). Output goes through the
16//! bamio work-stealing BGZF writer — at `workers >= 2` via the
17//! [`WsBatchBamWriter`], which frames each batch on its own thread so the
18//! main thread's per-record write cost no longer caps multi-thread throughput.
19//!
20//! The reference is read from an indexed FASTA (`.fai`); each contig is fetched
21//! once on first use and cached for the run. Coordinate-sorted input therefore
22//! touches each contig once and never re-reads it — the common calmd case.
23//!
24//! At `workers >= 2` the MD/NM computation is parallelised with rayon: records
25//! are collected into a batch, the needed contigs are fetched serially into a
26//! shared read-only map (`Arc<Vec<u8>>` per contig), then `par_iter_mut` runs
27//! the raw MD/NM pass on every record simultaneously. Output is written in
28//! original batch order so the byte stream is identical to the serial path.
29
30use std::collections::HashMap;
31use std::num::NonZero;
32use std::path::Path;
33use std::sync::Arc;
34use std::time::Instant;
35
36use noodles::bam;
37use noodles::fasta;
38use noodles::sam::alignment::RecordBuf;
39use noodles::sam::alignment::io::Write as AlignmentWrite;
40use noodles::sam::alignment::record::cigar::op::Kind;
41use noodles::sam::alignment::record::data::field::Tag;
42use noodles::sam::alignment::record_buf::data::field::Value;
43use rayon::prelude::*;
44use rsomics_bamio::raw::{self, RawRecord};
45use rsomics_bamio::{WsBamWriter, WsBatchBamWriter};
46use rsomics_common::{Result, RsomicsError};
47use serde::Serialize;
48
49const TAG_NM: Tag = Tag::EDIT_DISTANCE;
50const TAG_MD: Tag = Tag::MISMATCHED_POSITIONS;
51
52/// BAM aux type codes for MD (Z string) and NM (i32).
53const AUX_TYPE_Z: u8 = b'Z';
54const AUX_TYPE_I: u8 = b'i';
55
56/// Tag bytes for the raw aux path.
57const NM_TAG: [u8; 2] = [b'N', b'M'];
58const MD_TAG: [u8; 2] = [b'M', b'D'];
59
60/// Number of raw records processed per rayon batch.
61///
62/// A `RawRecord` holds the on-disk payload bytes (~350 bytes for 150 bp).
63/// 4096 records ≈ 1.4 MB per batch, fitting in L3. 4096 gives rayon workers
64/// ~1-4 ms of compute per batch — well above scheduler granularity while
65/// keeping the serial read/write phase short enough to maintain pipelining.
66const BATCH_SIZE: usize = 4096;
67
68/// htslib `seq_nt16_table` (htslib `hts.c`): ASCII base → 4-bit nucleotide code.
69/// `=`→0, A→1 … N→15, with the IUPAC ambiguity codes and the digit aliases
70/// `0123`→`ACGT`; any unrecognised byte → 15. The MD/NM match test compares
71/// these codes, not ASCII, exactly as `bam_fillmd1_core` does via `bam_seqi`.
72#[rustfmt::skip]
73const SEQ_NT16_TABLE: [u8; 256] = [
74    15,15,15,15, 15,15,15,15, 15,15,15,15, 15,15,15,15,
75    15,15,15,15, 15,15,15,15, 15,15,15,15, 15,15,15,15,
76    15,15,15,15, 15,15,15,15, 15,15,15,15, 15,15,15,15,
77     1, 2, 4, 8, 15,15,15,15, 15,15,15,15, 15, 0,15,15,
78    15, 1,14, 2, 13,15,15, 4, 11,15,15,12, 15, 3,15,15,
79    15,15, 5, 6,  8, 8, 7, 9, 15,10,15,15, 15,15,15,15,
80    15, 1,14, 2, 13,15,15, 4, 11,15,15,12, 15, 3,15,15,
81    15,15, 5, 6,  8, 8, 7, 9, 15,10,15,15, 15,15,15,15,
82    15,15,15,15, 15,15,15,15, 15,15,15,15, 15,15,15,15,
83    15,15,15,15, 15,15,15,15, 15,15,15,15, 15,15,15,15,
84    15,15,15,15, 15,15,15,15, 15,15,15,15, 15,15,15,15,
85    15,15,15,15, 15,15,15,15, 15,15,15,15, 15,15,15,15,
86    15,15,15,15, 15,15,15,15, 15,15,15,15, 15,15,15,15,
87    15,15,15,15, 15,15,15,15, 15,15,15,15, 15,15,15,15,
88    15,15,15,15, 15,15,15,15, 15,15,15,15, 15,15,15,15,
89    15,15,15,15, 15,15,15,15, 15,15,15,15, 15,15,15,15,
90];
91
92#[derive(Debug, Clone, Default)]
93pub struct CalmdOpts {
94    /// `-e`: rewrite reference-matching read bases as `=` in the output SEQ.
95    pub use_equal: bool,
96}
97
98#[derive(Debug, Default, Clone, Serialize)]
99pub struct CalmdStats {
100    pub records: u64,
101    /// Mapped records whose MD/NM were (re)computed.
102    pub computed: u64,
103    /// Mapped records whose contig was missing from the reference, left untouched.
104    pub missing_ref: u64,
105    /// Records skipped because they carry no stored sequence (`l_qseq == 0`).
106    pub no_sequence: u64,
107}
108
109/// Port of the MD/NM walk in `bam_fillmd1_core` (samtools `bam_md.c`).
110///
111/// `seq` is the read's ASCII bases (mutated in place to `=` on matches when
112/// `use_equal`), `cigar` the (kind, len) op list, `ref_seq` the contig bases,
113/// and `pos` the 0-based reference start. The constructed MD string is written
114/// into `md` (cleared by the caller); the return value is the recomputed NM.
115///
116/// M/=/X consume read+ref and compare nucleotide codes (`SEQ_NT16_TABLE`): a
117/// match needs equal non-N codes, or a read code of 0 (`=`); a mismatch flushes
118/// the current match run then the uppercased ref base. D emits `^` + uppercased
119/// ref bases; I/S advance the read (I also bumps NM); N advances the reference
120/// only. Out-of-bounds / NUL-padded ref ends the walk early, exactly as the C
121/// `break` does.
122fn compute_md_nm(
123    seq: &mut [u8],
124    cigar: &[(Kind, usize)],
125    ref_seq: &[u8],
126    pos: usize,
127    use_equal: bool,
128    md: &mut Vec<u8>,
129) -> i32 {
130    let qual_len = seq.len();
131    let ref_len = ref_seq.len();
132    let mut nm: i32 = 0;
133    let mut matched: i64 = 0;
134    let mut qpos: usize = 0;
135    let mut rpos: usize = pos;
136
137    'outer: for &(kind, oplen) in cigar {
138        match kind {
139            Kind::Match | Kind::SequenceMatch | Kind::SequenceMismatch => {
140                let mut j = 0;
141                while j < oplen {
142                    let z = qpos + j;
143                    let rp = rpos + j;
144                    if rp >= ref_len || z >= qual_len || ref_seq[rp] == 0 {
145                        break;
146                    }
147                    let c1 = SEQ_NT16_TABLE[seq[z] as usize];
148                    let c2 = SEQ_NT16_TABLE[ref_seq[rp] as usize];
149                    let is_match = (c1 == c2 && c1 != 15 && c2 != 15) || c1 == 0;
150                    if is_match {
151                        if use_equal {
152                            seq[z] = b'=';
153                        }
154                        matched += 1;
155                    } else {
156                        append_int(md, matched);
157                        md.push(ref_seq[rp].to_ascii_uppercase());
158                        matched = 0;
159                        nm += 1;
160                    }
161                    j += 1;
162                }
163                if j < oplen {
164                    break 'outer;
165                }
166                rpos += oplen;
167                qpos += oplen;
168            }
169            Kind::Deletion => {
170                append_int(md, matched);
171                md.push(b'^');
172                let mut j = 0;
173                while j < oplen {
174                    let rp = rpos + j;
175                    if rp >= ref_len || ref_seq[rp] == 0 {
176                        break;
177                    }
178                    md.push(ref_seq[rp].to_ascii_uppercase());
179                    j += 1;
180                }
181                matched = 0;
182                rpos += j;
183                nm += j as i32;
184                if j < oplen {
185                    break 'outer;
186                }
187            }
188            Kind::Insertion => {
189                qpos += oplen;
190                nm += oplen as i32;
191            }
192            Kind::SoftClip => {
193                qpos += oplen;
194            }
195            Kind::Skip => {
196                rpos += oplen;
197            }
198            Kind::HardClip | Kind::Pad => {}
199        }
200    }
201    append_int(md, matched);
202
203    nm
204}
205
206/// BAM CIGAR op-code to `noodles::sam::alignment::record::cigar::op::Kind`.
207/// The op codes in the BAM 4-bit encoding are: 0=M 1=I 2=D 3=N 4=S 5=H 6=P 7== 8=X.
208fn bam_op_to_kind(op: u8) -> Kind {
209    match op {
210        0 => Kind::Match,
211        1 => Kind::Insertion,
212        2 => Kind::Deletion,
213        3 => Kind::Skip,
214        4 => Kind::SoftClip,
215        5 => Kind::HardClip,
216        6 => Kind::Pad,
217        7 => Kind::SequenceMatch,
218        8 => Kind::SequenceMismatch,
219        _ => Kind::Match,
220    }
221}
222
223/// Port of the MD/NM walk operating directly on packed BAM nibble SEQ.
224///
225/// The BAM packed SEQ format stores two bases per byte: high nibble = even
226/// index, low nibble = odd index. Nibble codes are the `seq_nt16` values
227/// directly (`=`→0, A→1, C→2, G→4, T→8, N→15), so no table lookup is needed
228/// for the read side — the nibble IS the `c1` value in `bam_fillmd1_core`.
229///
230/// `seq_bytes` is the packed SEQ field (mutated in place when `use_equal` sets
231/// matched bases to 0). `seq_len` is the number of query bases (not byte count).
232/// `cigar` is the raw BAM CIGAR op list as `(op_code, len)` pairs. `ref_seq`
233/// and `pos` are the reference contig and 0-based start. `md` is cleared and
234/// filled with the MD string bytes; the return value is the recomputed NM.
235fn compute_md_nm_raw(
236    seq_bytes: &mut [u8],
237    seq_len: usize,
238    cigar: &[(u8, u32)],
239    ref_seq: &[u8],
240    pos: usize,
241    use_equal: bool,
242    md: &mut Vec<u8>,
243) -> i32 {
244    let ref_len = ref_seq.len();
245    let mut nm: i32 = 0;
246    let mut matched: i64 = 0;
247    let mut qpos: usize = 0;
248    let mut rpos: usize = pos;
249
250    'outer: for &(op, oplen_u32) in cigar {
251        let oplen = oplen_u32 as usize;
252        let kind = bam_op_to_kind(op);
253        match kind {
254            Kind::Match | Kind::SequenceMatch | Kind::SequenceMismatch => {
255                let mut j = 0;
256                while j < oplen {
257                    let z = qpos + j;
258                    let rp = rpos + j;
259                    if rp >= ref_len || z >= seq_len || ref_seq[rp] == 0 {
260                        break;
261                    }
262                    // BAM nibble codes are already seq_nt16 values: no table lookup.
263                    let byte_idx = z / 2;
264                    let c1 = if z.is_multiple_of(2) {
265                        seq_bytes[byte_idx] >> 4
266                    } else {
267                        seq_bytes[byte_idx] & 0x0f
268                    };
269                    let c2 = SEQ_NT16_TABLE[ref_seq[rp] as usize];
270                    let is_match = (c1 == c2 && c1 != 15 && c2 != 15) || c1 == 0;
271                    if is_match {
272                        if use_equal {
273                            // Set nibble to 0 (the `=` code in seq_nt16).
274                            if z.is_multiple_of(2) {
275                                seq_bytes[byte_idx] &= 0x0f;
276                            } else {
277                                seq_bytes[byte_idx] &= 0xf0;
278                            }
279                        }
280                        matched += 1;
281                    } else {
282                        append_int(md, matched);
283                        md.push(ref_seq[rp].to_ascii_uppercase());
284                        matched = 0;
285                        nm += 1;
286                    }
287                    j += 1;
288                }
289                if j < oplen {
290                    break 'outer;
291                }
292                rpos += oplen;
293                qpos += oplen;
294            }
295            Kind::Deletion => {
296                append_int(md, matched);
297                md.push(b'^');
298                let mut j = 0;
299                while j < oplen {
300                    let rp = rpos + j;
301                    if rp >= ref_len || ref_seq[rp] == 0 {
302                        break;
303                    }
304                    md.push(ref_seq[rp].to_ascii_uppercase());
305                    j += 1;
306                }
307                matched = 0;
308                rpos += j;
309                nm += j as i32;
310                if j < oplen {
311                    break 'outer;
312                }
313            }
314            Kind::Insertion => {
315                qpos += oplen;
316                nm += oplen as i32;
317            }
318            Kind::SoftClip => {
319                qpos += oplen;
320            }
321            Kind::Skip => {
322                rpos += oplen;
323            }
324            Kind::HardClip | Kind::Pad => {}
325        }
326    }
327    append_int(md, matched);
328
329    nm
330}
331
332/// `kputw`: append a base-10 match-run length as ASCII to the MD buffer. The run
333/// length is a non-negative count, so this writes digits straight into the Vec
334/// (most significant first) with no temporary allocation or formatter overhead.
335fn append_int(buf: &mut Vec<u8>, value: i64) {
336    debug_assert!(value >= 0, "MD match run lengths are non-negative");
337    if value == 0 {
338        buf.push(b'0');
339        return;
340    }
341    let mut digits = [0u8; 20];
342    let mut i = digits.len();
343    let mut v = value;
344    while v > 0 {
345        i -= 1;
346        digits[i] = b'0' + (v % 10) as u8;
347        v /= 10;
348    }
349    buf.extend_from_slice(&digits[i..]);
350}
351
352/// Apply the MD/NM update rules of `bam_fillmd1_core` to a decoded record:
353/// append the tag if absent, replace-and-move-to-end if the value differs, and
354/// leave it in place (preserving aux order + integer subtype) if it is already
355/// correct. Mirrors samtools' `bam_aux_get` / `bam_aux_del` / `bam_aux_append`.
356/// `md` is taken by value to avoid copying the MD bytes into the tag.
357fn apply_tags(record: &mut RecordBuf, nm: i32, md: &[u8]) {
358    let data = record.data_mut();
359
360    let nm_same = data
361        .get(&TAG_NM)
362        .and_then(Value::as_int)
363        .is_some_and(|old| old == i64::from(nm));
364    if !nm_same {
365        replace_or_append(data, TAG_NM, Value::Int32(nm));
366    }
367
368    let md_same = data.get(&TAG_MD).is_some_and(|v| match v {
369        Value::String(s) => s.eq_ignore_ascii_case(md),
370        _ => false,
371    });
372    if !md_same {
373        replace_or_append(data, TAG_MD, Value::String(md.into()));
374    }
375}
376
377/// Apply MD/NM tags directly to a `RawRecord`'s aux tail, bypassing full
378/// decode+re-encode. This is the hot path for the raw parallel pipeline.
379///
380/// NM is written as BAM type `i` (signed 32-bit). MD is written as BAM type
381/// `Z` (NUL-terminated string). `set_aux` removes the old field (if any) and
382/// appends the new value at the end, matching samtools' `bam_aux_del` +
383/// `bam_aux_append` behaviour.
384fn apply_tags_raw(record: &mut RawRecord, nm: i32, md: &[u8]) {
385    let nm_same = record
386        .aux_value(NM_TAG)
387        .and_then(|v| {
388            if v.len() == 4 {
389                Some(i32::from_le_bytes(v.try_into().unwrap()))
390            } else {
391                None
392            }
393        })
394        .is_some_and(|old| old == nm);
395
396    if !nm_same {
397        record.set_aux(NM_TAG, AUX_TYPE_I, &nm.to_le_bytes());
398    }
399
400    let md_same = record
401        .aux_value(MD_TAG)
402        .and_then(|v| {
403            // The stored Z value includes a NUL terminator; strip it for comparison.
404            let stored = v.strip_suffix(&[0]).unwrap_or(v);
405            if stored.eq_ignore_ascii_case(md) {
406                Some(())
407            } else {
408                None
409            }
410        })
411        .is_some();
412
413    if !md_same {
414        // Z values are NUL-terminated on disk.
415        let mut md_z = Vec::with_capacity(md.len() + 1);
416        md_z.extend_from_slice(md);
417        md_z.push(0);
418        record.set_aux(MD_TAG, AUX_TYPE_Z, &md_z);
419    }
420}
421
422/// samtools' `bam_aux_del` + `bam_aux_append`: when the tag already exists with
423/// a different value it is deleted and re-appended at the end of the aux block;
424/// when it is absent it is appended. The absent case (the overwhelmingly common
425/// one — a fresh calmd run) is a plain `insert`, which appends in O(1) amortised.
426/// Only an existing tag triggers the order-preserving rebuild (noodles'
427/// `Data::remove` is a swap-remove, which would scramble the surviving order).
428fn replace_or_append(data: &mut noodles::sam::alignment::record_buf::Data, tag: Tag, value: Value) {
429    if data.get(&tag).is_none() {
430        data.insert(tag, value);
431        return;
432    }
433    let kept: Vec<(Tag, Value)> = data
434        .iter()
435        .filter(|(t, _)| *t != tag)
436        .map(|(t, v)| (t, v.clone()))
437        .collect();
438    let mut rebuilt: noodles::sam::alignment::record_buf::Data = kept.into_iter().collect();
439    rebuilt.insert(tag, value);
440    *data = rebuilt;
441}
442
443/// A reference contig loaded once and reused across consecutive records on it.
444///
445/// The sequence is stored as `Arc<Vec<u8>>` so that cloning for the parallel
446/// batch path is an atomic refcount bump (O(1), no copy), not a full chromosome
447/// memcpy. The `None` sentinel means the contig was not found in the reference.
448struct RefCache<R> {
449    reader: fasta::io::IndexedReader<R>,
450    current: Option<(usize, Arc<Vec<u8>>)>,
451}
452
453impl<R> RefCache<R>
454where
455    R: std::io::BufRead + std::io::Seek,
456{
457    /// Fetch the contig for `tid` (header-resolved name), reusing the held one
458    /// when the tid is unchanged. `None` means the contig is absent from the
459    /// reference — calmd leaves such records untouched (matching samtools).
460    fn get(&mut self, tid: usize, name: &[u8]) -> Result<Option<&[u8]>> {
461        if self.current.as_ref().is_none_or(|(t, _)| *t != tid) {
462            let region = noodles::core::Region::new(name.to_vec(), ..);
463            match self.reader.query(&region) {
464                Ok(record) => {
465                    self.current = Some((tid, Arc::new(record.sequence().as_ref().to_vec())));
466                }
467                Err(_) => {
468                    self.current = None;
469                    return Ok(None);
470                }
471            }
472        }
473        Ok(self.current.as_ref().map(|(_, seq)| seq.as_slice()))
474    }
475
476    /// Return an `Arc` handle to the contig for `tid`. On a cache hit this is
477    /// a single atomic refcount bump — no chromosome copy.
478    fn get_arc(&mut self, tid: usize, name: &[u8]) -> Result<Option<Arc<Vec<u8>>>> {
479        if self.current.as_ref().is_none_or(|(t, _)| *t != tid) {
480            let region = noodles::core::Region::new(name.to_vec(), ..);
481            match self.reader.query(&region) {
482                Ok(record) => {
483                    self.current = Some((tid, Arc::new(record.sequence().as_ref().to_vec())));
484                }
485                Err(_) => {
486                    self.current = None;
487                    return Ok(None);
488                }
489            }
490        }
491        Ok(self.current.as_ref().map(|(_, arc)| Arc::clone(arc)))
492    }
493}
494
495pub fn calmd(
496    input: &Path,
497    reference: &Path,
498    output_path: Option<&Path>,
499    opts: &CalmdOpts,
500    workers: NonZero<usize>,
501) -> Result<CalmdStats> {
502    let fasta_reader = fasta::io::indexed_reader::Builder::default()
503        .build_from_path(reference)
504        .map_err(|e| {
505            RsomicsError::InvalidInput(format!("reference {}: {e}", reference.display()))
506        })?;
507    let mut refs = RefCache {
508        reader: fasta_reader,
509        current: None,
510    };
511
512    let mut reader = rsomics_bamio::open_with_workers(input, workers)?;
513    let header = reader.read_header().map_err(RsomicsError::Io)?;
514
515    match output_path {
516        Some(path) => {
517            // All thread counts use the work-stealing BGZF writer so the
518            // compressed output is byte-identical across -t1/-t4/-t8. The WS
519            // writer runs libdeflate level 6 on a fixed ring with no per-block
520            // allocation regardless of worker count.
521            let ws_writer = rsomics_bamio::create_ws_with_workers(path, workers)?;
522            if workers.get() == 1 {
523                run_serial_raw_ws(reader.get_mut(), ws_writer, &header, &mut refs, opts)
524            } else {
525                run_parallel_raw_ws(
526                    reader.get_mut(),
527                    ws_writer,
528                    &header,
529                    &mut refs,
530                    opts,
531                    workers,
532                )
533            }
534        }
535        None => {
536            let mut writer = bam::io::Writer::new(std::io::stdout().lock());
537            run_serial_fallback(&mut reader, &mut writer, &header, &mut refs, opts)
538        }
539    }
540}
541
542/// Serial raw path backed by the work-stealing BGZF writer, for `-t1` with file
543/// output. Uses the same WS writer as the parallel path so the compressed
544/// output is byte-identical to `-t4`/`-t8` runs on the same input.
545fn run_serial_raw_ws<R, F>(
546    inner: &mut R,
547    mut writer: WsBamWriter,
548    header: &noodles::sam::Header,
549    refs: &mut RefCache<F>,
550    opts: &CalmdOpts,
551) -> Result<CalmdStats>
552where
553    R: std::io::BufRead,
554    F: std::io::BufRead + std::io::Seek,
555{
556    use rsomics_bamio::finish_ws_bam_writer;
557    use rsomics_bamio::raw::write_record;
558
559    writer.write_header(header).map_err(RsomicsError::Io)?;
560
561    let mut stats = CalmdStats::default();
562    let mut record = RawRecord::default();
563    let mut cigar_buf: Vec<(u8, u32)> = Vec::new();
564    let mut md: Vec<u8> = Vec::new();
565
566    loop {
567        let n = raw::read_record(inner, &mut record)?;
568        if n == 0 {
569            break;
570        }
571        stats.records += 1;
572        process_record_raw(
573            &mut record,
574            header,
575            refs,
576            opts,
577            &mut stats,
578            &mut cigar_buf,
579            &mut md,
580        )?;
581        // Write directly to the WsBamWriter's inner stream (the WS BGZF
582        // writer), bypassing the noodles bam codec round-trip.
583        write_record(writer.get_mut(), &record)?;
584    }
585
586    // Flush all pending BGZF blocks and append the EOF marker.
587    finish_ws_bam_writer(writer)?;
588
589    Ok(stats)
590}
591
592/// Fallback serial path for stdout output using noodles RecordBuf (stdout has no
593/// parallel BGZF writer, so we use the original decoded path for correctness).
594fn run_serial_fallback<R, W, F>(
595    reader: &mut bam::io::Reader<R>,
596    writer: &mut bam::io::Writer<W>,
597    header: &noodles::sam::Header,
598    refs: &mut RefCache<F>,
599    opts: &CalmdOpts,
600) -> Result<CalmdStats>
601where
602    R: std::io::Read,
603    W: std::io::Write,
604    F: std::io::BufRead + std::io::Seek,
605{
606    writer.write_header(header).map_err(RsomicsError::Io)?;
607
608    let mut stats = CalmdStats::default();
609    let mut record = RecordBuf::default();
610    let mut cigar: Vec<(Kind, usize)> = Vec::new();
611    let mut md: Vec<u8> = Vec::new();
612    while reader
613        .read_record_buf(header, &mut record)
614        .map_err(RsomicsError::Io)?
615        != 0
616    {
617        stats.records += 1;
618        process_record(
619            &mut record,
620            header,
621            refs,
622            opts,
623            &mut stats,
624            &mut cigar,
625            &mut md,
626        )?;
627        writer
628            .write_alignment_record(header, &record)
629            .map_err(RsomicsError::Io)?;
630    }
631
632    Ok(stats)
633}
634
635/// Parallel raw path: read `BATCH_SIZE` raw records serially, prefetch contigs,
636/// rayon-compute MD/NM on packed nibble SEQ in parallel, hand each computed batch
637/// to the work-stealing batched writer in order.
638///
639/// Uses [`WsBatchBamWriter`] over a [`WsBamWriter`]: the underlying BGZF
640/// compressor is the fixed-ring work-stealing writer (no per-block channel
641/// allocation), which removes the throughput ceiling at `-t4` and above that
642/// noodles' per-block-channel `MultithreadedWriter` imposed.
643fn run_parallel_raw_ws<R, F>(
644    inner: &mut R,
645    mut writer: WsBamWriter,
646    header: &noodles::sam::Header,
647    refs: &mut RefCache<F>,
648    opts: &CalmdOpts,
649    workers: NonZero<usize>,
650) -> Result<CalmdStats>
651where
652    R: std::io::BufRead,
653    F: std::io::BufRead + std::io::Seek,
654{
655    writer.write_header(header).map_err(RsomicsError::Io)?;
656    let mut batch_writer = WsBatchBamWriter::new(writer);
657
658    let timing = std::env::var("CALMD_PHASE_TIMING").is_ok();
659
660    let pool = rayon::ThreadPoolBuilder::new()
661        .num_threads(workers.get())
662        .build()
663        .map_err(|e| RsomicsError::InvalidInput(format!("rayon pool: {e}")))?;
664
665    let mut stats = CalmdStats::default();
666    let mut contig_map: HashMap<usize, Arc<Vec<u8>>> = HashMap::new();
667
668    let mut t_read = std::time::Duration::ZERO;
669    let mut t_compute = std::time::Duration::ZERO;
670    let mut t_write = std::time::Duration::ZERO;
671
672    loop {
673        let t0 = Instant::now();
674        let mut batch: Vec<RawRecord> = Vec::with_capacity(BATCH_SIZE);
675        for _ in 0..BATCH_SIZE {
676            let mut record = RawRecord::default();
677            let n = raw::read_record(inner, &mut record)?;
678            if n == 0 {
679                break;
680            }
681            batch.push(record);
682        }
683        if timing {
684            t_read += t0.elapsed();
685        }
686        if batch.is_empty() {
687            break;
688        }
689
690        for record in &batch {
691            let flags = record.flags();
692            if flags & 0x4 != 0 {
693                continue;
694            }
695            let tid = record.reference_sequence_id();
696            if tid < 0 {
697                continue;
698            }
699            let tid_usize = tid as usize;
700            if contig_map.contains_key(&tid_usize) {
701                continue;
702            }
703            let Some((name, _)) = header.reference_sequences().get_index(tid_usize) else {
704                continue;
705            };
706            if let Some(seq) = refs.get_arc(tid_usize, name.as_ref())? {
707                contig_map.insert(tid_usize, seq);
708            }
709        }
710
711        let t1 = Instant::now();
712        let contig_map_ref = &contig_map;
713        let opts_ref = opts;
714        let header_ref = header;
715
716        let batch_computed = std::sync::atomic::AtomicU64::new(0);
717        let batch_missing = std::sync::atomic::AtomicU64::new(0);
718        let batch_noseq = std::sync::atomic::AtomicU64::new(0);
719
720        pool.install(|| {
721            batch.par_iter_mut().for_each(|record| {
722                process_record_raw_parallel(
723                    record,
724                    header_ref,
725                    contig_map_ref,
726                    opts_ref,
727                    &batch_computed,
728                    &batch_missing,
729                    &batch_noseq,
730                );
731            });
732        });
733        if timing {
734            t_compute += t1.elapsed();
735        }
736
737        let t2 = Instant::now();
738        stats.records += batch.len() as u64;
739        stats.computed += batch_computed.load(std::sync::atomic::Ordering::Relaxed);
740        stats.missing_ref += batch_missing.load(std::sync::atomic::Ordering::Relaxed);
741        stats.no_sequence += batch_noseq.load(std::sync::atomic::Ordering::Relaxed);
742        batch_writer.write_records_batch(batch)?;
743        if timing {
744            t_write += t2.elapsed();
745        }
746    }
747
748    batch_writer.finish()?;
749
750    if timing {
751        eprintln!(
752            "PHASE TIMING: read={:.3}s compute={:.3}s write={:.3}s total_phase={:.3}s",
753            t_read.as_secs_f64(),
754            t_compute.as_secs_f64(),
755            t_write.as_secs_f64(),
756            (t_read + t_compute + t_write).as_secs_f64()
757        );
758    }
759
760    Ok(stats)
761}
762
763/// Per-record raw MD/NM pass for the parallel path. Reads nibble SEQ and raw
764/// CIGAR from `RawRecord`, updates MD/NM aux in-place. No noodles decode.
765fn process_record_raw_parallel(
766    record: &mut RawRecord,
767    header: &noodles::sam::Header,
768    contig_map: &HashMap<usize, Arc<Vec<u8>>>,
769    opts: &CalmdOpts,
770    computed: &std::sync::atomic::AtomicU64,
771    missing_ref: &std::sync::atomic::AtomicU64,
772    no_sequence: &std::sync::atomic::AtomicU64,
773) {
774    use std::sync::atomic::Ordering::Relaxed;
775
776    let flags = record.flags();
777    if flags & 0x4 != 0 {
778        return;
779    }
780    let tid = record.reference_sequence_id();
781    if tid < 0 {
782        return;
783    }
784    let tid_usize = tid as usize;
785    let pos_raw = record.alignment_start();
786    if pos_raw < 0 {
787        return;
788    }
789    if header.reference_sequences().get_index(tid_usize).is_none() {
790        return;
791    }
792
793    let Some(ref_seq) = contig_map.get(&tid_usize) else {
794        missing_ref.fetch_add(1, Relaxed);
795        return;
796    };
797
798    let seq_len = record.sequence_len();
799    if seq_len == 0 {
800        no_sequence.fetch_add(1, Relaxed);
801        return;
802    }
803
804    let pos = pos_raw as usize;
805
806    // Collect CIGAR ops from the raw payload into a small inline buffer.
807    let mut cigar: Vec<(u8, u32)> = record.cigar_ops().collect();
808
809    let mut md: Vec<u8> = Vec::new();
810
811    // Compute MD/NM directly on the raw nibble SEQ, mutating the record's
812    // packed SEQ bytes in place for `use_equal`. `seq_bytes_mut()` accesses
813    // the packed [(l_seq+1)/2] bytes starting after name+cigar in the payload.
814    let nm = {
815        let seq_bytes = record.seq_bytes_mut();
816        compute_md_nm_raw(
817            seq_bytes,
818            seq_len,
819            &cigar,
820            ref_seq,
821            pos,
822            opts.use_equal,
823            &mut md,
824        )
825    };
826
827    // Patch the raw aux tail: set NM and MD without decoding the rest of the record.
828    apply_tags_raw(record, nm, &md);
829    computed.fetch_add(1, Relaxed);
830    cigar.clear();
831}
832
833/// Per-record raw MD/NM pass for the serial path (reuses scratch buffers).
834fn process_record_raw<F>(
835    record: &mut RawRecord,
836    header: &noodles::sam::Header,
837    refs: &mut RefCache<F>,
838    opts: &CalmdOpts,
839    stats: &mut CalmdStats,
840    cigar_buf: &mut Vec<(u8, u32)>,
841    md: &mut Vec<u8>,
842) -> Result<()>
843where
844    F: std::io::BufRead + std::io::Seek,
845{
846    let flags = record.flags();
847    if flags & 0x4 != 0 {
848        return Ok(());
849    }
850    let tid = record.reference_sequence_id();
851    if tid < 0 {
852        return Ok(());
853    }
854    let tid_usize = tid as usize;
855    let pos_raw = record.alignment_start();
856    if pos_raw < 0 {
857        return Ok(());
858    }
859    let Some((name, _)) = header.reference_sequences().get_index(tid_usize) else {
860        return Ok(());
861    };
862
863    let ref_seq = match refs.get(tid_usize, name.as_ref())? {
864        Some(seq) => seq,
865        None => {
866            stats.missing_ref += 1;
867            return Ok(());
868        }
869    };
870
871    let seq_len = record.sequence_len();
872    if seq_len == 0 {
873        stats.no_sequence += 1;
874        return Ok(());
875    }
876
877    let pos = pos_raw as usize;
878
879    cigar_buf.clear();
880    cigar_buf.extend(record.cigar_ops());
881    md.clear();
882
883    let nm = {
884        let seq_bytes = record.seq_bytes_mut();
885        compute_md_nm_raw(
886            seq_bytes,
887            seq_len,
888            cigar_buf,
889            ref_seq,
890            pos,
891            opts.use_equal,
892            md,
893        )
894    };
895
896    apply_tags_raw(record, nm, md);
897    stats.computed += 1;
898    Ok(())
899}
900
901/// The per-record body of samtools' `bam_fillmd` loop: skip unmapped/refless
902/// records, fetch the contig (held by reference — never copied), run the MD/NM
903/// walk in place, and apply the tag updates. Used by the stdout fallback path.
904fn process_record<F>(
905    record: &mut RecordBuf,
906    header: &noodles::sam::Header,
907    refs: &mut RefCache<F>,
908    opts: &CalmdOpts,
909    stats: &mut CalmdStats,
910    cigar: &mut Vec<(Kind, usize)>,
911    md: &mut Vec<u8>,
912) -> Result<()>
913where
914    F: std::io::BufRead + std::io::Seek,
915{
916    if record.flags().is_unmapped() {
917        return Ok(());
918    }
919    let Some(tid) = record.reference_sequence_id() else {
920        return Ok(());
921    };
922    let Some(start) = record.alignment_start() else {
923        return Ok(());
924    };
925    let Some((name, _)) = header.reference_sequences().get_index(tid) else {
926        return Ok(());
927    };
928
929    // The contig slice borrows the cache and is only read in `compute_md_nm`;
930    // the record mutation that follows borrows `record`, a disjoint object. The
931    // contig is therefore never copied per record (the MD pass stays
932    // O(records + contig_len), not O(records × contig_len)).
933    let ref_seq = match refs.get(tid, name.as_ref())? {
934        Some(seq) => seq,
935        None => {
936            stats.missing_ref += 1;
937            return Ok(());
938        }
939    };
940
941    if record.sequence().is_empty() {
942        stats.no_sequence += 1;
943        return Ok(());
944    }
945
946    let pos = start.get() - 1;
947    cigar.clear();
948    cigar.extend(
949        record
950            .cigar()
951            .as_ref()
952            .iter()
953            .map(|op| (op.kind(), op.len())),
954    );
955    md.clear();
956
957    let seq = record.sequence_mut().as_mut();
958    let nm = compute_md_nm(seq, cigar, ref_seq, pos, opts.use_equal, md);
959
960    apply_tags(record, nm, md);
961    stats.computed += 1;
962    Ok(())
963}
964
965#[cfg(test)]
966mod tests {
967    use super::*;
968
969    fn run_walk(seq: &[u8], cigar: &[(Kind, usize)], rf: &[u8], pos: usize) -> (String, i32) {
970        let mut s = seq.to_vec();
971        let mut md = Vec::new();
972        let nm = compute_md_nm(&mut s, cigar, rf, pos, false, &mut md);
973        (String::from_utf8(md).unwrap(), nm)
974    }
975
976    #[test]
977    fn perfect_match() {
978        let (md, nm) = run_walk(b"ACGTACGT", &[(Kind::Match, 8)], b"ACGTACGT", 0);
979        assert_eq!((md.as_str(), nm), ("8", 0));
980    }
981
982    #[test]
983    fn two_mismatches() {
984        let (md, nm) = run_walk(b"ATGTACAT", &[(Kind::Match, 8)], b"ACGTACGT", 0);
985        assert_eq!((md.as_str(), nm), ("1C4G1", 2));
986    }
987
988    #[test]
989    fn deletion() {
990        // 4M2D2M over ref ACGT GG AC, read = ACGT AC.
991        let (md, nm) = run_walk(
992            b"ACGTAC",
993            &[(Kind::Match, 4), (Kind::Deletion, 2), (Kind::Match, 2)],
994            b"ACGTGGAC",
995            0,
996        );
997        assert_eq!((md.as_str(), nm), ("4^GG2", 2));
998    }
999
1000    #[test]
1001    fn insertion_adds_only_nm() {
1002        // 3M2I3M; the inserted bases never appear in MD but bump NM by 2.
1003        let (md, nm) = run_walk(
1004            b"ACGTTACG",
1005            &[(Kind::Match, 3), (Kind::Insertion, 2), (Kind::Match, 3)],
1006            b"ACGACG",
1007            0,
1008        );
1009        assert_eq!((md.as_str(), nm), ("6", 2));
1010    }
1011
1012    #[test]
1013    fn soft_clip_and_skip() {
1014        // 2S3M3N3M: soft-clip consumes read only; N consumes ref only.
1015        let (md, nm) = run_walk(
1016            b"NNACGTAC",
1017            &[
1018                (Kind::SoftClip, 2),
1019                (Kind::Match, 3),
1020                (Kind::Skip, 3),
1021                (Kind::Match, 3),
1022            ],
1023            b"ACGXXXTAC",
1024            0,
1025        );
1026        assert_eq!((md.as_str(), nm), ("6", 0));
1027    }
1028
1029    #[test]
1030    fn n_in_reference_is_mismatch() {
1031        // A read base over an N reference base is a mismatch emitting the ref N.
1032        let (md, nm) = run_walk(b"ACGT", &[(Kind::Match, 4)], b"ACNT", 0);
1033        assert_eq!((md.as_str(), nm), ("2N1", 1));
1034    }
1035
1036    #[test]
1037    fn equal_base_in_read_is_match() {
1038        // A read code of 0 (`=`) is always a match regardless of the ref base.
1039        let (md, nm) = run_walk(b"=CGT", &[(Kind::Match, 4)], b"ACGT", 0);
1040        assert_eq!((md.as_str(), nm), ("4", 0));
1041    }
1042
1043    #[test]
1044    fn use_equal_rewrites_matches() {
1045        let mut s = b"ATGT".to_vec();
1046        let mut md = Vec::new();
1047        // Position 1 mismatches (T vs C); the other three match → become `=`.
1048        let nm = compute_md_nm(&mut s, &[(Kind::Match, 4)], b"ACGT", 0, true, &mut md);
1049        assert_eq!(&s, b"=T==");
1050        assert_eq!((String::from_utf8(md).unwrap().as_str(), nm), ("1C2", 1));
1051    }
1052}