holodeck_lib/commands/
eval.rs1use std::collections::BTreeMap;
4use std::io::Write;
5use std::path::PathBuf;
6
7use anyhow::{Context, Result};
8use bstr::ByteSlice;
9use clap::Parser;
10use noodles::bam;
11
12use super::command::{Command, output_path};
13use super::common::OutputPrefixOptions;
14use crate::read_naming::{parse_encoded_pe_name, parse_encoded_se_name};
15
16#[derive(Parser, Debug)]
23#[command(after_long_help = "EXAMPLES:\n \
24 holodeck eval --mapped aligned.bam -o eval_results\n \
25 holodeck eval --mapped aligned.bam --truth golden.bam -o eval_results")]
26pub struct Eval {
27 #[arg(short = 'm', long, value_name = "BAM")]
29 pub mapped: PathBuf,
30
31 #[arg(long, value_name = "BAM")]
34 pub truth: Option<PathBuf>,
35
36 #[command(flatten)]
37 pub output: OutputPrefixOptions,
38
39 #[arg(long, default_value_t = 5, value_name = "INT")]
43 pub wiggle: u32,
44}
45
46#[derive(Debug, Default, Clone)]
48struct BinCounts {
49 correct: u64,
51 mismapped: u64,
53 unmapped: u64,
55 total: u64,
57}
58
59impl Command for Eval {
60 fn execute(&self) -> Result<()> {
61 if self.truth.is_some() {
62 log::warn!("--truth (golden BAM) is not yet implemented; using read names");
63 }
64
65 let mut reader = bam::io::reader::Builder
66 .build_from_path(&self.mapped)
67 .with_context(|| format!("Failed to open BAM: {}", self.mapped.display()))?;
68
69 let header = reader.read_header()?;
70
71 let mut bins: BTreeMap<u8, BinCounts> = BTreeMap::new();
73 let mut total_reads: u64 = 0;
74 let mut parse_failures: u64 = 0;
75
76 for result in reader.records() {
77 let record = result.with_context(|| "Failed to read BAM record")?;
78
79 let flags = record.flags();
81 if flags.is_secondary() || flags.is_supplementary() {
82 continue;
83 }
84 total_reads += 1;
85
86 let name_bytes = record.name().map_or(&b""[..], |n| n.as_bytes());
88 let name = name_bytes.to_str().unwrap_or("");
89
90 let truth = if let Some((_, r1, r2)) = parse_encoded_pe_name(name) {
94 if flags.is_last_segment() { Some(r2) } else { Some(r1) }
95 } else {
96 parse_encoded_se_name(name).map(|(_, truth)| truth)
97 };
98
99 let Some(truth) = truth else {
100 parse_failures += 1;
101 continue;
102 };
103
104 let mapq = record.mapping_quality().map_or(0, u8::from);
105 let bin_key = mapq_bin(mapq);
106 let counts = bins.entry(bin_key).or_default();
107 counts.total += 1;
108
109 if flags.is_unmapped() {
110 counts.unmapped += 1;
111 continue;
112 }
113
114 let mapped_contig_idx = record.reference_sequence_id().and_then(Result::ok);
116 let mapped_pos = record.alignment_start().and_then(Result::ok).map(usize::from);
117
118 let (Some(mapped_contig_idx), Some(mapped_pos_1based)) =
119 (mapped_contig_idx, mapped_pos)
120 else {
121 counts.mismapped += 1;
122 continue;
123 };
124
125 let Some((contig_name, _)) = header.reference_sequences().get_index(mapped_contig_idx)
127 else {
128 counts.mismapped += 1;
129 continue;
130 };
131 let mapped_contig = String::from_utf8_lossy(contig_name.as_ref());
132
133 #[expect(clippy::cast_possible_truncation, reason = "mapped positions fit u32")]
135 let mapped_pos_u32 = mapped_pos_1based as u32;
136 let is_correct = mapped_contig == truth.contig
137 && mapped_pos_u32.abs_diff(truth.position) <= self.wiggle;
138
139 if is_correct {
140 counts.correct += 1;
141 } else {
142 counts.mismapped += 1;
143 }
144 }
145
146 if parse_failures > 0 {
147 log::warn!("{parse_failures} reads had unparseable names; skipped");
148 }
149
150 let output_file = output_path(&self.output.output, ".eval.txt");
152 let mut out = std::fs::File::create(&output_file)
153 .with_context(|| format!("Failed to create {}", output_file.display()))?;
154
155 writeln!(
156 out,
157 "mapq_bin\ttotal\tcorrect\tmismapped\tunmapped\tpct_correct\tpct_mismapped\tpct_unmapped"
158 )?;
159
160 let mut grand_total = BinCounts::default();
161 for (&bin, counts) in &bins {
162 write_bin_row(&mut out, &format_bin_label(bin), counts)?;
163 grand_total.correct += counts.correct;
164 grand_total.mismapped += counts.mismapped;
165 grand_total.unmapped += counts.unmapped;
166 grand_total.total += counts.total;
167 }
168 write_bin_row(&mut out, "ALL", &grand_total)?;
169
170 log::info!(
171 "Evaluated {total_reads} reads: {} correct, {} mismapped, {} unmapped",
172 grand_total.correct,
173 grand_total.mismapped,
174 grand_total.unmapped
175 );
176 log::info!("Results written to: {}", output_file.display());
177
178 Ok(())
179 }
180}
181
182fn mapq_bin(mapq: u8) -> u8 {
184 match mapq {
185 0 => 0,
186 1..=9 => 1,
187 10..=19 => 10,
188 20..=29 => 20,
189 30..=39 => 30,
190 40..=49 => 40,
191 50..=59 => 50,
192 _ => 60,
193 }
194}
195
196fn format_bin_label(bin: u8) -> String {
198 match bin {
199 0 => "0".to_string(),
200 1 => "1-9".to_string(),
201 60 => "60+".to_string(),
202 _ => format!("{}-{}", bin, bin + 9),
203 }
204}
205
206fn write_bin_row(out: &mut impl Write, label: &str, counts: &BinCounts) -> Result<()> {
208 let total = counts.total.max(1) as f64;
209 writeln!(
210 out,
211 "{label}\t{}\t{}\t{}\t{}\t{:.2}\t{:.2}\t{:.2}",
212 counts.total,
213 counts.correct,
214 counts.mismapped,
215 counts.unmapped,
216 counts.correct as f64 / total * 100.0,
217 counts.mismapped as f64 / total * 100.0,
218 counts.unmapped as f64 / total * 100.0,
219 )?;
220 Ok(())
221}
222
223#[cfg(test)]
224mod tests {
225 use super::*;
226
227 #[test]
228 fn test_mapq_bin() {
229 assert_eq!(mapq_bin(0), 0);
230 assert_eq!(mapq_bin(5), 1);
231 assert_eq!(mapq_bin(10), 10);
232 assert_eq!(mapq_bin(15), 10);
233 assert_eq!(mapq_bin(30), 30);
234 assert_eq!(mapq_bin(60), 60);
235 assert_eq!(mapq_bin(255), 60);
236 }
237
238 #[test]
239 fn test_format_bin_label() {
240 assert_eq!(format_bin_label(0), "0");
241 assert_eq!(format_bin_label(1), "1-9");
242 assert_eq!(format_bin_label(10), "10-19");
243 assert_eq!(format_bin_label(60), "60+");
244 }
245
246 #[test]
247 fn test_write_bin_row() {
248 let counts = BinCounts { correct: 90, mismapped: 8, unmapped: 2, total: 100 };
249 let mut buf = Vec::new();
250 write_bin_row(&mut buf, "30-39", &counts).unwrap();
251 let line = String::from_utf8(buf).unwrap();
252 assert!(line.starts_with("30-39\t100\t90\t8\t2\t"));
253 assert!(line.contains("90.00"));
254 }
255}