Skip to main content

k2tools_lib/commands/
report_to_tsv.rs

1use std::fs::File;
2use std::io::{BufWriter, Write, stdout};
3use std::path::PathBuf;
4
5use anyhow::Result;
6use fgoxide::io::DelimFileWriter;
7use serde::Serialize;
8
9use crate::commands::command::Command;
10use crate::report::KrakenReport;
11
12/// Convert a kraken2 report file to a clean, header-bearing TSV.
13///
14/// Reads a kraken2 report (standard 6-column or extended 8-column with minimizer
15/// data) and writes a tab-separated file with clearly named columns, derived
16/// parent information, descendant counts, and fraction columns.
17///
18/// Output columns:
19///
20///   tax_id                   - NCBI taxonomy ID
21///   name                     - Scientific name
22///   rank                     - Taxonomic rank code (e.g. S, G, D1)
23///   level                    - Depth in taxonomy (0 for root/unclassified)
24///   parent_tax_id            - Parent taxon ID (empty for root/unclassified)
25///   parent_rank              - Parent rank code (empty for root/unclassified)
26///   clade_count              - Fragments in clade rooted at this taxon
27///   direct_count             - Fragments assigned directly to this taxon
28///   descendant_count         - clade_count minus direct_count
29///   frac_clade               - clade_count / total_sequences
30///   frac_direct              - direct_count / total_sequences
31///   frac_descendant          - descendant_count / total_sequences
32///   minimizer_count          - Minimizers in clade (empty if not in report)
33///   distinct_minimizer_count - Distinct minimizers (empty if not in report)
34///
35/// Examples:
36///
37///   k2tools report-to-tsv -r kraken2_report.txt -o report.tsv
38///   k2tools report-to-tsv -r kraken2_report.txt   # writes to stdout
39#[derive(clap::Args)]
40pub struct ReportToTsv {
41    /// Path to the kraken2 report file.
42    #[arg(short = 'r', long)]
43    kraken_report: PathBuf,
44
45    /// Output TSV file path. Defaults to stdout.
46    #[arg(short, long)]
47    output: Option<PathBuf>,
48}
49
50/// One row of TSV output, assembled from the report before writing.
51#[derive(Default, Serialize)]
52struct TsvRow {
53    tax_id: u64,
54    name: String,
55    rank: String,
56    level: usize,
57    parent_tax_id: String,
58    parent_rank: String,
59    clade_count: u64,
60    direct_count: u64,
61    descendant_count: u64,
62    frac_clade: f64,
63    frac_direct: f64,
64    frac_descendant: f64,
65    minimizer_count: String,
66    distinct_minimizer_count: String,
67}
68
69/// Derives the TSV header line from `TsvRow`'s field names by serializing a default
70/// row through a csv writer and extracting just the header.
71fn tsv_header() -> String {
72    let mut csv_writer =
73        csv::WriterBuilder::new().delimiter(b'\t').has_headers(true).from_writer(Vec::new());
74    csv_writer.serialize(TsvRow::default()).unwrap();
75    csv_writer.flush().unwrap();
76    let bytes = csv_writer.into_inner().unwrap();
77    let text = String::from_utf8(bytes).unwrap();
78    text.lines().next().unwrap().to_string()
79}
80
81impl Command for ReportToTsv {
82    fn execute(&self) -> Result<()> {
83        let report = KrakenReport::from_path(&self.kraken_report)?;
84        let rows = build_tsv_rows(&report);
85
86        let writer: BufWriter<Box<dyn Write + Send>> = match &self.output {
87            Some(path) => {
88                let file = File::create(path).map_err(|e| {
89                    anyhow::anyhow!("failed to create output {}: {e}", path.display())
90                })?;
91                BufWriter::new(Box::new(file))
92            }
93            None => BufWriter::new(Box::new(stdout())),
94        };
95
96        if rows.is_empty() {
97            // csv::Writer only emits headers on the first serialize call, so with
98            // zero rows we derive the header from TsvRow's field names directly.
99            let mut w = writer;
100            writeln!(w, "{}", tsv_header())?;
101            w.flush()?;
102        } else {
103            let mut tsv_writer = DelimFileWriter::new(writer, b'\t', true);
104            tsv_writer.write_all(rows)?;
105        }
106
107        log::info!("Wrote {} rows to TSV.", report.len());
108        Ok(())
109    }
110}
111
112/// Builds `TsvRow` structs from every row in the report.
113fn build_tsv_rows(report: &KrakenReport) -> Vec<TsvRow> {
114    let total_sequences = report.total_sequences();
115    let has_minimizer_data = report.has_minimizer_data();
116    let mut tsv_rows = Vec::with_capacity(report.len());
117
118    for (i, row) in report.rows().iter().enumerate() {
119        let clade_count = row.clade_count();
120        let direct_count = row.direct_count();
121        let descendant_count = clade_count - direct_count;
122
123        #[allow(clippy::cast_precision_loss)]
124        let (frac_clade, frac_direct, frac_descendant) = if total_sequences > 0 {
125            (
126                clade_count as f64 / total_sequences as f64,
127                direct_count as f64 / total_sequences as f64,
128                descendant_count as f64 / total_sequences as f64,
129            )
130        } else {
131            (0.0, 0.0, 0.0)
132        };
133
134        let (minimizer_count, distinct_minimizer_count) = if has_minimizer_data {
135            let mc = row.minimizer_count().unwrap_or(0);
136            let dmc = row.distinct_minimizer_count().unwrap_or(0);
137            (mc.to_string(), dmc.to_string())
138        } else {
139            (String::new(), String::new())
140        };
141
142        let (parent_tax_id, parent_rank) = match report.parent(i) {
143            Some(parent) => (parent.taxon_id().to_string(), parent.taxonomic_rank().to_string()),
144            None => (String::new(), String::new()),
145        };
146
147        tsv_rows.push(TsvRow {
148            tax_id: row.taxon_id(),
149            name: row.name().to_string(),
150            rank: row.taxonomic_rank().to_string(),
151            level: row.depth(),
152            parent_tax_id,
153            parent_rank,
154            clade_count,
155            direct_count,
156            descendant_count,
157            frac_clade,
158            frac_direct,
159            frac_descendant,
160            minimizer_count,
161            distinct_minimizer_count,
162        });
163    }
164
165    tsv_rows
166}
167
168#[cfg(test)]
169mod tests {
170    use super::*;
171
172    /// Builds a standard (6-column) report string from structured parameters.
173    fn standard_line(
174        pct: f64,
175        clade: u64,
176        direct: u64,
177        rank: &str,
178        taxid: u64,
179        name: &str,
180        depth: usize,
181    ) -> String {
182        let indent = " ".repeat(depth * 2);
183        format!("{pct:.2}\t{clade}\t{direct}\t{rank}\t{taxid}\t{indent}{name}")
184    }
185
186    /// Builds an extended (8-column) report string with minimizer columns.
187    #[allow(clippy::too_many_arguments)]
188    fn extended_line(
189        pct: f64,
190        clade: u64,
191        direct: u64,
192        minimizers: u64,
193        distinct: u64,
194        rank: &str,
195        taxid: u64,
196        name: &str,
197        depth: usize,
198    ) -> String {
199        let indent = " ".repeat(depth * 2);
200        format!(
201            "{pct:.2}\t{clade}\t{direct}\t{minimizers}\t{distinct}\t{rank}\t{taxid}\t{indent}{name}"
202        )
203    }
204
205    fn parse(report: &str) -> KrakenReport {
206        KrakenReport::from_reader(report.as_bytes()).unwrap()
207    }
208
209    fn make_standard_report() -> KrakenReport {
210        parse(
211            &[
212                standard_line(10.0, 100, 100, "U", 0, "unclassified", 0),
213                standard_line(90.0, 900, 5, "R", 1, "root", 0),
214                standard_line(60.0, 600, 10, "D", 2, "Bacteria", 1),
215                standard_line(50.0, 500, 500, "S", 3, "Escherichia coli", 2),
216                standard_line(30.0, 300, 10, "D", 4, "Eukaryota", 1),
217                standard_line(20.0, 200, 200, "S", 5, "Homo sapiens", 2),
218            ]
219            .join("\n"),
220        )
221    }
222
223    fn make_extended_report() -> KrakenReport {
224        parse(
225            &[
226                extended_line(10.0, 100, 100, 0, 0, "U", 0, "unclassified", 0),
227                extended_line(90.0, 900, 5, 500, 400, "R", 1, "root", 0),
228                extended_line(60.0, 600, 10, 300, 250, "D", 2, "Bacteria", 1),
229                extended_line(50.0, 500, 500, 200, 150, "S", 3, "Escherichia coli", 2),
230                extended_line(30.0, 300, 10, 200, 150, "D", 4, "Eukaryota", 1),
231                extended_line(20.0, 200, 200, 100, 80, "S", 5, "Homo sapiens", 2),
232            ]
233            .join("\n"),
234        )
235    }
236
237    /// Writes rows to a temp file using `DelimFile::write_tsv` and returns the TSV text.
238    fn write_rows_to_string(rows: Vec<TsvRow>) -> String {
239        let dir = tempfile::TempDir::new().unwrap();
240        let path = dir.path().join("out.tsv");
241        let df = fgoxide::io::DelimFile::default();
242        df.write_tsv(&path, rows).unwrap();
243        std::fs::read_to_string(path).unwrap()
244    }
245
246    #[test]
247    fn test_level_values() {
248        let report = make_standard_report();
249        let rows = build_tsv_rows(&report);
250
251        assert_eq!(rows[0].level, 0); // unclassified
252        assert_eq!(rows[1].level, 0); // root
253        assert_eq!(rows[2].level, 1); // Bacteria (child of root)
254        assert_eq!(rows[3].level, 2); // E. coli (child of Bacteria)
255        assert_eq!(rows[4].level, 1); // Eukaryota (child of root)
256        assert_eq!(rows[5].level, 2); // Homo sapiens (child of Eukaryota)
257    }
258
259    #[test]
260    fn test_parent_fields_empty_for_root_and_unclassified() {
261        let report = make_standard_report();
262        let rows = build_tsv_rows(&report);
263
264        // unclassified (depth 0, no parent)
265        assert_eq!(rows[0].parent_tax_id, "");
266        assert_eq!(rows[0].parent_rank, "");
267
268        // root (depth 0, no parent)
269        assert_eq!(rows[1].parent_tax_id, "");
270        assert_eq!(rows[1].parent_rank, "");
271    }
272
273    #[test]
274    fn test_parent_fields_populated_for_children() {
275        let report = make_standard_report();
276        let rows = build_tsv_rows(&report);
277
278        // Bacteria -> parent is root (tax_id 1, rank R)
279        assert_eq!(rows[2].parent_tax_id, "1");
280        assert_eq!(rows[2].parent_rank, "R");
281
282        // E. coli -> parent is Bacteria (tax_id 2, rank D)
283        assert_eq!(rows[3].parent_tax_id, "2");
284        assert_eq!(rows[3].parent_rank, "D");
285
286        // Homo sapiens -> parent is Eukaryota (tax_id 4, rank D)
287        assert_eq!(rows[5].parent_tax_id, "4");
288        assert_eq!(rows[5].parent_rank, "D");
289    }
290
291    #[test]
292    fn test_descendant_count() {
293        let report = make_standard_report();
294        let rows = build_tsv_rows(&report);
295
296        // unclassified: clade=100, direct=100, desc=0
297        assert_eq!(rows[0].descendant_count, 0);
298
299        // root: clade=900, direct=5, desc=895
300        assert_eq!(rows[1].descendant_count, 895);
301
302        // Bacteria: clade=600, direct=10, desc=590
303        assert_eq!(rows[2].descendant_count, 590);
304
305        // E. coli: clade=500, direct=500, desc=0
306        assert_eq!(rows[3].descendant_count, 0);
307    }
308
309    #[test]
310    fn test_minimizer_columns_empty_for_standard_report() {
311        let report = make_standard_report();
312        let rows = build_tsv_rows(&report);
313
314        for row in &rows {
315            assert_eq!(row.minimizer_count, "");
316            assert_eq!(row.distinct_minimizer_count, "");
317        }
318    }
319
320    #[test]
321    fn test_minimizer_columns_populated_for_extended_report() {
322        let report = make_extended_report();
323        let rows = build_tsv_rows(&report);
324
325        // Root row: minimizers=500, distinct=400
326        assert_eq!(rows[1].minimizer_count, "500");
327        assert_eq!(rows[1].distinct_minimizer_count, "400");
328
329        // Bacteria: minimizers=300, distinct=250
330        assert_eq!(rows[2].minimizer_count, "300");
331        assert_eq!(rows[2].distinct_minimizer_count, "250");
332
333        // All minimizer count fields should be non-empty
334        for row in &rows {
335            assert!(!row.minimizer_count.is_empty());
336            assert!(!row.distinct_minimizer_count.is_empty());
337        }
338    }
339
340    #[test]
341    fn test_fraction_calculations() {
342        let report = make_standard_report();
343        let rows = build_tsv_rows(&report);
344
345        // total_sequences = 100 (unclassified) + 900 (root) = 1000
346        // E. coli: clade=500, direct=500, descendant=0
347        let ecoli = &rows[3];
348        assert!((ecoli.frac_clade - 0.5).abs() < 1e-9);
349        assert!((ecoli.frac_direct - 0.5).abs() < 1e-9);
350        assert!((ecoli.frac_descendant - 0.0).abs() < 1e-9);
351
352        // Bacteria: clade=600, direct=10, descendant=590
353        let bacteria = &rows[2];
354        assert!((bacteria.frac_clade - 0.6).abs() < 1e-9);
355        assert!((bacteria.frac_direct - 0.01).abs() < 1e-9);
356        assert!((bacteria.frac_descendant - 0.59).abs() < 1e-9);
357    }
358
359    #[test]
360    fn test_fractions_zero_when_no_sequences() {
361        let report = parse(
362            &[
363                standard_line(0.0, 0, 0, "U", 0, "unclassified", 0),
364                standard_line(0.0, 0, 0, "R", 1, "root", 0),
365            ]
366            .join("\n"),
367        );
368        let rows = build_tsv_rows(&report);
369
370        for row in &rows {
371            assert!((row.frac_clade - 0.0).abs() < 1e-9);
372            assert!((row.frac_direct - 0.0).abs() < 1e-9);
373            assert!((row.frac_descendant - 0.0).abs() < 1e-9);
374        }
375    }
376
377    #[test]
378    fn test_basic_field_values() {
379        let report = make_standard_report();
380        let rows = build_tsv_rows(&report);
381
382        assert_eq!(rows[0].tax_id, 0);
383        assert_eq!(rows[0].name, "unclassified");
384        assert_eq!(rows[0].rank, "U");
385        assert_eq!(rows[0].clade_count, 100);
386        assert_eq!(rows[0].direct_count, 100);
387
388        assert_eq!(rows[3].tax_id, 3);
389        assert_eq!(rows[3].name, "Escherichia coli");
390        assert_eq!(rows[3].rank, "S");
391    }
392
393    #[test]
394    fn test_empty_report_produces_no_rows() {
395        let report = parse("");
396        let rows = build_tsv_rows(&report);
397        assert!(rows.is_empty());
398    }
399
400    #[test]
401    fn test_write_tsv_header_and_rows() {
402        let report = parse(
403            &[
404                standard_line(50.0, 5, 5, "U", 0, "unclassified", 0),
405                standard_line(50.0, 5, 5, "R", 1, "root", 0),
406            ]
407            .join("\n"),
408        );
409        let rows = build_tsv_rows(&report);
410        let text = write_rows_to_string(rows);
411        let lines: Vec<&str> = text.lines().collect();
412
413        // Header + 2 data rows
414        assert_eq!(lines.len(), 3);
415
416        // Header should match field names
417        let header_cols: Vec<&str> = lines[0].split('\t').collect();
418        assert_eq!(header_cols[0], "tax_id");
419        assert_eq!(header_cols[3], "level");
420        assert_eq!(header_cols[6], "clade_count");
421        assert_eq!(header_cols[9], "frac_clade");
422        assert_eq!(header_cols[12], "minimizer_count");
423        assert_eq!(header_cols[13], "distinct_minimizer_count");
424        assert_eq!(header_cols.len(), 14);
425
426        // Verify column count consistency
427        for line in &lines[1..] {
428            assert_eq!(line.split('\t').count(), header_cols.len());
429        }
430    }
431}