use crate::core::seq_record::SeqRecord;
use anyhow::{Context, Result};
use flate2::write::GzEncoder;
use flate2::Compression;
use std::fs::File;
use std::io::{BufWriter, Write};
use std::path::Path;
pub struct SeqWriter<W: Write> {
writer: W,
line_width: Option<usize>, }
impl SeqWriter<Box<dyn Write>> {
pub fn to_path<P: AsRef<Path>>(path: P, compress: bool) -> Result<Self> {
let path = path.as_ref();
let file =
File::create(path).with_context(|| format!("无法创建文件: {}", path.display()))?;
let writer: Box<dyn Write> = if compress {
Box::new(BufWriter::with_capacity(
16384,
GzEncoder::new(file, Compression::default()),
))
} else {
Box::new(BufWriter::with_capacity(16384, file))
};
Ok(Self::new(writer))
}
pub fn to_stdout() -> Self {
let stdout = std::io::stdout();
let writer = BufWriter::with_capacity(16384, stdout);
Self::new(Box::new(writer))
}
}
impl<W: Write> SeqWriter<W> {
pub fn new(writer: W) -> Self {
Self {
writer,
line_width: None,
}
}
pub fn with_line_width(mut self, width: usize) -> Self {
self.line_width = if width > 0 { Some(width) } else { None };
self
}
pub fn write_record(&mut self, record: &SeqRecord) -> Result<()> {
if record.is_fastq() {
self.write_fastq(record)
} else {
self.write_fasta(record)
}
}
fn write_fasta(&mut self, record: &SeqRecord) -> Result<()> {
write!(self.writer, ">")?;
self.writer.write_all(&record.name)?;
if let Some(comment) = &record.comment {
write!(self.writer, " ")?;
self.writer.write_all(comment)?;
}
writeln!(self.writer)?;
match self.line_width {
Some(width) => {
for chunk in record.seq.chunks(width) {
self.writer.write_all(chunk)?;
writeln!(self.writer)?;
}
}
None => {
self.writer.write_all(&record.seq)?;
writeln!(self.writer)?;
}
}
Ok(())
}
fn write_fastq(&mut self, record: &SeqRecord) -> Result<()> {
write!(self.writer, "@")?;
self.writer.write_all(&record.name)?;
if let Some(comment) = &record.comment {
write!(self.writer, " ")?;
self.writer.write_all(comment)?;
}
writeln!(self.writer)?;
match self.line_width {
Some(width) => {
for chunk in record.seq.chunks(width) {
self.writer.write_all(chunk)?;
writeln!(self.writer)?;
}
}
None => {
self.writer.write_all(&record.seq)?;
writeln!(self.writer)?;
}
}
writeln!(self.writer, "+")?;
if let Some(qual) = &record.qual {
match self.line_width {
Some(width) => {
for chunk in qual.chunks(width) {
self.writer.write_all(chunk)?;
writeln!(self.writer)?;
}
}
None => {
self.writer.write_all(qual)?;
writeln!(self.writer)?;
}
}
}
Ok(())
}
pub fn flush(&mut self) -> Result<()> {
self.writer.flush()?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Cursor;
#[test]
fn test_write_fasta() {
let mut output = Cursor::new(Vec::new());
let mut writer = SeqWriter::new(&mut output);
let record = SeqRecord::new(b"seq1".to_vec(), b"ACGTACGT".to_vec());
writer.write_record(&record).unwrap();
writer.flush().unwrap();
let result = String::from_utf8(output.into_inner()).unwrap();
assert_eq!(result, ">seq1\nACGTACGT\n");
}
#[test]
fn test_write_fasta_with_comment() {
let mut output = Cursor::new(Vec::new());
let mut writer = SeqWriter::new(&mut output);
let mut record = SeqRecord::new(b"seq1".to_vec(), b"ACGT".to_vec());
record.comment = Some(b"test comment".to_vec());
writer.write_record(&record).unwrap();
writer.flush().unwrap();
let result = String::from_utf8(output.into_inner()).unwrap();
assert_eq!(result, ">seq1 test comment\nACGT\n");
}
#[test]
fn test_write_fasta_with_line_width() {
let mut output = Cursor::new(Vec::new());
let mut writer = SeqWriter::new(&mut output).with_line_width(4);
let record = SeqRecord::new(b"seq1".to_vec(), b"ACGTACGT".to_vec());
writer.write_record(&record).unwrap();
writer.flush().unwrap();
let result = String::from_utf8(output.into_inner()).unwrap();
assert_eq!(result, ">seq1\nACGT\nACGT\n");
}
#[test]
fn test_write_fastq() {
let mut output = Cursor::new(Vec::new());
let mut writer = SeqWriter::new(&mut output);
let record = SeqRecord::with_qual(b"seq1".to_vec(), b"ACGT".to_vec(), b"IIII".to_vec());
writer.write_record(&record).unwrap();
writer.flush().unwrap();
let result = String::from_utf8(output.into_inner()).unwrap();
assert_eq!(result, "@seq1\nACGT\n+\nIIII\n");
}
#[test]
fn test_write_multiple_records() {
let mut output = Cursor::new(Vec::new());
let mut writer = SeqWriter::new(&mut output);
let record1 = SeqRecord::new(b"seq1".to_vec(), b"ACGT".to_vec());
let record2 = SeqRecord::new(b"seq2".to_vec(), b"TGCA".to_vec());
writer.write_record(&record1).unwrap();
writer.write_record(&record2).unwrap();
writer.flush().unwrap();
let result = String::from_utf8(output.into_inner()).unwrap();
assert_eq!(result, ">seq1\nACGT\n>seq2\nTGCA\n");
}
}