Skip to main content

holodeck_lib/commands/
eval.rs

1//! Alignment accuracy evaluation command.
2
3use std::collections::BTreeMap;
4use std::io::Write;
5use std::path::PathBuf;
6
7use anyhow::{Context, Result};
8use bstr::ByteSlice;
9use clap::Parser;
10use noodles::bam;
11
12use super::command::{Command, output_path};
13use super::common::OutputPrefixOptions;
14use crate::read_naming::{parse_encoded_pe_name, parse_encoded_se_name};
15
16/// Evaluate alignment accuracy of simulated reads.
17///
18/// Compares the true (simulated) positions of reads against their mapped
19/// positions in a BAM file.  Reports mapping accuracy, mismapping rate, and
20/// unmapped rate stratified by MAPQ bin.  Truth positions are parsed from
21/// encoded read names (default holodeck format).
22#[derive(Parser, Debug)]
23#[command(after_long_help = "EXAMPLES:\n  \
24    holodeck eval --mapped aligned.bam -o eval_results\n  \
25    holodeck eval --mapped aligned.bam --truth golden.bam -o eval_results")]
26pub struct Eval {
27    /// BAM file of mapped reads to evaluate.
28    #[arg(short = 'm', long, value_name = "BAM")]
29    pub mapped: PathBuf,
30
31    /// Optional golden BAM file with truth alignments. If omitted, truth
32    /// positions are parsed from encoded read names.
33    #[arg(long, value_name = "BAM")]
34    pub truth: Option<PathBuf>,
35
36    #[command(flatten)]
37    pub output: OutputPrefixOptions,
38
39    /// Maximum distance (in bases) between the true and mapped start positions
40    /// of a read for it to be considered correctly mapped. Uses
41    /// `|mapped_start - true_start| <= wiggle` on the same contig.
42    #[arg(long, default_value_t = 5, value_name = "INT")]
43    pub wiggle: u32,
44}
45
46/// Accuracy counts for a single MAPQ bin.
47#[derive(Debug, Default, Clone)]
48struct BinCounts {
49    /// Reads mapped to the correct position (within wiggle).
50    correct: u64,
51    /// Reads mapped to a wrong position or wrong contig.
52    mismapped: u64,
53    /// Reads that are unmapped.
54    unmapped: u64,
55    /// Total reads in this bin.
56    total: u64,
57}
58
59impl Command for Eval {
60    fn execute(&self) -> Result<()> {
61        if self.truth.is_some() {
62            log::warn!("--truth (golden BAM) is not yet implemented; using read names");
63        }
64
65        let mut reader = bam::io::reader::Builder
66            .build_from_path(&self.mapped)
67            .with_context(|| format!("Failed to open BAM: {}", self.mapped.display()))?;
68
69        let header = reader.read_header()?;
70
71        // MAPQ bins: 0, 1-9, 10-19, 20-29, 30-39, 40-49, 50-59, 60+
72        let mut bins: BTreeMap<u8, BinCounts> = BTreeMap::new();
73        let mut total_reads: u64 = 0;
74        let mut parse_failures: u64 = 0;
75
76        for result in reader.records() {
77            let record = result.with_context(|| "Failed to read BAM record")?;
78
79            // Skip secondary and supplementary alignments before counting.
80            let flags = record.flags();
81            if flags.is_secondary() || flags.is_supplementary() {
82                continue;
83            }
84            total_reads += 1;
85
86            // Get read name.
87            let name_bytes = record.name().map_or(&b""[..], |n| n.as_bytes());
88            let name = name_bytes.to_str().unwrap_or("");
89
90            // Parse truth from encoded read name. For PE names, pick R1 or R2
91            // based on the record's segment flag; mis-selecting here caused R2
92            // alignments to be scored against the R1 truth position.
93            let truth = if let Some((_, r1, r2)) = parse_encoded_pe_name(name) {
94                if flags.is_last_segment() { Some(r2) } else { Some(r1) }
95            } else {
96                parse_encoded_se_name(name).map(|(_, truth)| truth)
97            };
98
99            let Some(truth) = truth else {
100                parse_failures += 1;
101                continue;
102            };
103
104            let mapq = record.mapping_quality().map_or(0, u8::from);
105            let bin_key = mapq_bin(mapq);
106            let counts = bins.entry(bin_key).or_default();
107            counts.total += 1;
108
109            if flags.is_unmapped() {
110                counts.unmapped += 1;
111                continue;
112            }
113
114            // Get mapped position.
115            let mapped_contig_idx = record.reference_sequence_id().and_then(Result::ok);
116            let mapped_pos = record.alignment_start().and_then(Result::ok).map(usize::from);
117
118            let (Some(mapped_contig_idx), Some(mapped_pos_1based)) =
119                (mapped_contig_idx, mapped_pos)
120            else {
121                counts.mismapped += 1;
122                continue;
123            };
124
125            // Resolve mapped contig name.
126            let Some((contig_name, _)) = header.reference_sequences().get_index(mapped_contig_idx)
127            else {
128                counts.mismapped += 1;
129                continue;
130            };
131            let mapped_contig = String::from_utf8_lossy(contig_name.as_ref());
132
133            // Compare truth vs mapped.
134            #[expect(clippy::cast_possible_truncation, reason = "mapped positions fit u32")]
135            let mapped_pos_u32 = mapped_pos_1based as u32;
136            let is_correct = mapped_contig == truth.contig
137                && mapped_pos_u32.abs_diff(truth.position) <= self.wiggle;
138
139            if is_correct {
140                counts.correct += 1;
141            } else {
142                counts.mismapped += 1;
143            }
144        }
145
146        if parse_failures > 0 {
147            log::warn!("{parse_failures} reads had unparseable names; skipped");
148        }
149
150        // Write results.
151        let output_file = output_path(&self.output.output, ".eval.txt");
152        let mut out = std::fs::File::create(&output_file)
153            .with_context(|| format!("Failed to create {}", output_file.display()))?;
154
155        writeln!(
156            out,
157            "mapq_bin\ttotal\tcorrect\tmismapped\tunmapped\tpct_correct\tpct_mismapped\tpct_unmapped"
158        )?;
159
160        let mut grand_total = BinCounts::default();
161        for (&bin, counts) in &bins {
162            write_bin_row(&mut out, &format_bin_label(bin), counts)?;
163            grand_total.correct += counts.correct;
164            grand_total.mismapped += counts.mismapped;
165            grand_total.unmapped += counts.unmapped;
166            grand_total.total += counts.total;
167        }
168        write_bin_row(&mut out, "ALL", &grand_total)?;
169
170        log::info!(
171            "Evaluated {total_reads} reads: {} correct, {} mismapped, {} unmapped",
172            grand_total.correct,
173            grand_total.mismapped,
174            grand_total.unmapped
175        );
176        log::info!("Results written to: {}", output_file.display());
177
178        Ok(())
179    }
180}
181
182/// Map a MAPQ value to a bin key.
183fn mapq_bin(mapq: u8) -> u8 {
184    match mapq {
185        0 => 0,
186        1..=9 => 1,
187        10..=19 => 10,
188        20..=29 => 20,
189        30..=39 => 30,
190        40..=49 => 40,
191        50..=59 => 50,
192        _ => 60,
193    }
194}
195
196/// Format a bin key as a label string.
197fn format_bin_label(bin: u8) -> String {
198    match bin {
199        0 => "0".to_string(),
200        1 => "1-9".to_string(),
201        60 => "60+".to_string(),
202        _ => format!("{}-{}", bin, bin + 9),
203    }
204}
205
206/// Write one row of the evaluation results table.
207fn write_bin_row(out: &mut impl Write, label: &str, counts: &BinCounts) -> Result<()> {
208    let total = counts.total.max(1) as f64;
209    writeln!(
210        out,
211        "{label}\t{}\t{}\t{}\t{}\t{:.2}\t{:.2}\t{:.2}",
212        counts.total,
213        counts.correct,
214        counts.mismapped,
215        counts.unmapped,
216        counts.correct as f64 / total * 100.0,
217        counts.mismapped as f64 / total * 100.0,
218        counts.unmapped as f64 / total * 100.0,
219    )?;
220    Ok(())
221}
222
223#[cfg(test)]
224mod tests {
225    use super::*;
226
227    #[test]
228    fn test_mapq_bin() {
229        assert_eq!(mapq_bin(0), 0);
230        assert_eq!(mapq_bin(5), 1);
231        assert_eq!(mapq_bin(10), 10);
232        assert_eq!(mapq_bin(15), 10);
233        assert_eq!(mapq_bin(30), 30);
234        assert_eq!(mapq_bin(60), 60);
235        assert_eq!(mapq_bin(255), 60);
236    }
237
238    #[test]
239    fn test_format_bin_label() {
240        assert_eq!(format_bin_label(0), "0");
241        assert_eq!(format_bin_label(1), "1-9");
242        assert_eq!(format_bin_label(10), "10-19");
243        assert_eq!(format_bin_label(60), "60+");
244    }
245
246    #[test]
247    fn test_write_bin_row() {
248        let counts = BinCounts { correct: 90, mismapped: 8, unmapped: 2, total: 100 };
249        let mut buf = Vec::new();
250        write_bin_row(&mut buf, "30-39", &counts).unwrap();
251        let line = String::from_utf8(buf).unwrap();
252        assert!(line.starts_with("30-39\t100\t90\t8\t2\t"));
253        assert!(line.contains("90.00"));
254    }
255}