use crate::commands::command::Command;
use crate::commands::fastq_readahead::ReadAheadBuilder;
use anyhow::{Result, anyhow};
use clap::Parser;
use clap::builder::RangedU64ValueParser;
use itertools::Itertools;
use log::info;
use pooled_writer::{Pool, PoolBuilder, PooledWriter, bgzf::BgzfCompressor};
use proglog::{CountFormatterKind, ProgLogBuilder};
use std::fs::File;
use std::io::{BufWriter, Write};
use std::path::{Path, PathBuf};
struct ShardWriters<W: Write> {
writers: Vec<W>,
}
impl ShardWriters<PooledWriter> {
fn close(self) -> Result<()> {
self.writers.into_iter().try_for_each(PooledWriter::close)?;
Ok(())
}
}
#[derive(Parser, Debug)]
#[command(version)]
#[clap(verbatim_doc_comment)]
pub(crate) struct Shard {
#[clap(long, short = 'i', required = true, num_args = 1..)]
inputs: Vec<PathBuf>,
#[clap(long, short = 'o')]
output_prefix: String,
#[clap(long, short = 'S', default_value = "s")]
shard_prefix: String,
#[clap(long, short = 'R', default_value = "r")]
read_number_prefix: String,
#[clap(long, short = 's', value_parser = RangedU64ValueParser::<usize>::new().range(1..))]
shards: usize,
#[clap(long, short = 't', default_value = "8", value_parser = RangedU64ValueParser::<usize>::new().range(2..))]
threads: usize,
#[clap(long, short = 'c', default_value = "1",
value_parser = RangedU64ValueParser::<u8>::new().range(1..=12))]
compression_level: u8,
#[clap(long, hide = true, default_value = "131072",
value_parser = RangedU64ValueParser::<usize>::new().range(1..))]
chunk_size: usize,
#[clap(long, hide = true, default_value = "32",
value_parser = RangedU64ValueParser::<usize>::new().range(1..))]
chunk_count: usize,
}
impl Shard {
fn build_writer_pool(&self) -> Result<(Pool, Vec<ShardWriters<PooledWriter>>)> {
let mut shard_writers = Vec::with_capacity(self.shards);
for shard in 1..=self.shards {
let mut ws = Vec::with_capacity(self.inputs.len());
for source_idx in 1..=self.inputs.len() {
let path_str = format!(
"{prefix}.{shard_prefix}{shard_num}.{read_prefix}{read_num}.fq.gz",
prefix = self.output_prefix,
shard_prefix = self.shard_prefix,
shard_num = shard,
read_prefix = self.read_number_prefix,
read_num = source_idx
);
let path = Path::new(&path_str);
let writer = BufWriter::new(File::create(path)?);
ws.push(writer);
}
shard_writers.push(ShardWriters { writers: ws });
}
let pool_threads = self.threads - 1;
let mut pool_builder = PoolBuilder::<_, BgzfCompressor>::new()
.threads(pool_threads)
.queue_size(pool_threads * 50)
.compression_level(self.compression_level)?;
let mut pooled_shard_writers = Vec::with_capacity(shard_writers.len());
for shard_writer in shard_writers.into_iter() {
let pooled_writers =
shard_writer.writers.into_iter().map(|w| pool_builder.exchange(w)).collect_vec();
pooled_shard_writers.push(ShardWriters { writers: pooled_writers });
}
let pool = pool_builder.build()?;
Ok((pool, pooled_shard_writers))
}
}
impl Command for Shard {
fn execute(&self) -> Result<()> {
info!("Reading {} input FASTQs and generating {} shards.", self.inputs.len(), self.shards);
let mut readers = Vec::with_capacity(self.inputs.len());
for path in &self.inputs {
readers.push(
ReadAheadBuilder {
path: path.clone(),
chunk_size: self.chunk_size,
chunk_count: self.chunk_count,
}
.build()?,
);
}
let n_inputs = readers.len();
let (mut pool, mut shard_writers) = self.build_writer_pool()?;
let logger = ProgLogBuilder::new()
.name("fqtk")
.noun("record sets")
.verb("read")
.unit(5_000_000)
.count_formatter(CountFormatterKind::Comma)
.level(log::Level::Info)
.build();
let mut target_shard_idx: usize = 0;
let process_result: Result<()> = 'process: {
loop {
let records: Vec<_> = readers.iter_mut().map(|reader| reader.next()).collect();
for result in records.iter().flatten() {
if let Err(e) = result {
break 'process Err(anyhow!("Error reading FASTQ input: {e}"));
}
}
let present = records.iter().filter(|slot| slot.is_some()).count();
if present == 0 {
break; }
if present != n_inputs {
break 'process Err(anyhow!(
"FASTQ sources out of sync; expected {} records but got {}.",
n_inputs,
present
));
}
let target = &mut shard_writers[target_shard_idx];
for (slot, writer) in records.iter().zip(target.writers.iter_mut()) {
if let Some(Ok(record)) = slot {
if let Err(e) = record.write_unchanged(&mut *writer) {
break 'process Err(e.into());
}
}
}
target_shard_idx = (target_shard_idx + 1) % self.shards;
logger.record();
}
Ok(())
};
info!("Finished reading input FASTQs.");
let close_result = shard_writers.into_iter().try_for_each(ShardWriters::close);
let stop_result = pool.stop_pool().map_err(anyhow::Error::from);
process_result.and(close_result).and(stop_result)?;
info!("Output FASTQ writing complete.");
Ok(())
}
}
#[cfg(test)]
mod tests {
use crate::commands::command::Command;
use crate::commands::shard::Shard;
use bstr::ByteSlice;
use itertools::Itertools;
use rand;
use seq_io::fastq::{OwnedRecord, Record};
use std::collections::HashSet;
use std::fs::File;
use std::io::{BufReader, BufWriter, Write};
use std::path::{Path, PathBuf};
use tempfile::TempDir;
fn write_fastq_records<W: Write>(
out: &mut W,
prefix: &str,
suffix: &str,
idx: usize,
count: usize,
) {
let bases = "ACGT".as_bytes();
for i in idx..idx + count {
let seq = (0..30).map(|_| rand::random_range(0..4)).map(|j| bases[j]).collect_vec();
let qual = (0..30).map(|_| rand::random_range(2u8..40u8) + 33).collect_vec();
let rec = OwnedRecord {
head: format!("{}{}{}", prefix, i, suffix).as_bytes().to_owned(),
seq,
qual,
};
rec.write(&mut *out).unwrap();
}
}
fn build_fastq(path: &Path, prefix: &str, suffix: &str, idx: usize, count: usize) {
let io = fgoxide::io::Io::new(1, 8 * 1024);
let mut out = io.new_writer(path).unwrap();
write_fastq_records(&mut out, prefix, suffix, idx, count);
}
fn build_fastq_bgzf(path: &Path, prefix: &str, suffix: &str, idx: usize, count: usize) {
let file = BufWriter::new(File::create(path).unwrap());
let mut writer = bgzf::Writer::new(file, bgzf::CompressionLevel::new(3).unwrap());
write_fastq_records(&mut writer, prefix, suffix, idx, count);
writer.finish().unwrap().flush().unwrap();
}
fn run_sharding(tmp: &TempDir, inputs: &[&Path], shards: usize) -> Vec<Vec<Vec<OwnedRecord>>> {
let prefix = format!("{}/test_out", tmp.path().to_str().unwrap());
build_sharder(inputs, &prefix, shards).execute().unwrap();
collect_shard_outputs(&prefix, inputs.len(), shards, read_fastq)
}
fn build_sharder(inputs: &[&Path], output_prefix: &str, shards: usize) -> Shard {
Shard {
inputs: inputs.iter().map(|p| p.to_path_buf()).collect_vec(),
output_prefix: output_prefix.to_string(),
shard_prefix: "shard".to_string(),
read_number_prefix: "read".to_string(),
shards,
threads: 4,
compression_level: 1,
chunk_size: 100,
chunk_count: 8,
}
}
fn collect_shard_outputs<F>(
output_prefix: &str,
n_inputs: usize,
shards: usize,
reader: F,
) -> Vec<Vec<Vec<OwnedRecord>>>
where
F: Fn(&Path) -> Vec<OwnedRecord>,
{
let mut results: Vec<Vec<Vec<OwnedRecord>>> = Vec::with_capacity(shards);
for shard in 1..=shards {
let mut reads_vecs = Vec::with_capacity(n_inputs);
for input_idx in 1..=n_inputs {
let path_str = format!("{}.shard{}.read{}.fq.gz", output_prefix, shard, input_idx);
reads_vecs.push(reader(Path::new(&path_str)));
}
results.push(reads_vecs);
}
results
}
fn read_fastq(path: &Path) -> Vec<OwnedRecord> {
let io = fgoxide::io::Io::new(1, 8 * 1024);
let mut reader = io.new_reader(path).unwrap();
let mut fq_reader = seq_io::fastq::Reader::with_capacity(&mut reader, 8 * 1024);
fq_reader.records().map(|r| r.unwrap()).collect_vec()
}
fn read_fastq_via_bgzf(path: &Path) -> Vec<OwnedRecord> {
let file = BufReader::new(File::open(path).unwrap());
let bgzf_reader = bgzf::Reader::new(file);
let mut fq_reader = seq_io::fastq::Reader::with_capacity(bgzf_reader, 8 * 1024);
fq_reader.records().map(|r| r.unwrap()).collect_vec()
}
fn read_index(rec: &OwnedRecord) -> usize {
let head = rec.head.to_str().unwrap();
let trimmed = head
.rsplit_once('/')
.filter(|(_, tail)| tail.chars().all(|c| c.is_ascii_digit()))
.map(|(head, _)| head)
.unwrap_or(head);
let digits: String = trimmed.chars().rev().take_while(|c| c.is_ascii_digit()).collect();
digits.chars().rev().collect::<String>().parse().unwrap()
}
#[test]
fn test_shard_single_file() {
let tmp = TempDir::new().unwrap();
let r1 = PathBuf::from(tmp.path()).join("r1.fq");
build_fastq(r1.as_path(), "q", "", 1, 50);
let outputs = run_sharding(&tmp, &[&r1], 5);
assert_eq!(outputs.len(), 5);
for shard in outputs.iter() {
assert_eq!(shard.len(), 1);
assert_eq!(shard.iter().next().unwrap().len(), 10);
}
let read_names: HashSet<&str> =
outputs.iter().flatten().flatten().map(|r| r.head.to_str().unwrap()).collect();
assert_eq!(read_names.len(), 50);
}
#[test]
fn test_shard_multiple_files() {
let tmp = TempDir::new().unwrap();
let r1 = PathBuf::from(tmp.path()).join("r1.fq");
let r2 = PathBuf::from(tmp.path()).join("r2.fq");
build_fastq(r1.as_path(), "q", "/1", 1, 64);
build_fastq(r2.as_path(), "q", "/2", 1, 64);
let outputs = run_sharding(&tmp, &[&r1, &r2], 3);
assert_eq!(outputs.len(), 3);
for shard in outputs.iter() {
assert_eq!(shard.len(), 2); assert_eq!(shard[0].len(), shard[1].len()); assert!(shard[0].len() == 21 || shard[0].len() == 22); }
}
#[test]
fn test_round_robin_assignment() {
let tmp = TempDir::new().unwrap();
let r1 = PathBuf::from(tmp.path()).join("r1.fq");
let n_reads = 20;
let n_shards = 4;
build_fastq(r1.as_path(), "q", "", 1, n_reads);
let outputs = run_sharding(&tmp, &[&r1], n_shards);
for (shard_idx, shard) in outputs.iter().enumerate() {
let expected: Vec<usize> =
(1..=n_reads).filter(|i| (i - 1) % n_shards == shard_idx).collect();
let actual: Vec<usize> = shard[0].iter().map(read_index).collect();
assert_eq!(actual, expected, "shard {} contents wrong", shard_idx + 1);
}
}
#[test]
fn test_paired_records_stay_aligned() {
let tmp = TempDir::new().unwrap();
let r1 = PathBuf::from(tmp.path()).join("r1.fq");
let r2 = PathBuf::from(tmp.path()).join("r2.fq");
build_fastq(r1.as_path(), "q", "/1", 1, 30);
build_fastq(r2.as_path(), "q", "/2", 1, 30);
let outputs = run_sharding(&tmp, &[&r1, &r2], 4);
for shard in outputs.iter() {
let r1_indices: Vec<usize> = shard[0].iter().map(read_index).collect();
let r2_indices: Vec<usize> = shard[1].iter().map(read_index).collect();
assert_eq!(r1_indices, r2_indices);
}
}
#[test]
fn test_no_reads_lost_or_duplicated() {
let tmp = TempDir::new().unwrap();
let r1 = PathBuf::from(tmp.path()).join("r1.fq");
build_fastq(r1.as_path(), "q", "", 1, 100);
let outputs = run_sharding(&tmp, &[&r1], 7);
let all_indices: Vec<usize> =
outputs.iter().flatten().flatten().map(read_index).sorted().collect();
let expected: Vec<usize> = (1..=100).collect();
assert_eq!(all_indices, expected);
}
#[test]
fn test_single_shard() {
let tmp = TempDir::new().unwrap();
let r1 = PathBuf::from(tmp.path()).join("r1.fq");
build_fastq(r1.as_path(), "q", "", 1, 12);
let outputs = run_sharding(&tmp, &[&r1], 1);
assert_eq!(outputs.len(), 1);
let actual: Vec<usize> = outputs[0][0].iter().map(read_index).collect();
assert_eq!(actual, (1..=12).collect::<Vec<_>>());
}
#[test]
fn test_more_shards_than_reads() {
let tmp = TempDir::new().unwrap();
let r1 = PathBuf::from(tmp.path()).join("r1.fq");
build_fastq(r1.as_path(), "q", "", 1, 3);
let outputs = run_sharding(&tmp, &[&r1], 5);
assert_eq!(outputs.len(), 5);
assert_eq!(outputs[0][0].len(), 1);
assert_eq!(outputs[1][0].len(), 1);
assert_eq!(outputs[2][0].len(), 1);
assert_eq!(outputs[3][0].len(), 0);
assert_eq!(outputs[4][0].len(), 0);
}
#[test]
fn test_empty_input() {
let tmp = TempDir::new().unwrap();
let r1 = PathBuf::from(tmp.path()).join("r1.fq");
build_fastq(r1.as_path(), "q", "", 1, 0);
let outputs = run_sharding(&tmp, &[&r1], 4);
assert_eq!(outputs.len(), 4);
for shard in outputs.iter() {
assert_eq!(shard.len(), 1);
assert!(shard[0].is_empty());
}
}
#[test]
fn test_mismatched_input_lengths_fails() {
let tmp = TempDir::new().unwrap();
let r1 = PathBuf::from(tmp.path()).join("r1.fq");
let r2 = PathBuf::from(tmp.path()).join("r2.fq");
build_fastq(r1.as_path(), "q", "/1", 1, 10);
build_fastq(r2.as_path(), "q", "/2", 1, 8);
let prefix = format!("{}/test_out", tmp.path().to_str().unwrap());
let err = build_sharder(&[&r1, &r2], &prefix, 3).execute().unwrap_err();
assert!(err.to_string().contains("out of sync"), "unexpected error message: {}", err);
}
#[test]
fn test_out_of_sync_inputs_leave_no_orphan_record() {
let tmp = TempDir::new().unwrap();
let r1 = PathBuf::from(tmp.path()).join("r1.fq");
let r2 = PathBuf::from(tmp.path()).join("r2.fq");
build_fastq(r1.as_path(), "q", "/1", 1, 3);
build_fastq(r2.as_path(), "q", "/2", 1, 2);
let prefix = format!("{}/test_out", tmp.path().to_str().unwrap());
let err = build_sharder(&[&r1, &r2], &prefix, 1).execute().unwrap_err();
assert!(err.to_string().contains("out of sync"), "unexpected error message: {}", err);
let r1_out = read_fastq(Path::new(&format!("{}.shard1.read1.fq.gz", prefix)));
let r2_out = read_fastq(Path::new(&format!("{}.shard1.read2.fq.gz", prefix)));
assert_eq!(
r1_out.len(),
r2_out.len(),
"R1/R2 outputs left misaligned on the out-of-sync error path"
);
}
#[test]
fn test_decompression_error_surfaces() {
let tmp = TempDir::new().unwrap();
let bad = PathBuf::from(tmp.path()).join("bad.fq.gz");
std::fs::write(&bad, b"this is definitely not valid gzip content\n").unwrap();
let prefix = format!("{}/test_out", tmp.path().to_str().unwrap());
let result = build_sharder(&[bad.as_path()], &prefix, 2).execute();
assert!(result.is_err(), "expected a decompression error but the run succeeded");
}
#[test]
fn test_gzip_compressed_inputs() {
let tmp = TempDir::new().unwrap();
let r1 = PathBuf::from(tmp.path()).join("r1.fq.gz");
build_fastq(r1.as_path(), "q", "", 1, 25);
let outputs = run_sharding(&tmp, &[&r1], 5);
let all_indices: Vec<usize> =
outputs.iter().flatten().flatten().map(read_index).sorted().collect();
assert_eq!(all_indices, (1..=25).collect::<Vec<_>>());
}
#[test]
fn test_bgzf_compressed_inputs_multiple_blocks() {
let tmp = TempDir::new().unwrap();
let r1 = PathBuf::from(tmp.path()).join("r1.fq.gz");
let n_reads = 1500;
build_fastq_bgzf(r1.as_path(), "q", "", 1, n_reads);
let outputs = run_sharding(&tmp, &[&r1], 3);
let all_indices: Vec<usize> =
outputs.iter().flatten().flatten().map(read_index).sorted().collect();
assert_eq!(all_indices, (1..=n_reads).collect::<Vec<_>>());
}
#[test]
fn test_output_files_are_bgzf() {
let tmp = TempDir::new().unwrap();
let r1 = PathBuf::from(tmp.path()).join("r1.fq");
build_fastq(r1.as_path(), "q", "", 1, 30);
let prefix = format!("{}/test_out", tmp.path().to_str().unwrap());
build_sharder(&[&r1], &prefix, 3).execute().unwrap();
let outputs = collect_shard_outputs(&prefix, 1, 3, read_fastq_via_bgzf);
let all_indices: Vec<usize> =
outputs.iter().flatten().flatten().map(read_index).sorted().collect();
assert_eq!(all_indices, (1..=30).collect::<Vec<_>>());
}
#[test]
fn test_custom_prefixes() {
let tmp = TempDir::new().unwrap();
let r1 = PathBuf::from(tmp.path()).join("r1.fq");
build_fastq(r1.as_path(), "q", "", 1, 6);
let prefix = format!("{}/sample", tmp.path().to_str().unwrap());
let sharder = Shard {
inputs: vec![r1.clone()],
output_prefix: prefix.clone(),
shard_prefix: "chunk_".to_string(),
read_number_prefix: "R".to_string(),
shards: 2,
threads: 2,
compression_level: 1,
chunk_size: 64 * 1024,
chunk_count: 4,
};
sharder.execute().unwrap();
for shard in 1..=2 {
let path = PathBuf::from(format!("{}.chunk_{}.R1.fq.gz", prefix, shard));
assert!(path.exists(), "expected output file does not exist: {}", path.display());
assert_eq!(read_fastq(&path).len(), 3);
}
}
#[test]
fn test_tiny_chunk_size_reassembles_records() {
let tmp = TempDir::new().unwrap();
let r1 = PathBuf::from(tmp.path()).join("r1.fq");
build_fastq(r1.as_path(), "q", "", 1, 40);
let prefix = format!("{}/test_out", tmp.path().to_str().unwrap());
let sharder = Shard {
inputs: vec![r1.clone()],
output_prefix: prefix.clone(),
shard_prefix: "shard".to_string(),
read_number_prefix: "read".to_string(),
shards: 3,
threads: 4,
compression_level: 1,
chunk_size: 3,
chunk_count: 2,
};
sharder.execute().unwrap();
let outputs = collect_shard_outputs(&prefix, 1, 3, read_fastq);
let all_indices: Vec<usize> =
outputs.iter().flatten().flatten().map(read_index).sorted().collect();
assert_eq!(all_indices, (1..=40).collect::<Vec<_>>());
}
#[test]
fn test_large_chunk_size_single_chunk() {
let tmp = TempDir::new().unwrap();
let r1 = PathBuf::from(tmp.path()).join("r1.fq");
build_fastq(r1.as_path(), "q", "", 1, 20);
let prefix = format!("{}/test_out", tmp.path().to_str().unwrap());
let sharder = Shard {
inputs: vec![r1.clone()],
output_prefix: prefix.clone(),
shard_prefix: "shard".to_string(),
read_number_prefix: "read".to_string(),
shards: 2,
threads: 2,
compression_level: 1,
chunk_size: 8 * 1024 * 1024,
chunk_count: 4,
};
sharder.execute().unwrap();
let outputs = collect_shard_outputs(&prefix, 1, 2, read_fastq);
let all_indices: Vec<usize> =
outputs.iter().flatten().flatten().map(read_index).sorted().collect();
assert_eq!(all_indices, (1..=20).collect::<Vec<_>>());
}
}