1use std::collections::HashSet;
2use std::fs::File;
3use std::io::{BufWriter, Write};
4use std::path::PathBuf;
5
6use anyhow::{Context, Result};
7use fgoxide::io::Io;
8use fgoxide::iter::IntoChunkedReadAheadIterator;
9use pooled_writer::bgzf::BgzfCompressor;
10use pooled_writer::{PoolBuilder, PooledWriter};
11use seq_io::fastq::{Error as FastqError, OwnedRecord, Reader as FastqReader, Record};
12
13use crate::commands::command::Command;
14use crate::kraken_output::{KrakenOutputReader, KrakenRecord};
15use crate::progress::{ProgressLogger, format_count};
16use crate::report::KrakenReport;
17
18const READ_AHEAD_CHUNK_SIZE: usize = 1024;
20
21const READ_AHEAD_NUM_CHUNKS: usize = 1024;
23
24const IO_BUFFER_SIZE: usize = 512 * 1024;
26
27#[derive(clap::Args)]
102pub struct Filter {
103 #[arg(short = 'r', long)]
105 kraken_report: PathBuf,
106
107 #[arg(short = 'k', long)]
109 kraken_output: PathBuf,
110
111 #[arg(short, long, num_args = 1..=2, required = true)]
114 input: Vec<PathBuf>,
115
116 #[arg(short, long, num_args = 1..=2, required = true)]
119 output: Vec<PathBuf>,
120
121 #[arg(short, long, num_args = 1..)]
124 taxon_ids: Vec<u64>,
125
126 #[arg(short = 'd', long, default_value_t = false)]
128 include_descendants: bool,
129
130 #[arg(short = 'u', long, default_value_t = false)]
132 include_unclassified: bool,
133
134 #[arg(long, default_value_t = 4)]
136 threads: usize,
137
138 #[arg(long, default_value_t = 5)]
140 compression_level: u8,
141}
142
143impl Command for Filter {
144 fn execute(&self) -> Result<()> {
145 self.validate_args()?;
146
147 let report = KrakenReport::from_path(&self.kraken_report)?;
148 if report.is_empty() {
149 return self.handle_empty_inputs();
150 }
151
152 let (taxon_set, expected) = build_taxon_set_and_expected_count(
153 &report,
154 &self.taxon_ids,
155 self.include_descendants,
156 self.include_unclassified,
157 )?;
158 log::info!(
159 "Filtering for {} taxa; expecting approximately {} reads",
160 format_count(taxon_set.len() as u64),
161 format_count(expected),
162 );
163
164 let (total, kept) = self.run_filter_pipeline(&taxon_set).map_err(|e| {
165 let banner = "#".repeat(72);
166 let output_paths: Vec<_> =
167 self.output.iter().map(|p| format!(" {}", p.display())).collect();
168 eprintln!(
169 "\n{banner}\n\
170 # ERROR: invalid inputs detected\n\
171 #\n\
172 # {e}\n\
173 #\n\
174 # WARNING: partial/invalid output files may have been written to:\n\
175 # {}\n\
176 {banner}\n",
177 output_paths.join("\n"),
178 );
179 e
180 })?;
181
182 #[allow(clippy::cast_precision_loss)]
183 let pct = if total > 0 { kept as f64 / total as f64 * 100.0 } else { 0.0 };
184 log::info!(
185 "Kept {} / {} reads ({pct:.2}%), expected {}.",
186 format_count(kept),
187 format_count(total),
188 format_count(expected),
189 );
190
191 Ok(())
192 }
193}
194
195impl Filter {
196 fn handle_empty_inputs(&self) -> Result<()> {
200 let io = Io::new(u32::from(self.compression_level), IO_BUFFER_SIZE);
201 for path in &self.input {
202 let reader = io
203 .new_reader(path)
204 .with_context(|| format!("failed to open FASTQ: {}", path.display()))?;
205 let mut fq = FastqReader::new(reader);
206 if fq.next().is_some() {
207 anyhow::bail!(
208 "kraken2 report is empty but FASTQ input {} contains records; \
209 inputs are inconsistent",
210 path.display()
211 );
212 }
213 }
214
215 let (mut pool, writers) = self.build_writer_pool()?;
216 for w in writers {
217 w.close()?;
218 }
219 pool.stop_pool()?;
220
221 log::info!("Report is empty; all inputs are empty. Wrote empty output files.");
222 Ok(())
223 }
224
225 fn validate_args(&self) -> Result<()> {
227 anyhow::ensure!(
228 self.input.len() == self.output.len(),
229 "number of input files ({}) must match number of output files ({})",
230 self.input.len(),
231 self.output.len()
232 );
233 anyhow::ensure!(self.threads >= 1, "threads must be at least 1");
234 anyhow::ensure!(self.compression_level <= 9, "compression level must be 0-9");
235 anyhow::ensure!(
236 !self.taxon_ids.is_empty() || self.include_unclassified,
237 "at least one --taxon-ids value or --include-unclassified must be specified"
238 );
239 Ok(())
240 }
241
242 fn run_filter_pipeline(&self, taxon_set: &HashSet<u64>) -> Result<(u64, u64)> {
245 let io = Io::new(u32::from(self.compression_level), IO_BUFFER_SIZE);
246 let kraken_reader = io.new_reader(&self.kraken_output).with_context(|| {
247 format!("failed to open kraken output: {}", self.kraken_output.display())
248 })?;
249 let mut kraken_iter = KrakenOutputReader::new(kraken_reader)
250 .read_ahead(READ_AHEAD_CHUNK_SIZE, READ_AHEAD_NUM_CHUNKS);
251
252 let is_paired = self.input.len() == 2;
253 let mut fq_iter1 = FastqReader::new(
254 io.new_reader(&self.input[0])
255 .with_context(|| format!("failed to open FASTQ: {}", self.input[0].display()))?,
256 )
257 .into_records()
258 .read_ahead(READ_AHEAD_CHUNK_SIZE, READ_AHEAD_NUM_CHUNKS);
259
260 let mut fq_iter2 = if is_paired {
261 Some(
262 FastqReader::new(io.new_reader(&self.input[1]).with_context(|| {
263 format!("failed to open FASTQ: {}", self.input[1].display())
264 })?)
265 .into_records()
266 .read_ahead(READ_AHEAD_CHUNK_SIZE, READ_AHEAD_NUM_CHUNKS),
267 )
268 } else {
269 None
270 };
271
272 let (mut pool, mut writers) = self.build_writer_pool()?;
273 let mut progress = ProgressLogger::new("k2tools::filter", "reads", 5_000_000);
274
275 let result = filter_reads(
279 &mut kraken_iter,
280 &mut fq_iter1,
281 fq_iter2.as_mut(),
282 taxon_set,
283 &mut writers,
284 &mut progress,
285 )
286 .and_then(|(total, kept)| {
287 verify_fastq_exhausted(&mut fq_iter1, fq_iter2.as_mut(), total)?;
288 Ok((total, kept))
289 });
290
291 progress.finish();
292
293 for w in writers {
295 w.close()?;
296 }
297 pool.stop_pool()?;
298
299 result
300 }
301
302 fn build_writer_pool(&self) -> Result<(pooled_writer::Pool, Vec<PooledWriter>)> {
306 let mut pool_builder = PoolBuilder::<_, BgzfCompressor>::new()
307 .threads(self.threads)
308 .queue_size(self.threads * 50)
309 .compression_level(self.compression_level)?;
310
311 let mut writers: Vec<PooledWriter> = Vec::new();
312 for path in &self.output {
313 let file = File::create(path)
314 .with_context(|| format!("failed to create output: {}", path.display()))?;
315 writers.push(pool_builder.exchange(BufWriter::new(file)));
316 }
317 let pool = pool_builder.build()?;
318 Ok((pool, writers))
319 }
320}
321
322fn filter_reads(
327 kraken_iter: &mut impl Iterator<Item = Result<KrakenRecord>>,
328 fq_iter1: &mut impl Iterator<Item = Result<OwnedRecord, FastqError>>,
329 mut fq_iter2: Option<&mut impl Iterator<Item = Result<OwnedRecord, FastqError>>>,
330 taxon_set: &HashSet<u64>,
331 writers: &mut [PooledWriter],
332 progress: &mut ProgressLogger,
333) -> Result<(u64, u64)> {
334 let mut total: u64 = 0;
335 let mut kept: u64 = 0;
336
337 for kraken_result in kraken_iter {
338 let kraken_rec = kraken_result?;
339 total += 1;
340 progress.record();
341
342 let fq_rec1 = fq_iter1
343 .next()
344 .context("FASTQ input ended before kraken output")?
345 .with_context(|| format!("failed to read FASTQ record at kraken line {total}"))?;
346
347 let fq_rec2: Option<OwnedRecord> = if let Some(ref mut iter2) = fq_iter2 {
348 Some(
349 iter2
350 .next()
351 .context("second FASTQ input ended before kraken output")?
352 .with_context(|| {
353 format!("failed to read FASTQ R2 record at kraken line {total}")
354 })?,
355 )
356 } else {
357 None
358 };
359
360 if taxon_set.contains(&kraken_rec.taxon_id()) {
361 validate_read_name(kraken_rec.read_name(), fq_rec1.head(), total)?;
363 if let Some(ref rec2) = fq_rec2 {
364 validate_read_name(kraken_rec.read_name(), rec2.head(), total)?;
365 }
366
367 write_fastq_record(&mut writers[0], &fq_rec1)?;
368 if let Some(ref rec2) = fq_rec2 {
369 write_fastq_record(&mut writers[1], rec2)?;
370 }
371 kept += 1;
372 }
373 }
374
375 Ok((total, kept))
376}
377
378fn verify_fastq_exhausted(
380 fq_iter1: &mut impl Iterator<Item = Result<OwnedRecord, FastqError>>,
381 fq_iter2: Option<&mut impl Iterator<Item = Result<OwnedRecord, FastqError>>>,
382 total: u64,
383) -> Result<()> {
384 if fq_iter1.next().is_some() {
385 anyhow::bail!("FASTQ input has more records than kraken output ({total} kraken records)");
386 }
387 if let Some(iter2) = fq_iter2 {
388 if iter2.next().is_some() {
389 anyhow::bail!(
390 "second FASTQ input has more records than kraken output ({total} kraken records)"
391 );
392 }
393 }
394 Ok(())
395}
396
397fn build_taxon_set_and_expected_count(
407 report: &KrakenReport,
408 taxon_ids: &[u64],
409 include_descendants: bool,
410 include_unclassified: bool,
411) -> Result<(HashSet<u64>, u64)> {
412 let mut set = HashSet::new();
413 let mut expected: u64 = 0;
414
415 for &tid in taxon_ids {
416 let idx = report
417 .index_of_taxon_id(tid)
418 .with_context(|| format!("taxon ID {tid} not found in report"))?;
419 let row = report.row(idx);
420 set.insert(tid);
421
422 if include_descendants {
423 expected += row.clade_count();
424 for desc_idx in report.descendants(idx) {
425 set.insert(report.row(desc_idx).taxon_id());
426 }
427 } else {
428 expected += row.direct_count();
429 }
430 }
431
432 if include_unclassified {
433 set.insert(0);
434 if let Some(row) = report.get_by_taxon_id(0) {
435 expected += row.clade_count();
436 }
437 }
438
439 Ok((set, expected))
440}
441
442fn validate_read_name(kraken_name: &str, fastq_head: &[u8], line_number: u64) -> Result<()> {
449 let k = kraken_name.as_bytes();
450 let f = fastq_head;
451
452 if f.len() >= k.len() && f[..k.len()] == *k {
453 let rest = &f[k.len()..];
454 if rest.is_empty()
455 || rest[0] == b' '
456 || rest[0] == b'\t'
457 || (rest.len() >= 2
458 && rest[0] == b'/'
459 && (rest[1] == b'1' || rest[1] == b'2')
460 && (rest.len() == 2 || rest[2] == b' ' || rest[2] == b'\t'))
461 {
462 return Ok(());
463 }
464 }
465
466 let name_end = f.iter().position(|&b| b == b' ' || b == b'\t').unwrap_or(f.len());
468 anyhow::bail!(
469 "read name mismatch at kraken line {line_number}: \
470 kraken={kraken_name:?}, FASTQ={:?}",
471 String::from_utf8_lossy(&f[..name_end])
472 );
473}
474
475fn write_fastq_record<W: Write>(writer: &mut W, rec: &impl Record) -> Result<()> {
477 writer.write_all(b"@")?;
478 writer.write_all(rec.head())?;
479 writer.write_all(b"\n")?;
480 writer.write_all(rec.seq())?;
481 writer.write_all(b"\n+\n")?;
482 writer.write_all(rec.qual())?;
483 writer.write_all(b"\n")?;
484 Ok(())
485}
486
487#[cfg(test)]
488mod tests {
489 use super::*;
490
491 fn make_report() -> KrakenReport {
492 let lines = [
494 " 10.00\t100\t100\tU\t0\tunclassified",
495 " 90.00\t900\t5\tR\t1\troot",
496 " 60.00\t600\t10\tD\t2\t Bacteria",
497 " 50.00\t500\t500\tS\t3\t Escherichia coli",
498 " 30.00\t300\t10\tD\t4\t Eukaryota",
499 " 20.00\t200\t200\tS\t5\t Homo sapiens",
500 ]
501 .join("\n");
502 KrakenReport::from_reader(lines.as_bytes()).unwrap()
503 }
504
505 #[test]
506 fn test_build_taxon_set_exact() {
507 let report = make_report();
508 let (set, expected) =
509 build_taxon_set_and_expected_count(&report, &[3], false, false).unwrap();
510 assert_eq!(set, HashSet::from([3]));
511 assert_eq!(expected, 500);
512 }
513
514 #[test]
515 fn test_build_taxon_set_with_descendants() {
516 let report = make_report();
517 let (set, expected) =
518 build_taxon_set_and_expected_count(&report, &[2], true, false).unwrap();
519 assert_eq!(set, HashSet::from([2, 3]));
520 assert_eq!(expected, 600);
521 }
522
523 #[test]
524 fn test_build_taxon_set_with_descendants_root() {
525 let report = make_report();
526 let (set, expected) =
527 build_taxon_set_and_expected_count(&report, &[1], true, false).unwrap();
528 assert_eq!(set, HashSet::from([1, 2, 3, 4, 5]));
529 assert_eq!(expected, 900);
530 }
531
532 #[test]
533 fn test_build_taxon_set_unknown_taxon() {
534 let report = make_report();
535 let result = build_taxon_set_and_expected_count(&report, &[99999], false, false);
536 assert!(result.is_err());
537 }
538
539 #[test]
540 fn test_build_taxon_set_include_unclassified() {
541 let report = make_report();
542 let (set, expected) =
543 build_taxon_set_and_expected_count(&report, &[3], false, true).unwrap();
544 assert_eq!(set, HashSet::from([0, 3]));
545 assert_eq!(expected, 600);
546 }
547
548 #[test]
549 fn test_build_taxon_set_only_unclassified() {
550 let report = make_report();
551 let (set, expected) =
552 build_taxon_set_and_expected_count(&report, &[], false, true).unwrap();
553 assert_eq!(set, HashSet::from([0]));
554 assert_eq!(expected, 100);
555 }
556
557 #[test]
558 fn test_expected_count_with_descendants() {
559 let report = make_report();
560 let (_, expected) = build_taxon_set_and_expected_count(&report, &[2], true, false).unwrap();
561 assert_eq!(expected, 600);
562 }
563
564 #[test]
565 fn test_expected_count_without_descendants() {
566 let report = make_report();
567 let (_, expected) =
568 build_taxon_set_and_expected_count(&report, &[2], false, false).unwrap();
569 assert_eq!(expected, 10);
570 }
571
572 #[test]
573 fn test_expected_count_with_unclassified() {
574 let report = make_report();
575 let (_, expected) = build_taxon_set_and_expected_count(&report, &[3], false, true).unwrap();
576 assert_eq!(expected, 600);
577 }
578
579 #[test]
580 fn test_validate_read_name_match() {
581 assert!(validate_read_name("read1", b"read1", 1).is_ok());
582 }
583
584 #[test]
585 fn test_validate_read_name_mismatch() {
586 assert!(validate_read_name("read1", b"read2", 1).is_err());
587 }
588
589 #[test]
590 fn test_validate_read_name_strip_suffix_1() {
591 assert!(validate_read_name("read1", b"read1/1", 1).is_ok());
592 }
593
594 #[test]
595 fn test_validate_read_name_strip_suffix_2() {
596 assert!(validate_read_name("read1", b"read1/2", 1).is_ok());
597 }
598
599 #[test]
600 fn test_validate_read_name_with_comment() {
601 assert!(validate_read_name("read1", b"read1 length=150", 1).is_ok());
602 }
603
604 #[test]
605 fn test_validate_read_name_suffix_and_comment() {
606 assert!(validate_read_name("read1", b"read1/1 length=150", 1).is_ok());
607 }
608
609 #[test]
610 fn test_validate_args_mismatched_counts() {
611 let filter = Filter {
612 kraken_report: PathBuf::from("r.txt"),
613 kraken_output: PathBuf::from("k.txt"),
614 input: vec![PathBuf::from("a.fq"), PathBuf::from("b.fq")],
615 output: vec![PathBuf::from("c.fq")],
616 taxon_ids: vec![1],
617 include_descendants: false,
618 include_unclassified: false,
619 threads: 4,
620 compression_level: 6,
621 };
622 assert!(filter.validate_args().is_err());
623 }
624
625 #[test]
626 fn test_validate_args_no_taxa_or_unclassified() {
627 let filter = Filter {
628 kraken_report: PathBuf::from("r.txt"),
629 kraken_output: PathBuf::from("k.txt"),
630 input: vec![PathBuf::from("a.fq")],
631 output: vec![PathBuf::from("b.fq")],
632 taxon_ids: vec![],
633 include_descendants: false,
634 include_unclassified: false,
635 threads: 4,
636 compression_level: 6,
637 };
638 assert!(filter.validate_args().is_err());
639 }
640}