Skip to main content

k2tools_lib/commands/
filter.rs

1use std::collections::HashSet;
2use std::fs::File;
3use std::io::{BufWriter, Write};
4use std::path::PathBuf;
5
6use anyhow::{Context, Result};
7use fgoxide::io::Io;
8use fgoxide::iter::IntoChunkedReadAheadIterator;
9use pooled_writer::bgzf::BgzfCompressor;
10use pooled_writer::{PoolBuilder, PooledWriter};
11use seq_io::fastq::{Error as FastqError, OwnedRecord, Reader as FastqReader, Record};
12
13use crate::commands::command::Command;
14use crate::kraken_output::{KrakenOutputReader, KrakenRecord};
15use crate::progress::{ProgressLogger, format_count};
16use crate::report::KrakenReport;
17
18/// Number of records per chunk sent through the read-ahead channel.
19const READ_AHEAD_CHUNK_SIZE: usize = 1024;
20
21/// Number of buffered chunks in the read-ahead channel.
22const READ_AHEAD_NUM_CHUNKS: usize = 1024;
23
24/// Buffer size used when opening input files for reading.
25const IO_BUFFER_SIZE: usize = 512 * 1024;
26
27/// Filter reads from FASTQ files based on kraken2 classification results.
28///
29/// Extracts reads classified to one or more taxon IDs from FASTQ files, using the
30/// kraken2 report (taxonomy tree) and per-read classification output. Supports both
31/// single-end and paired-end reads, and writes bgzf-compressed output.
32///
33/// # Required inputs
34///
35/// The command needs three pieces of data that must all come from the same kraken2 run:
36///
37/// - **`--kraken-report`** (`-r`): The kraken2 report file containing the taxonomy tree
38///   and per-taxon read counts. This is used to resolve taxon IDs, expand descendants,
39///   and estimate the expected number of matching reads.
40/// - **`--kraken-output`** (`-k`): The per-read classification output from kraken2
41///   (generated with `--output`). Each line maps a read name to a taxon ID.
42/// - **`--input`** (`-i`): One FASTQ file for single-end data, or two for paired-end.
43///   Gzip and bgzf compressed inputs are detected and handled automatically.
44///
45/// The kraken output and FASTQ file(s) must contain the same reads in the same order.
46/// The command verifies read name agreement and will error if the files are mismatched
47/// or have different numbers of records.
48///
49/// # Taxon selection
50///
51/// At least one of `--taxon-ids` or `--include-unclassified` must be specified.
52///
53/// - **`--taxon-ids`** (`-t`): One or more NCBI taxon IDs to extract. By default, only
54///   reads classified directly to these exact taxon IDs are included.
55/// - **`--include-descendants`** (`-d`): Expand each taxon ID to include all of its
56///   descendants in the taxonomy tree. For example, specifying a genus-level taxon ID
57///   with `-d` will also extract reads classified to any species or strain within that
58///   genus.
59/// - **`--include-unclassified`** (`-u`): Include reads that kraken2 could not classify
60///   (taxon ID 0). Can be combined with `--taxon-ids` to extract both classified and
61///   unclassified reads in a single pass.
62///
63/// # Output
64///
65/// - **`--output`** (`-o`): Output FASTQ file path(s). Must provide the same number of
66///   output files as input files (one for single-end, two for paired-end). Outputs are
67///   always bgzf-compressed regardless of file extension.
68/// - **`--threads`**: Number of threads used for bgzf compression (default: 4).
69/// - **`--compression-level`**: Bgzf compression level from 0 (fastest) to 9 (smallest),
70///   default 5.
71///
72/// # Examples
73///
74/// Extract all reads classified as _E. coli_ (taxon 562):
75///
76/// ```bash
77/// k2tools filter -r report.txt -k output.txt -i reads.fq.gz -o ecoli.fq.gz -t 562
78/// ```
79///
80/// Extract all Enterobacteriaceae (taxon 543) including every species and strain beneath
81/// it in the taxonomy:
82///
83/// ```bash
84/// k2tools filter -r report.txt -k output.txt \
85///     -i reads.fq.gz -o entero.fq.gz -t 543 -d
86/// ```
87///
88/// Extract unclassified reads from a paired-end run:
89///
90/// ```bash
91/// k2tools filter -r report.txt -k output.txt \
92///     -i r1.fq.gz r2.fq.gz -o unclass_r1.fq.gz unclass_r2.fq.gz -u
93/// ```
94///
95/// Extract human reads plus unclassified in a single pass:
96///
97/// ```bash
98/// k2tools filter -r report.txt -k output.txt \
99///     -i reads.fq.gz -o host_and_unclass.fq.gz -t 9606 -d -u
100/// ```
101#[derive(clap::Args)]
102pub struct Filter {
103    /// Path to the kraken2 report file.
104    #[arg(short = 'r', long)]
105    kraken_report: PathBuf,
106
107    /// Path to the kraken2 per-read classification output.
108    #[arg(short = 'k', long)]
109    kraken_output: PathBuf,
110
111    /// Input FASTQ file(s). One for single-end, two for paired-end.
112    /// Supports gzip/bgzf compressed inputs.
113    #[arg(short, long, num_args = 1..=2, required = true)]
114    input: Vec<PathBuf>,
115
116    /// Output FASTQ file(s). Must match the number of inputs.
117    /// Written with bgzf compression.
118    #[arg(short, long, num_args = 1..=2, required = true)]
119    output: Vec<PathBuf>,
120
121    /// Taxon ID(s) to extract reads for. At least one taxon ID or
122    /// --include-unclassified must be specified.
123    #[arg(short, long, num_args = 1..)]
124    taxon_ids: Vec<u64>,
125
126    /// Include reads assigned to any descendant of the specified taxa.
127    #[arg(short = 'd', long, default_value_t = false)]
128    include_descendants: bool,
129
130    /// Include unclassified reads (taxon ID 0) in the output.
131    #[arg(short = 'u', long, default_value_t = false)]
132    include_unclassified: bool,
133
134    /// Number of threads for bgzf compression.
135    #[arg(long, default_value_t = 4)]
136    threads: usize,
137
138    /// Bgzf compression level (0-9).
139    #[arg(long, default_value_t = 5)]
140    compression_level: u8,
141}
142
143impl Command for Filter {
144    fn execute(&self) -> Result<()> {
145        self.validate_args()?;
146
147        let report = KrakenReport::from_path(&self.kraken_report)?;
148        if report.is_empty() {
149            return self.handle_empty_inputs();
150        }
151
152        let (taxon_set, expected) = build_taxon_set_and_expected_count(
153            &report,
154            &self.taxon_ids,
155            self.include_descendants,
156            self.include_unclassified,
157        )?;
158        log::info!(
159            "Filtering for {} taxa; expecting approximately {} reads",
160            format_count(taxon_set.len() as u64),
161            format_count(expected),
162        );
163
164        let (total, kept) = self.run_filter_pipeline(&taxon_set).map_err(|e| {
165            let banner = "#".repeat(72);
166            let output_paths: Vec<_> =
167                self.output.iter().map(|p| format!("  {}", p.display())).collect();
168            eprintln!(
169                "\n{banner}\n\
170                 # ERROR: invalid inputs detected\n\
171                 #\n\
172                 # {e}\n\
173                 #\n\
174                 # WARNING: partial/invalid output files may have been written to:\n\
175                 # {}\n\
176                 {banner}\n",
177                output_paths.join("\n"),
178            );
179            e
180        })?;
181
182        #[allow(clippy::cast_precision_loss)]
183        let pct = if total > 0 { kept as f64 / total as f64 * 100.0 } else { 0.0 };
184        log::info!(
185            "Kept {} / {} reads ({pct:.2}%), expected {}.",
186            format_count(kept),
187            format_count(total),
188            format_count(expected),
189        );
190
191        Ok(())
192    }
193}
194
195impl Filter {
196    /// Handles the case where kraken2 was run on empty FASTQ files, producing an
197    /// empty report and no kraken output file. Verifies that all FASTQ inputs are
198    /// truly empty, then writes valid empty bgzf output files.
199    fn handle_empty_inputs(&self) -> Result<()> {
200        let io = Io::new(u32::from(self.compression_level), IO_BUFFER_SIZE);
201        for path in &self.input {
202            let reader = io
203                .new_reader(path)
204                .with_context(|| format!("failed to open FASTQ: {}", path.display()))?;
205            let mut fq = FastqReader::new(reader);
206            if fq.next().is_some() {
207                anyhow::bail!(
208                    "kraken2 report is empty but FASTQ input {} contains records; \
209                     inputs are inconsistent",
210                    path.display()
211                );
212            }
213        }
214
215        let (mut pool, writers) = self.build_writer_pool()?;
216        for w in writers {
217            w.close()?;
218        }
219        pool.stop_pool()?;
220
221        log::info!("Report is empty; all inputs are empty. Wrote empty output files.");
222        Ok(())
223    }
224
225    /// Validates command-line arguments beyond what clap enforces.
226    fn validate_args(&self) -> Result<()> {
227        anyhow::ensure!(
228            self.input.len() == self.output.len(),
229            "number of input files ({}) must match number of output files ({})",
230            self.input.len(),
231            self.output.len()
232        );
233        anyhow::ensure!(self.threads >= 1, "threads must be at least 1");
234        anyhow::ensure!(self.compression_level <= 9, "compression level must be 0-9");
235        anyhow::ensure!(
236            !self.taxon_ids.is_empty() || self.include_unclassified,
237            "at least one --taxon-ids value or --include-unclassified must be specified"
238        );
239        Ok(())
240    }
241
242    /// Opens all inputs, creates writers, runs the main filter loop, and closes
243    /// everything down. Returns (total_reads, kept_reads).
244    fn run_filter_pipeline(&self, taxon_set: &HashSet<u64>) -> Result<(u64, u64)> {
245        let io = Io::new(u32::from(self.compression_level), IO_BUFFER_SIZE);
246        let kraken_reader = io.new_reader(&self.kraken_output).with_context(|| {
247            format!("failed to open kraken output: {}", self.kraken_output.display())
248        })?;
249        let mut kraken_iter = KrakenOutputReader::new(kraken_reader)
250            .read_ahead(READ_AHEAD_CHUNK_SIZE, READ_AHEAD_NUM_CHUNKS);
251
252        let is_paired = self.input.len() == 2;
253        let mut fq_iter1 = FastqReader::new(
254            io.new_reader(&self.input[0])
255                .with_context(|| format!("failed to open FASTQ: {}", self.input[0].display()))?,
256        )
257        .into_records()
258        .read_ahead(READ_AHEAD_CHUNK_SIZE, READ_AHEAD_NUM_CHUNKS);
259
260        let mut fq_iter2 = if is_paired {
261            Some(
262                FastqReader::new(io.new_reader(&self.input[1]).with_context(|| {
263                    format!("failed to open FASTQ: {}", self.input[1].display())
264                })?)
265                .into_records()
266                .read_ahead(READ_AHEAD_CHUNK_SIZE, READ_AHEAD_NUM_CHUNKS),
267            )
268        } else {
269            None
270        };
271
272        let (mut pool, mut writers) = self.build_writer_pool()?;
273        let mut progress = ProgressLogger::new("k2tools::filter", "reads", 5_000_000);
274
275        // Run the filter and verification, capturing any error so we can
276        // shut down the pool cleanly before propagating it (avoids panics
277        // in PooledWriter::drop when writers outlive the pool).
278        let result = filter_reads(
279            &mut kraken_iter,
280            &mut fq_iter1,
281            fq_iter2.as_mut(),
282            taxon_set,
283            &mut writers,
284            &mut progress,
285        )
286        .and_then(|(total, kept)| {
287            verify_fastq_exhausted(&mut fq_iter1, fq_iter2.as_mut(), total)?;
288            Ok((total, kept))
289        });
290
291        progress.finish();
292
293        // Always close writers before stopping the pool
294        for w in writers {
295            w.close()?;
296        }
297        pool.stop_pool()?;
298
299        result
300    }
301
302    /// Constructs the bgzf writer pool and exchanges output files into pooled writers.
303    /// Returns (pool, writers) so that destructuring as `let (pool, writers) = ...`
304    /// ensures writers are dropped before the pool (reverse declaration order).
305    fn build_writer_pool(&self) -> Result<(pooled_writer::Pool, Vec<PooledWriter>)> {
306        let mut pool_builder = PoolBuilder::<_, BgzfCompressor>::new()
307            .threads(self.threads)
308            .queue_size(self.threads * 50)
309            .compression_level(self.compression_level)?;
310
311        let mut writers: Vec<PooledWriter> = Vec::new();
312        for path in &self.output {
313            let file = File::create(path)
314                .with_context(|| format!("failed to create output: {}", path.display()))?;
315            writers.push(pool_builder.exchange(BufWriter::new(file)));
316        }
317        let pool = pool_builder.build()?;
318        Ok((pool, writers))
319    }
320}
321
322/// Runs the main filter loop: co-iterates kraken output and FASTQ iterator(s) in
323/// lockstep, writing matching records to the output writers.
324///
325/// Returns (total_reads_processed, reads_kept).
326fn filter_reads(
327    kraken_iter: &mut impl Iterator<Item = Result<KrakenRecord>>,
328    fq_iter1: &mut impl Iterator<Item = Result<OwnedRecord, FastqError>>,
329    mut fq_iter2: Option<&mut impl Iterator<Item = Result<OwnedRecord, FastqError>>>,
330    taxon_set: &HashSet<u64>,
331    writers: &mut [PooledWriter],
332    progress: &mut ProgressLogger,
333) -> Result<(u64, u64)> {
334    let mut total: u64 = 0;
335    let mut kept: u64 = 0;
336
337    for kraken_result in kraken_iter {
338        let kraken_rec = kraken_result?;
339        total += 1;
340        progress.record();
341
342        let fq_rec1 = fq_iter1
343            .next()
344            .context("FASTQ input ended before kraken output")?
345            .with_context(|| format!("failed to read FASTQ record at kraken line {total}"))?;
346
347        let fq_rec2: Option<OwnedRecord> = if let Some(ref mut iter2) = fq_iter2 {
348            Some(
349                iter2
350                    .next()
351                    .context("second FASTQ input ended before kraken output")?
352                    .with_context(|| {
353                        format!("failed to read FASTQ R2 record at kraken line {total}")
354                    })?,
355            )
356        } else {
357            None
358        };
359
360        if taxon_set.contains(&kraken_rec.taxon_id()) {
361            // Validate read names only for matching reads to avoid overhead
362            validate_read_name(kraken_rec.read_name(), fq_rec1.head(), total)?;
363            if let Some(ref rec2) = fq_rec2 {
364                validate_read_name(kraken_rec.read_name(), rec2.head(), total)?;
365            }
366
367            write_fastq_record(&mut writers[0], &fq_rec1)?;
368            if let Some(ref rec2) = fq_rec2 {
369                write_fastq_record(&mut writers[1], rec2)?;
370            }
371            kept += 1;
372        }
373    }
374
375    Ok((total, kept))
376}
377
378/// Verifies that the FASTQ streams are exhausted after the kraken output ends.
379fn verify_fastq_exhausted(
380    fq_iter1: &mut impl Iterator<Item = Result<OwnedRecord, FastqError>>,
381    fq_iter2: Option<&mut impl Iterator<Item = Result<OwnedRecord, FastqError>>>,
382    total: u64,
383) -> Result<()> {
384    if fq_iter1.next().is_some() {
385        anyhow::bail!("FASTQ input has more records than kraken output ({total} kraken records)");
386    }
387    if let Some(iter2) = fq_iter2 {
388        if iter2.next().is_some() {
389            anyhow::bail!(
390                "second FASTQ input has more records than kraken output ({total} kraken records)"
391            );
392        }
393    }
394    Ok(())
395}
396
397/// Builds the set of taxon IDs to filter for and computes the expected number of
398/// matching reads from the report's count fields.
399///
400/// If `include_descendants` is true, expands each taxon ID to include all its
401/// descendants in the report taxonomy tree. If `include_unclassified` is true,
402/// adds taxon ID 0. The expected count uses `clade_count` when descendants are
403/// included, `direct_count` otherwise.
404///
405/// Returns `(taxon_id_set, expected_read_count)`.
406fn build_taxon_set_and_expected_count(
407    report: &KrakenReport,
408    taxon_ids: &[u64],
409    include_descendants: bool,
410    include_unclassified: bool,
411) -> Result<(HashSet<u64>, u64)> {
412    let mut set = HashSet::new();
413    let mut expected: u64 = 0;
414
415    for &tid in taxon_ids {
416        let idx = report
417            .index_of_taxon_id(tid)
418            .with_context(|| format!("taxon ID {tid} not found in report"))?;
419        let row = report.row(idx);
420        set.insert(tid);
421
422        if include_descendants {
423            expected += row.clade_count();
424            for desc_idx in report.descendants(idx) {
425                set.insert(report.row(desc_idx).taxon_id());
426            }
427        } else {
428            expected += row.direct_count();
429        }
430    }
431
432    if include_unclassified {
433        set.insert(0);
434        if let Some(row) = report.get_by_taxon_id(0) {
435            expected += row.clade_count();
436        }
437    }
438
439    Ok((set, expected))
440}
441
442/// Validates that a kraken read name matches a FASTQ record header.
443///
444/// Expects the FASTQ header to start with the kraken read name (byte-for-byte),
445/// optionally followed by `/1` or `/2` (paired-end suffix) and/or whitespace
446/// plus a comment. Avoids scanning the full header — only checks the prefix
447/// at the kraken name length boundary.
448fn validate_read_name(kraken_name: &str, fastq_head: &[u8], line_number: u64) -> Result<()> {
449    let k = kraken_name.as_bytes();
450    let f = fastq_head;
451
452    if f.len() >= k.len() && f[..k.len()] == *k {
453        let rest = &f[k.len()..];
454        if rest.is_empty()
455            || rest[0] == b' '
456            || rest[0] == b'\t'
457            || (rest.len() >= 2
458                && rest[0] == b'/'
459                && (rest[1] == b'1' || rest[1] == b'2')
460                && (rest.len() == 2 || rest[2] == b' ' || rest[2] == b'\t'))
461        {
462            return Ok(());
463        }
464    }
465
466    // Build a readable FASTQ name for the error message only on failure
467    let name_end = f.iter().position(|&b| b == b' ' || b == b'\t').unwrap_or(f.len());
468    anyhow::bail!(
469        "read name mismatch at kraken line {line_number}: \
470         kraken={kraken_name:?}, FASTQ={:?}",
471        String::from_utf8_lossy(&f[..name_end])
472    );
473}
474
475/// Writes a single FASTQ record to a writer.
476fn write_fastq_record<W: Write>(writer: &mut W, rec: &impl Record) -> Result<()> {
477    writer.write_all(b"@")?;
478    writer.write_all(rec.head())?;
479    writer.write_all(b"\n")?;
480    writer.write_all(rec.seq())?;
481    writer.write_all(b"\n+\n")?;
482    writer.write_all(rec.qual())?;
483    writer.write_all(b"\n")?;
484    Ok(())
485}
486
487#[cfg(test)]
488mod tests {
489    use super::*;
490
491    fn make_report() -> KrakenReport {
492        // unclassified(0), root(1), Bacteria(2), E.coli(3), Eukaryota(4), Human(5)
493        let lines = [
494            " 10.00\t100\t100\tU\t0\tunclassified",
495            " 90.00\t900\t5\tR\t1\troot",
496            " 60.00\t600\t10\tD\t2\t  Bacteria",
497            " 50.00\t500\t500\tS\t3\t    Escherichia coli",
498            " 30.00\t300\t10\tD\t4\t  Eukaryota",
499            " 20.00\t200\t200\tS\t5\t    Homo sapiens",
500        ]
501        .join("\n");
502        KrakenReport::from_reader(lines.as_bytes()).unwrap()
503    }
504
505    #[test]
506    fn test_build_taxon_set_exact() {
507        let report = make_report();
508        let (set, expected) =
509            build_taxon_set_and_expected_count(&report, &[3], false, false).unwrap();
510        assert_eq!(set, HashSet::from([3]));
511        assert_eq!(expected, 500);
512    }
513
514    #[test]
515    fn test_build_taxon_set_with_descendants() {
516        let report = make_report();
517        let (set, expected) =
518            build_taxon_set_and_expected_count(&report, &[2], true, false).unwrap();
519        assert_eq!(set, HashSet::from([2, 3]));
520        assert_eq!(expected, 600);
521    }
522
523    #[test]
524    fn test_build_taxon_set_with_descendants_root() {
525        let report = make_report();
526        let (set, expected) =
527            build_taxon_set_and_expected_count(&report, &[1], true, false).unwrap();
528        assert_eq!(set, HashSet::from([1, 2, 3, 4, 5]));
529        assert_eq!(expected, 900);
530    }
531
532    #[test]
533    fn test_build_taxon_set_unknown_taxon() {
534        let report = make_report();
535        let result = build_taxon_set_and_expected_count(&report, &[99999], false, false);
536        assert!(result.is_err());
537    }
538
539    #[test]
540    fn test_build_taxon_set_include_unclassified() {
541        let report = make_report();
542        let (set, expected) =
543            build_taxon_set_and_expected_count(&report, &[3], false, true).unwrap();
544        assert_eq!(set, HashSet::from([0, 3]));
545        assert_eq!(expected, 600);
546    }
547
548    #[test]
549    fn test_build_taxon_set_only_unclassified() {
550        let report = make_report();
551        let (set, expected) =
552            build_taxon_set_and_expected_count(&report, &[], false, true).unwrap();
553        assert_eq!(set, HashSet::from([0]));
554        assert_eq!(expected, 100);
555    }
556
557    #[test]
558    fn test_expected_count_with_descendants() {
559        let report = make_report();
560        let (_, expected) = build_taxon_set_and_expected_count(&report, &[2], true, false).unwrap();
561        assert_eq!(expected, 600);
562    }
563
564    #[test]
565    fn test_expected_count_without_descendants() {
566        let report = make_report();
567        let (_, expected) =
568            build_taxon_set_and_expected_count(&report, &[2], false, false).unwrap();
569        assert_eq!(expected, 10);
570    }
571
572    #[test]
573    fn test_expected_count_with_unclassified() {
574        let report = make_report();
575        let (_, expected) = build_taxon_set_and_expected_count(&report, &[3], false, true).unwrap();
576        assert_eq!(expected, 600);
577    }
578
579    #[test]
580    fn test_validate_read_name_match() {
581        assert!(validate_read_name("read1", b"read1", 1).is_ok());
582    }
583
584    #[test]
585    fn test_validate_read_name_mismatch() {
586        assert!(validate_read_name("read1", b"read2", 1).is_err());
587    }
588
589    #[test]
590    fn test_validate_read_name_strip_suffix_1() {
591        assert!(validate_read_name("read1", b"read1/1", 1).is_ok());
592    }
593
594    #[test]
595    fn test_validate_read_name_strip_suffix_2() {
596        assert!(validate_read_name("read1", b"read1/2", 1).is_ok());
597    }
598
599    #[test]
600    fn test_validate_read_name_with_comment() {
601        assert!(validate_read_name("read1", b"read1 length=150", 1).is_ok());
602    }
603
604    #[test]
605    fn test_validate_read_name_suffix_and_comment() {
606        assert!(validate_read_name("read1", b"read1/1 length=150", 1).is_ok());
607    }
608
609    #[test]
610    fn test_validate_args_mismatched_counts() {
611        let filter = Filter {
612            kraken_report: PathBuf::from("r.txt"),
613            kraken_output: PathBuf::from("k.txt"),
614            input: vec![PathBuf::from("a.fq"), PathBuf::from("b.fq")],
615            output: vec![PathBuf::from("c.fq")],
616            taxon_ids: vec![1],
617            include_descendants: false,
618            include_unclassified: false,
619            threads: 4,
620            compression_level: 6,
621        };
622        assert!(filter.validate_args().is_err());
623    }
624
625    #[test]
626    fn test_validate_args_no_taxa_or_unclassified() {
627        let filter = Filter {
628            kraken_report: PathBuf::from("r.txt"),
629            kraken_output: PathBuf::from("k.txt"),
630            input: vec![PathBuf::from("a.fq")],
631            output: vec![PathBuf::from("b.fq")],
632            taxon_ids: vec![],
633            include_descendants: false,
634            include_unclassified: false,
635            threads: 4,
636            compression_level: 6,
637        };
638        assert!(filter.validate_args().is_err());
639    }
640}