use std::fs::File;
use std::io::{BufWriter, Write};
use std::path::Path;
use anyhow::{Context, Result};
use flate2::write::GzEncoder;
use flate2::Compression;
use gzp::deflate::Gzip;
use gzp::par::compress::{ParCompress, ParCompressBuilder};
use super::OwnedRecord;
const BUFFER_SIZE: usize = 128 * 1024;
pub const DEFAULT_COMPRESSION_LEVEL: u32 = 4;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum CompressionType {
#[default]
None,
Gzip,
ParallelGzip,
}
enum StdoutWriterInner {
Plain(BufWriter<std::io::Stdout>),
Gzip(GzEncoder<BufWriter<std::io::Stdout>>),
}
impl Write for StdoutWriterInner {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
match self {
StdoutWriterInner::Plain(w) => w.write(buf),
StdoutWriterInner::Gzip(w) => w.write(buf),
}
}
fn flush(&mut self) -> std::io::Result<()> {
match self {
StdoutWriterInner::Plain(w) => w.flush(),
StdoutWriterInner::Gzip(w) => w.flush(),
}
}
}
impl CompressionType {
pub fn from_path(path: &Path) -> Self {
let path_str = path.to_string_lossy().to_lowercase();
if path_str.ends_with(".gz") || path_str.ends_with(".gzip") {
Self::ParallelGzip
} else {
Self::None
}
}
}
enum WriterInner {
Plain(BufWriter<File>),
Gzip(GzEncoder<BufWriter<File>>),
ParallelGzip(ParCompress<Gzip>),
}
impl Write for WriterInner {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
match self {
WriterInner::Plain(w) => w.write(buf),
WriterInner::Gzip(w) => w.write(buf),
WriterInner::ParallelGzip(w) => w.write(buf),
}
}
fn flush(&mut self) -> std::io::Result<()> {
match self {
WriterInner::Plain(w) => w.flush(),
WriterInner::Gzip(w) => w.flush(),
WriterInner::ParallelGzip(w) => w.flush(),
}
}
}
pub struct StdoutFastqWriter {
inner: StdoutWriterInner,
compression_level: u32,
}
impl StdoutFastqWriter {
pub fn write_record(&mut self, record: &OwnedRecord) -> Result<()> {
self.inner.write_all(b"@")?;
self.inner.write_all(&record.name)?;
self.inner.write_all(b"\n")?;
self.inner.write_all(&record.seq)?;
self.inner.write_all(b"\n+\n")?;
self.inner.write_all(&record.qual)?;
self.inner.write_all(b"\n")?;
Ok(())
}
pub fn write_batch(&mut self, records: &[OwnedRecord]) -> Result<()> {
for record in records {
self.write_record(record)?;
}
Ok(())
}
pub fn flush(&mut self) -> Result<()> {
self.inner.flush()?;
Ok(())
}
pub fn compression_level(&self) -> u32 {
self.compression_level
}
}
impl Drop for StdoutFastqWriter {
fn drop(&mut self) {
let _ = self.flush();
}
}
pub fn create_stdout_writer(use_gzip: bool, level: u32) -> Result<StdoutFastqWriter> {
use std::io::stdout;
let stdout_handle = stdout();
let buffered = BufWriter::with_capacity(BUFFER_SIZE, stdout_handle);
let writer = if use_gzip {
StdoutWriterInner::Gzip(GzEncoder::new(buffered, Compression::new(level)))
} else {
StdoutWriterInner::Plain(buffered)
};
Ok(StdoutFastqWriter {
inner: writer,
compression_level: if use_gzip { level } else { 0 },
})
}
pub struct FastqWriter {
writer: WriterInner,
compression_level: u32,
}
impl FastqWriter {
pub fn new(path: &Path, compression: CompressionType) -> Result<Self> {
Self::with_level(path, compression, DEFAULT_COMPRESSION_LEVEL)
}
pub fn with_level(path: &Path, compression: CompressionType, level: u32) -> Result<Self> {
let file = File::create(path)
.with_context(|| format!("Failed to create file: {}", path.display()))?;
let writer = match compression {
CompressionType::None => {
WriterInner::Plain(BufWriter::with_capacity(BUFFER_SIZE, file))
}
CompressionType::Gzip => {
let buf_writer = BufWriter::with_capacity(BUFFER_SIZE, file);
let encoder = GzEncoder::new(buf_writer, Compression::new(level));
WriterInner::Gzip(encoder)
}
CompressionType::ParallelGzip => {
let cpu_count = num_cpus::get();
let num_threads = if cpu_count <= 2 {
1 } else {
(cpu_count / 2).min(8) };
let par_writer: ParCompress<Gzip> = ParCompressBuilder::new()
.num_threads(num_threads)
.unwrap()
.compression_level(gzp::Compression::new(level))
.from_writer(file);
WriterInner::ParallelGzip(par_writer)
}
};
Ok(Self {
writer,
compression_level: level,
})
}
pub fn write_record(&mut self, record: &OwnedRecord) -> Result<()> {
self.writer.write_all(b"@")?;
self.writer.write_all(&record.name)?;
self.writer.write_all(b"\n")?;
self.writer.write_all(&record.seq)?;
self.writer.write_all(b"\n+\n")?;
self.writer.write_all(&record.qual)?;
self.writer.write_all(b"\n")?;
Ok(())
}
pub fn write_batch(&mut self, records: &[OwnedRecord]) -> Result<()> {
for record in records {
self.write_record(record)?;
}
Ok(())
}
pub fn flush(&mut self) -> Result<()> {
self.writer.flush()?;
Ok(())
}
pub fn compression_level(&self) -> u32 {
self.compression_level
}
pub fn write_raw(&mut self, data: &[u8]) -> Result<()> {
self.writer.write_all(data)?;
Ok(())
}
}
impl Drop for FastqWriter {
fn drop(&mut self) {
let _ = self.flush();
}
}
pub struct PairedFastqWriter {
writer1: FastqWriter,
writer2: FastqWriter,
}
impl PairedFastqWriter {
pub fn new(path1: &Path, path2: &Path, compression: CompressionType) -> Result<Self> {
let writer1 = FastqWriter::new(path1, compression)
.with_context(|| format!("Failed to create R1 file: {}", path1.display()))?;
let writer2 = FastqWriter::new(path2, compression)
.with_context(|| format!("Failed to create R2 file: {}", path2.display()))?;
Ok(Self { writer1, writer2 })
}
pub fn with_level(
path1: &Path,
path2: &Path,
compression: CompressionType,
level: u32,
) -> Result<Self> {
let writer1 = FastqWriter::with_level(path1, compression, level)
.with_context(|| format!("Failed to create R1 file: {}", path1.display()))?;
let writer2 = FastqWriter::with_level(path2, compression, level)
.with_context(|| format!("Failed to create R2 file: {}", path2.display()))?;
Ok(Self { writer1, writer2 })
}
pub fn write_pair(&mut self, r1: &OwnedRecord, r2: &OwnedRecord) -> Result<()> {
self.writer1.write_record(r1)?;
self.writer2.write_record(r2)?;
Ok(())
}
pub fn write_batch(&mut self, pairs: &[(OwnedRecord, OwnedRecord)]) -> Result<()> {
for (r1, r2) in pairs {
self.write_pair(r1, r2)?;
}
Ok(())
}
pub fn flush(&mut self) -> Result<()> {
self.writer1.flush()?;
self.writer2.flush()?;
Ok(())
}
pub fn write_raw(&mut self, r1_data: &[u8], r2_data: &[u8]) -> Result<()> {
self.writer1.write_raw(r1_data)?;
self.writer2.write_raw(r2_data)?;
Ok(())
}
}
impl Drop for PairedFastqWriter {
fn drop(&mut self) {
let _ = self.flush();
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::io::FastqReader;
use std::io::Read;
use tempfile::{tempdir, NamedTempFile};
fn create_test_record(name: &[u8], seq: &[u8], qual: &[u8]) -> OwnedRecord {
OwnedRecord::new(name.to_vec(), seq.to_vec(), qual.to_vec())
}
#[test]
fn test_writer_plain_text() {
let file = NamedTempFile::with_suffix(".fastq").unwrap();
let record = create_test_record(b"read1", b"ACGT", b"IIII");
{
let mut writer = FastqWriter::new(file.path(), CompressionType::None).unwrap();
writer.write_record(&record).unwrap();
}
let contents = std::fs::read_to_string(file.path()).unwrap();
assert!(contents.contains("@read1"));
assert!(contents.contains("ACGT"));
assert!(contents.contains("IIII"));
}
#[test]
fn test_writer_gzip() {
let file = NamedTempFile::with_suffix(".fastq.gz").unwrap();
let record = create_test_record(b"read1", b"ACGTACGT", b"IIIIIIII");
{
let mut writer = FastqWriter::new(file.path(), CompressionType::Gzip).unwrap();
writer.write_record(&record).unwrap();
}
let gz_file = File::open(file.path()).unwrap();
let mut decoder = flate2::read::GzDecoder::new(gz_file);
let mut contents = String::new();
decoder.read_to_string(&mut contents).unwrap();
assert!(contents.contains("@read1"));
assert!(contents.contains("ACGTACGT"));
}
#[test]
fn test_writer_parallel_gzip() {
let file = NamedTempFile::with_suffix(".fastq.gz").unwrap();
let record = create_test_record(b"read1", b"ACGTACGT", b"IIIIIIII");
{
let mut writer = FastqWriter::new(file.path(), CompressionType::ParallelGzip).unwrap();
writer.write_record(&record).unwrap();
}
let gz_file = File::open(file.path()).unwrap();
let mut decoder = flate2::read::GzDecoder::new(gz_file);
let mut contents = String::new();
decoder.read_to_string(&mut contents).unwrap();
assert!(contents.contains("@read1"));
assert!(contents.contains("ACGTACGT"));
}
#[test]
fn test_write_batch() {
let file = NamedTempFile::with_suffix(".fastq").unwrap();
let records = vec![
create_test_record(b"read1", b"AAAA", b"IIII"),
create_test_record(b"read2", b"CCCC", b"HHHH"),
create_test_record(b"read3", b"GGGG", b"JJJJ"),
];
{
let mut writer = FastqWriter::new(file.path(), CompressionType::None).unwrap();
writer.write_batch(&records).unwrap();
}
let contents = std::fs::read_to_string(file.path()).unwrap();
assert!(contents.contains("@read1"));
assert!(contents.contains("@read2"));
assert!(contents.contains("@read3"));
}
#[test]
fn test_roundtrip_plain() {
let file = NamedTempFile::with_suffix(".fastq").unwrap();
let records = vec![
create_test_record(b"read1", b"ACGTACGT", b"IIIIIIII"),
create_test_record(b"read2", b"TGCATGCA", b"HHHHHHHH"),
];
{
let mut writer = FastqWriter::new(file.path(), CompressionType::None).unwrap();
writer.write_batch(&records).unwrap();
}
let mut reader = FastqReader::new(file.path()).unwrap();
let read_records = reader.read_batch(10).unwrap();
assert_eq!(read_records.len(), 2);
assert_eq!(read_records[0].name, records[0].name);
assert_eq!(read_records[0].seq, records[0].seq);
assert_eq!(read_records[0].qual, records[0].qual);
assert_eq!(read_records[1].name, records[1].name);
}
#[test]
fn test_roundtrip_gzip() {
let file = NamedTempFile::with_suffix(".fastq.gz").unwrap();
let records = vec![
create_test_record(b"read1", b"ACGTACGT", b"IIIIIIII"),
create_test_record(b"read2", b"TGCATGCA", b"HHHHHHHH"),
];
{
let mut writer = FastqWriter::new(file.path(), CompressionType::Gzip).unwrap();
writer.write_batch(&records).unwrap();
}
let mut reader = FastqReader::new(file.path()).unwrap();
let read_records = reader.read_batch(10).unwrap();
assert_eq!(read_records.len(), 2);
assert_eq!(read_records[0].name, records[0].name);
assert_eq!(read_records[0].seq, records[0].seq);
}
#[test]
fn test_paired_writer() {
let dir = tempdir().unwrap();
let path1 = dir.path().join("r1.fastq");
let path2 = dir.path().join("r2.fastq");
let r1 = create_test_record(b"read1/1", b"AAAA", b"IIII");
let r2 = create_test_record(b"read1/2", b"TTTT", b"IIII");
{
let mut writer =
PairedFastqWriter::new(&path1, &path2, CompressionType::None).unwrap();
writer.write_pair(&r1, &r2).unwrap();
}
let contents1 = std::fs::read_to_string(&path1).unwrap();
let contents2 = std::fs::read_to_string(&path2).unwrap();
assert!(contents1.contains("@read1/1"));
assert!(contents1.contains("AAAA"));
assert!(contents2.contains("@read1/2"));
assert!(contents2.contains("TTTT"));
}
#[test]
fn test_paired_writer_batch() {
let dir = tempdir().unwrap();
let path1 = dir.path().join("r1.fastq");
let path2 = dir.path().join("r2.fastq");
let pairs = vec![
(
create_test_record(b"read1/1", b"AAAA", b"IIII"),
create_test_record(b"read1/2", b"TTTT", b"IIII"),
),
(
create_test_record(b"read2/1", b"CCCC", b"HHHH"),
create_test_record(b"read2/2", b"GGGG", b"HHHH"),
),
];
{
let mut writer =
PairedFastqWriter::new(&path1, &path2, CompressionType::None).unwrap();
writer.write_batch(&pairs).unwrap();
}
let contents1 = std::fs::read_to_string(&path1).unwrap();
let contents2 = std::fs::read_to_string(&path2).unwrap();
assert!(contents1.contains("@read1/1"));
assert!(contents1.contains("@read2/1"));
assert!(contents2.contains("@read1/2"));
assert!(contents2.contains("@read2/2"));
}
#[test]
fn test_compression_type_from_path() {
assert_eq!(
CompressionType::from_path(Path::new("file.fastq")),
CompressionType::None
);
assert_eq!(
CompressionType::from_path(Path::new("file.fq")),
CompressionType::None
);
assert_eq!(
CompressionType::from_path(Path::new("file.fastq.gz")),
CompressionType::ParallelGzip
);
assert_eq!(
CompressionType::from_path(Path::new("file.fq.gz")),
CompressionType::ParallelGzip
);
assert_eq!(
CompressionType::from_path(Path::new("file.gzip")),
CompressionType::ParallelGzip
);
}
#[test]
fn test_compression_level() {
let file = NamedTempFile::with_suffix(".fastq.gz").unwrap();
let record = create_test_record(b"read1", b"ACGT", b"IIII");
{
let mut writer =
FastqWriter::with_level(file.path(), CompressionType::Gzip, 9).unwrap();
assert_eq!(writer.compression_level(), 9);
writer.write_record(&record).unwrap();
}
let mut reader = FastqReader::new(file.path()).unwrap();
let records = reader.read_batch(10).unwrap();
assert_eq!(records.len(), 1);
}
#[test]
fn test_empty_record() {
let file = NamedTempFile::with_suffix(".fastq").unwrap();
let record = create_test_record(b"empty", b"", b"");
{
let mut writer = FastqWriter::new(file.path(), CompressionType::None).unwrap();
writer.write_record(&record).unwrap();
}
let contents = std::fs::read_to_string(file.path()).unwrap();
assert!(contents.contains("@empty"));
}
}