1use std::fs::File;
2use std::io::{BufWriter, Write, stdout};
3use std::path::PathBuf;
4
5use anyhow::Result;
6use fgoxide::io::DelimFileWriter;
7use serde::Serialize;
8
9use crate::commands::command::Command;
10use crate::report::KrakenReport;
11
12#[derive(clap::Args)]
40pub struct ReportToTsv {
41 #[arg(short = 'r', long)]
43 kraken_report: PathBuf,
44
45 #[arg(short, long)]
47 output: Option<PathBuf>,
48}
49
50#[derive(Default, Serialize)]
52struct TsvRow {
53 tax_id: u64,
54 name: String,
55 rank: String,
56 level: usize,
57 parent_tax_id: String,
58 parent_rank: String,
59 clade_count: u64,
60 direct_count: u64,
61 descendant_count: u64,
62 frac_clade: f64,
63 frac_direct: f64,
64 frac_descendant: f64,
65 minimizer_count: String,
66 distinct_minimizer_count: String,
67}
68
69fn tsv_header() -> String {
72 let mut csv_writer =
73 csv::WriterBuilder::new().delimiter(b'\t').has_headers(true).from_writer(Vec::new());
74 csv_writer.serialize(TsvRow::default()).unwrap();
75 csv_writer.flush().unwrap();
76 let bytes = csv_writer.into_inner().unwrap();
77 let text = String::from_utf8(bytes).unwrap();
78 text.lines().next().unwrap().to_string()
79}
80
81impl Command for ReportToTsv {
82 fn execute(&self) -> Result<()> {
83 let report = KrakenReport::from_path(&self.kraken_report)?;
84 let rows = build_tsv_rows(&report);
85
86 let writer: BufWriter<Box<dyn Write + Send>> = match &self.output {
87 Some(path) => {
88 let file = File::create(path).map_err(|e| {
89 anyhow::anyhow!("failed to create output {}: {e}", path.display())
90 })?;
91 BufWriter::new(Box::new(file))
92 }
93 None => BufWriter::new(Box::new(stdout())),
94 };
95
96 if rows.is_empty() {
97 let mut w = writer;
100 writeln!(w, "{}", tsv_header())?;
101 w.flush()?;
102 } else {
103 let mut tsv_writer = DelimFileWriter::new(writer, b'\t', true);
104 tsv_writer.write_all(rows)?;
105 }
106
107 log::info!("Wrote {} rows to TSV.", report.len());
108 Ok(())
109 }
110}
111
112fn build_tsv_rows(report: &KrakenReport) -> Vec<TsvRow> {
114 let total_sequences = report.total_sequences();
115 let has_minimizer_data = report.has_minimizer_data();
116 let mut tsv_rows = Vec::with_capacity(report.len());
117
118 for (i, row) in report.rows().iter().enumerate() {
119 let clade_count = row.clade_count();
120 let direct_count = row.direct_count();
121 let descendant_count = clade_count - direct_count;
122
123 #[allow(clippy::cast_precision_loss)]
124 let (frac_clade, frac_direct, frac_descendant) = if total_sequences > 0 {
125 (
126 clade_count as f64 / total_sequences as f64,
127 direct_count as f64 / total_sequences as f64,
128 descendant_count as f64 / total_sequences as f64,
129 )
130 } else {
131 (0.0, 0.0, 0.0)
132 };
133
134 let (minimizer_count, distinct_minimizer_count) = if has_minimizer_data {
135 let mc = row.minimizer_count().unwrap_or(0);
136 let dmc = row.distinct_minimizer_count().unwrap_or(0);
137 (mc.to_string(), dmc.to_string())
138 } else {
139 (String::new(), String::new())
140 };
141
142 let (parent_tax_id, parent_rank) = match report.parent(i) {
143 Some(parent) => (parent.taxon_id().to_string(), parent.taxonomic_rank().to_string()),
144 None => (String::new(), String::new()),
145 };
146
147 tsv_rows.push(TsvRow {
148 tax_id: row.taxon_id(),
149 name: row.name().to_string(),
150 rank: row.taxonomic_rank().to_string(),
151 level: row.depth(),
152 parent_tax_id,
153 parent_rank,
154 clade_count,
155 direct_count,
156 descendant_count,
157 frac_clade,
158 frac_direct,
159 frac_descendant,
160 minimizer_count,
161 distinct_minimizer_count,
162 });
163 }
164
165 tsv_rows
166}
167
168#[cfg(test)]
169mod tests {
170 use super::*;
171
172 fn standard_line(
174 pct: f64,
175 clade: u64,
176 direct: u64,
177 rank: &str,
178 taxid: u64,
179 name: &str,
180 depth: usize,
181 ) -> String {
182 let indent = " ".repeat(depth * 2);
183 format!("{pct:.2}\t{clade}\t{direct}\t{rank}\t{taxid}\t{indent}{name}")
184 }
185
186 #[allow(clippy::too_many_arguments)]
188 fn extended_line(
189 pct: f64,
190 clade: u64,
191 direct: u64,
192 minimizers: u64,
193 distinct: u64,
194 rank: &str,
195 taxid: u64,
196 name: &str,
197 depth: usize,
198 ) -> String {
199 let indent = " ".repeat(depth * 2);
200 format!(
201 "{pct:.2}\t{clade}\t{direct}\t{minimizers}\t{distinct}\t{rank}\t{taxid}\t{indent}{name}"
202 )
203 }
204
205 fn parse(report: &str) -> KrakenReport {
206 KrakenReport::from_reader(report.as_bytes()).unwrap()
207 }
208
209 fn make_standard_report() -> KrakenReport {
210 parse(
211 &[
212 standard_line(10.0, 100, 100, "U", 0, "unclassified", 0),
213 standard_line(90.0, 900, 5, "R", 1, "root", 0),
214 standard_line(60.0, 600, 10, "D", 2, "Bacteria", 1),
215 standard_line(50.0, 500, 500, "S", 3, "Escherichia coli", 2),
216 standard_line(30.0, 300, 10, "D", 4, "Eukaryota", 1),
217 standard_line(20.0, 200, 200, "S", 5, "Homo sapiens", 2),
218 ]
219 .join("\n"),
220 )
221 }
222
223 fn make_extended_report() -> KrakenReport {
224 parse(
225 &[
226 extended_line(10.0, 100, 100, 0, 0, "U", 0, "unclassified", 0),
227 extended_line(90.0, 900, 5, 500, 400, "R", 1, "root", 0),
228 extended_line(60.0, 600, 10, 300, 250, "D", 2, "Bacteria", 1),
229 extended_line(50.0, 500, 500, 200, 150, "S", 3, "Escherichia coli", 2),
230 extended_line(30.0, 300, 10, 200, 150, "D", 4, "Eukaryota", 1),
231 extended_line(20.0, 200, 200, 100, 80, "S", 5, "Homo sapiens", 2),
232 ]
233 .join("\n"),
234 )
235 }
236
237 fn write_rows_to_string(rows: Vec<TsvRow>) -> String {
239 let dir = tempfile::TempDir::new().unwrap();
240 let path = dir.path().join("out.tsv");
241 let df = fgoxide::io::DelimFile::default();
242 df.write_tsv(&path, rows).unwrap();
243 std::fs::read_to_string(path).unwrap()
244 }
245
246 #[test]
247 fn test_level_values() {
248 let report = make_standard_report();
249 let rows = build_tsv_rows(&report);
250
251 assert_eq!(rows[0].level, 0); assert_eq!(rows[1].level, 0); assert_eq!(rows[2].level, 1); assert_eq!(rows[3].level, 2); assert_eq!(rows[4].level, 1); assert_eq!(rows[5].level, 2); }
258
259 #[test]
260 fn test_parent_fields_empty_for_root_and_unclassified() {
261 let report = make_standard_report();
262 let rows = build_tsv_rows(&report);
263
264 assert_eq!(rows[0].parent_tax_id, "");
266 assert_eq!(rows[0].parent_rank, "");
267
268 assert_eq!(rows[1].parent_tax_id, "");
270 assert_eq!(rows[1].parent_rank, "");
271 }
272
273 #[test]
274 fn test_parent_fields_populated_for_children() {
275 let report = make_standard_report();
276 let rows = build_tsv_rows(&report);
277
278 assert_eq!(rows[2].parent_tax_id, "1");
280 assert_eq!(rows[2].parent_rank, "R");
281
282 assert_eq!(rows[3].parent_tax_id, "2");
284 assert_eq!(rows[3].parent_rank, "D");
285
286 assert_eq!(rows[5].parent_tax_id, "4");
288 assert_eq!(rows[5].parent_rank, "D");
289 }
290
291 #[test]
292 fn test_descendant_count() {
293 let report = make_standard_report();
294 let rows = build_tsv_rows(&report);
295
296 assert_eq!(rows[0].descendant_count, 0);
298
299 assert_eq!(rows[1].descendant_count, 895);
301
302 assert_eq!(rows[2].descendant_count, 590);
304
305 assert_eq!(rows[3].descendant_count, 0);
307 }
308
309 #[test]
310 fn test_minimizer_columns_empty_for_standard_report() {
311 let report = make_standard_report();
312 let rows = build_tsv_rows(&report);
313
314 for row in &rows {
315 assert_eq!(row.minimizer_count, "");
316 assert_eq!(row.distinct_minimizer_count, "");
317 }
318 }
319
320 #[test]
321 fn test_minimizer_columns_populated_for_extended_report() {
322 let report = make_extended_report();
323 let rows = build_tsv_rows(&report);
324
325 assert_eq!(rows[1].minimizer_count, "500");
327 assert_eq!(rows[1].distinct_minimizer_count, "400");
328
329 assert_eq!(rows[2].minimizer_count, "300");
331 assert_eq!(rows[2].distinct_minimizer_count, "250");
332
333 for row in &rows {
335 assert!(!row.minimizer_count.is_empty());
336 assert!(!row.distinct_minimizer_count.is_empty());
337 }
338 }
339
340 #[test]
341 fn test_fraction_calculations() {
342 let report = make_standard_report();
343 let rows = build_tsv_rows(&report);
344
345 let ecoli = &rows[3];
348 assert!((ecoli.frac_clade - 0.5).abs() < 1e-9);
349 assert!((ecoli.frac_direct - 0.5).abs() < 1e-9);
350 assert!((ecoli.frac_descendant - 0.0).abs() < 1e-9);
351
352 let bacteria = &rows[2];
354 assert!((bacteria.frac_clade - 0.6).abs() < 1e-9);
355 assert!((bacteria.frac_direct - 0.01).abs() < 1e-9);
356 assert!((bacteria.frac_descendant - 0.59).abs() < 1e-9);
357 }
358
359 #[test]
360 fn test_fractions_zero_when_no_sequences() {
361 let report = parse(
362 &[
363 standard_line(0.0, 0, 0, "U", 0, "unclassified", 0),
364 standard_line(0.0, 0, 0, "R", 1, "root", 0),
365 ]
366 .join("\n"),
367 );
368 let rows = build_tsv_rows(&report);
369
370 for row in &rows {
371 assert!((row.frac_clade - 0.0).abs() < 1e-9);
372 assert!((row.frac_direct - 0.0).abs() < 1e-9);
373 assert!((row.frac_descendant - 0.0).abs() < 1e-9);
374 }
375 }
376
377 #[test]
378 fn test_basic_field_values() {
379 let report = make_standard_report();
380 let rows = build_tsv_rows(&report);
381
382 assert_eq!(rows[0].tax_id, 0);
383 assert_eq!(rows[0].name, "unclassified");
384 assert_eq!(rows[0].rank, "U");
385 assert_eq!(rows[0].clade_count, 100);
386 assert_eq!(rows[0].direct_count, 100);
387
388 assert_eq!(rows[3].tax_id, 3);
389 assert_eq!(rows[3].name, "Escherichia coli");
390 assert_eq!(rows[3].rank, "S");
391 }
392
393 #[test]
394 fn test_empty_report_produces_no_rows() {
395 let report = parse("");
396 let rows = build_tsv_rows(&report);
397 assert!(rows.is_empty());
398 }
399
400 #[test]
401 fn test_write_tsv_header_and_rows() {
402 let report = parse(
403 &[
404 standard_line(50.0, 5, 5, "U", 0, "unclassified", 0),
405 standard_line(50.0, 5, 5, "R", 1, "root", 0),
406 ]
407 .join("\n"),
408 );
409 let rows = build_tsv_rows(&report);
410 let text = write_rows_to_string(rows);
411 let lines: Vec<&str> = text.lines().collect();
412
413 assert_eq!(lines.len(), 3);
415
416 let header_cols: Vec<&str> = lines[0].split('\t').collect();
418 assert_eq!(header_cols[0], "tax_id");
419 assert_eq!(header_cols[3], "level");
420 assert_eq!(header_cols[6], "clade_count");
421 assert_eq!(header_cols[9], "frac_clade");
422 assert_eq!(header_cols[12], "minimizer_count");
423 assert_eq!(header_cols[13], "distinct_minimizer_count");
424 assert_eq!(header_cols.len(), 14);
425
426 for line in &lines[1..] {
428 assert_eq!(line.split('\t').count(), header_cols.len());
429 }
430 }
431}