use std::io::Write;
use byteorder::{LittleEndian, WriteBytesExt};
use rand::rngs::SmallRng;
use rand::SeedableRng;
use zstd::Encoder as ZstdEncoder;
use crate::error::{Result, WriteError};
use crate::header::{BlockHeader, VBinseqHeader};
use crate::Policy;
pub const RNG_SEED: u64 = 42;
pub fn record_byte_size(schunk: usize, xchunk: usize) -> usize {
8 * (schunk + xchunk + 3)
}
pub fn record_byte_size_quality(schunk: usize, xchunk: usize, slen: usize, xlen: usize) -> usize {
record_byte_size(schunk, xchunk) + slen + xlen
}
#[derive(Default)]
pub struct VBinseqWriterBuilder {
header: Option<VBinseqHeader>,
policy: Option<Policy>,
headless: Option<bool>,
}
impl VBinseqWriterBuilder {
pub fn header(mut self, header: VBinseqHeader) -> Self {
self.header = Some(header);
self
}
pub fn policy(mut self, policy: Policy) -> Self {
self.policy = Some(policy);
self
}
pub fn headless(mut self, headless: bool) -> Self {
self.headless = Some(headless);
self
}
pub fn build<W: Write>(self, inner: W) -> Result<VBinseqWriter<W>> {
VBinseqWriter::new(
inner,
self.header.unwrap_or_default(),
self.policy.unwrap_or_default(),
self.headless.unwrap_or(false),
)
}
}
#[derive(Clone)]
pub struct VBinseqWriter<W: Write> {
inner: W,
header: VBinseqHeader,
encoder: Encoder,
cblock: BlockWriter,
}
impl<W: Write> VBinseqWriter<W> {
pub fn new(inner: W, header: VBinseqHeader, policy: Policy, headless: bool) -> Result<Self> {
let mut wtr = Self {
inner,
header,
encoder: Encoder::with_policy(policy),
cblock: BlockWriter::new(header.block as usize, header.compressed),
};
if !headless {
wtr.init()?;
}
Ok(wtr)
}
fn init(&mut self) -> Result<()> {
self.header.write_bytes(&mut self.inner)?;
Ok(())
}
pub fn is_paired(&self) -> bool {
self.header.paired
}
pub fn has_quality(&self) -> bool {
self.header.qual
}
pub fn write_nucleotides(&mut self, flag: u64, sequence: &[u8]) -> Result<bool> {
if self.header.qual {
return Err(WriteError::QualityFlagSet.into());
}
if self.header.paired {
return Err(WriteError::PairedFlagSet.into());
}
if let Some(sbuffer) = self.encoder.encode_single(sequence)? {
let record_size = record_byte_size(sbuffer.len(), 0);
if self.cblock.exceeds_block_size(record_size)? {
self.cblock.flush(&mut self.inner)?;
}
self.cblock
.write_record(flag, sequence.len() as u64, 0, sbuffer, None, None, None)?;
Ok(true)
} else {
Ok(false)
}
}
pub fn write_nucleotides_paired(
&mut self,
flag: u64,
primary: &[u8],
extended: &[u8],
) -> Result<bool> {
if self.header.qual {
return Err(WriteError::QualityFlagSet.into());
}
if !self.header.paired {
return Err(WriteError::PairedFlagNotSet.into());
}
if let Some((sbuffer, xbuffer)) = self.encoder.encode_paired(primary, extended)? {
let record_size = record_byte_size(sbuffer.len(), xbuffer.len());
if self.cblock.exceeds_block_size(record_size)? {
self.cblock.flush(&mut self.inner)?;
}
self.cblock.write_record(
flag,
primary.len() as u64,
extended.len() as u64,
sbuffer,
None,
Some(xbuffer),
None,
)?;
Ok(true)
} else {
Ok(false)
}
}
pub fn write_nucleotides_quality(
&mut self,
flag: u64,
sequence: &[u8],
quality: &[u8],
) -> Result<bool> {
if !self.header.qual {
return Err(WriteError::QualityFlagNotSet.into());
}
if self.header.paired {
return Err(WriteError::PairedFlagSet.into());
}
if let Some(sbuffer) = self.encoder.encode_single(sequence)? {
let record_size = record_byte_size_quality(sbuffer.len(), 0, quality.len(), 0);
if self.cblock.exceeds_block_size(record_size)? {
self.cblock.flush(&mut self.inner)?;
}
self.cblock.write_record(
flag,
sequence.len() as u64,
0,
sbuffer,
Some(quality),
None,
None,
)?;
Ok(true)
} else {
Ok(false)
}
}
pub fn write_nucleotides_quality_paired(
&mut self,
flag: u64,
s_seq: &[u8],
x_seq: &[u8],
s_qual: &[u8],
x_qual: &[u8],
) -> Result<bool> {
if !self.header.qual {
return Err(WriteError::QualityFlagNotSet.into());
}
if !self.header.paired {
return Err(WriteError::PairedFlagNotSet.into());
}
if let Some((sbuffer, xbuffer)) = self.encoder.encode_paired(s_seq, x_seq)? {
let record_size =
record_byte_size_quality(sbuffer.len(), xbuffer.len(), s_qual.len(), x_qual.len());
if self.cblock.exceeds_block_size(record_size)? {
self.cblock.flush(&mut self.inner)?;
}
self.cblock.write_record(
flag,
s_seq.len() as u64,
x_seq.len() as u64,
sbuffer,
Some(s_qual),
Some(xbuffer),
Some(x_qual),
)?;
Ok(true)
} else {
Ok(false)
}
}
pub fn finish(&mut self) -> Result<()> {
self.cblock.flush(&mut self.inner)?;
self.inner.flush()?;
Ok(())
}
fn by_ref(&mut self) -> &mut W {
self.inner.by_ref()
}
fn cblock_mut(&mut self) -> &mut BlockWriter {
&mut self.cblock
}
pub fn ingest(&mut self, other: &mut VBinseqWriter<Vec<u8>>) -> Result<()> {
if self.header != other.header {
return Err(WriteError::IncompatibleHeaders(self.header, other.header).into());
}
{
self.inner.write_all(other.by_ref())?;
other.by_ref().clear();
}
{
self.cblock.ingest(other.cblock_mut(), &mut self.inner)?;
}
Ok(())
}
}
impl<W: Write> Drop for VBinseqWriter<W> {
fn drop(&mut self) {
self.finish()
.expect("VBinseqWriter: Failed to finish writing");
}
}
#[derive(Clone)]
struct BlockWriter {
pos: usize,
starts: Vec<usize>,
block_size: usize,
level: i32,
ubuf: Vec<u8>,
zbuf: Vec<u8>,
padding: Vec<u8>,
compress: bool,
}
impl BlockWriter {
fn new(block_size: usize, compress: bool) -> Self {
Self {
pos: 0,
starts: Vec::default(),
block_size,
level: 3,
ubuf: Vec::with_capacity(block_size),
zbuf: Vec::with_capacity(block_size),
padding: vec![0; block_size],
compress,
}
}
fn exceeds_block_size(&self, record_size: usize) -> Result<bool> {
if record_size > self.block_size {
return Err(WriteError::RecordSizeExceedsMaximumBlockSize(
record_size,
self.block_size,
)
.into());
}
Ok(self.pos + record_size > self.block_size)
}
#[allow(clippy::too_many_arguments)]
fn write_record(
&mut self,
flag: u64,
slen: u64,
xlen: u64,
sbuf: &[u64],
squal: Option<&[u8]>,
xbuf: Option<&[u64]>,
xqual: Option<&[u8]>,
) -> Result<()> {
self.starts.push(self.pos);
self.write_flag(flag)?;
self.write_length(slen)?;
self.write_length(xlen)?;
self.write_buffer(sbuf)?;
if let Some(qual) = squal {
self.write_quality(qual)?;
}
if let Some(xbuf) = xbuf {
self.write_buffer(xbuf)?;
}
if let Some(qual) = xqual {
self.write_quality(qual)?;
}
Ok(())
}
fn write_flag(&mut self, flag: u64) -> Result<()> {
self.ubuf.write_u64::<LittleEndian>(flag)?;
self.pos += 8;
Ok(())
}
fn write_length(&mut self, length: u64) -> Result<()> {
self.ubuf.write_u64::<LittleEndian>(length)?;
self.pos += 8;
Ok(())
}
fn write_buffer(&mut self, ebuf: &[u64]) -> Result<()> {
ebuf.iter()
.try_for_each(|&x| self.ubuf.write_u64::<LittleEndian>(x))?;
self.pos += 8 * ebuf.len();
Ok(())
}
fn write_quality(&mut self, quality: &[u8]) -> Result<()> {
self.ubuf.write_all(quality)?;
self.pos += quality.len();
Ok(())
}
fn flush_compressed<W: Write>(&mut self, inner: &mut W) -> Result<()> {
let mut encoder = ZstdEncoder::new(&mut self.zbuf, self.level)?;
encoder.write_all(&self.ubuf)?;
encoder.finish()?;
let header = BlockHeader::new(self.zbuf.len() as u64, self.starts.len() as u32);
header.write_bytes(inner)?;
inner.write_all(&self.zbuf)?;
Ok(())
}
fn flush_uncompressed<W: Write>(&mut self, inner: &mut W) -> Result<()> {
let header = BlockHeader::new(self.block_size as u64, self.starts.len() as u32);
header.write_bytes(inner)?;
inner.write_all(&self.ubuf)?;
Ok(())
}
fn flush<W: Write>(&mut self, inner: &mut W) -> Result<()> {
if self.pos == 0 {
return Ok(());
}
let bytes_to_next_start = self.block_size - self.pos;
self.ubuf.write_all(&self.padding[..bytes_to_next_start])?;
if self.compress {
self.flush_compressed(inner)?;
} else {
self.flush_uncompressed(inner)?;
}
self.clear();
Ok(())
}
fn clear(&mut self) {
self.pos = 0;
self.starts.clear();
self.ubuf.clear();
self.zbuf.clear();
}
fn ingest<W: Write>(&mut self, other: &mut Self, inner: &mut W) -> Result<()> {
if self.block_size != other.block_size {
return Err(
WriteError::IncompatibleBlockSizes(self.block_size, other.block_size).into(),
);
}
let remaining = self.block_size - self.pos;
if other.pos <= remaining {
self.ingest_all(other)
} else {
self.ingest_subset(other)?;
self.flush(inner)?;
self.ingest_all(other)
}
}
fn ingest_all(&mut self, other: &mut Self) -> Result<()> {
let n_bytes = other.pos;
self.ubuf.write_all(other.ubuf.drain(..).as_slice())?;
other
.starts
.drain(..)
.for_each(|start| self.starts.push(start + self.pos));
other.starts.iter_mut().for_each(|x| {
*x -= n_bytes;
});
self.pos += n_bytes;
other.clear();
Ok(())
}
fn ingest_subset(&mut self, other: &mut Self) -> Result<()> {
let remaining = self.block_size - self.pos;
let (start_index, end_byte) = other
.starts
.iter()
.enumerate()
.take_while(|(_idx, x)| **x <= remaining)
.last()
.map(|(idx, x)| (idx, *x))
.unwrap();
self.ubuf
.write_all(other.ubuf.drain(0..end_byte).as_slice())?;
other
.starts
.drain(0..start_index)
.for_each(|start| self.starts.push(start + self.pos));
other.starts.iter_mut().for_each(|x| {
*x -= end_byte;
});
self.pos += end_byte;
other.pos -= end_byte;
Ok(())
}
}
#[derive(Clone)]
pub struct Encoder {
sbuffer: Vec<u64>,
xbuffer: Vec<u64>,
s_ibuf: Vec<u8>,
x_ibuf: Vec<u8>,
policy: Policy,
rng: SmallRng,
}
impl Default for Encoder {
fn default() -> Self {
Self::with_policy(Policy::default())
}
}
impl Encoder {
pub fn new() -> Self {
Self::with_policy(Policy::default())
}
pub fn with_policy(policy: Policy) -> Self {
Self {
policy,
sbuffer: Vec::default(),
xbuffer: Vec::default(),
s_ibuf: Vec::default(),
x_ibuf: Vec::default(),
rng: SmallRng::seed_from_u64(RNG_SEED),
}
}
pub fn encode_single(&mut self, primary: &[u8]) -> Result<Option<&[u64]>> {
self.clear();
if bitnuc::encode(primary, &mut self.sbuffer).is_err() {
self.clear();
if self
.policy
.handle(primary, &mut self.s_ibuf, &mut self.rng)?
{
bitnuc::encode(&self.s_ibuf, &mut self.sbuffer)?;
} else {
return Ok(None);
}
}
Ok(Some(&self.sbuffer))
}
pub fn encode_paired(
&mut self,
primary: &[u8],
extended: &[u8],
) -> Result<Option<(&[u64], &[u64])>> {
self.clear();
if bitnuc::encode(primary, &mut self.sbuffer).is_err()
|| bitnuc::encode(extended, &mut self.xbuffer).is_err()
{
self.clear();
if self
.policy
.handle(primary, &mut self.s_ibuf, &mut self.rng)?
&& self
.policy
.handle(extended, &mut self.x_ibuf, &mut self.rng)?
{
bitnuc::encode(&self.s_ibuf, &mut self.sbuffer)?;
bitnuc::encode(&self.x_ibuf, &mut self.xbuffer)?;
} else {
return Ok(None);
}
}
Ok(Some((&self.sbuffer, &self.xbuffer)))
}
pub fn clear(&mut self) {
self.sbuffer.clear();
self.xbuffer.clear();
self.s_ibuf.clear();
self.x_ibuf.clear();
}
}
#[cfg(test)]
mod tests {
use crate::{header::SIZE_HEADER, *};
#[test]
fn test_headless_writer() -> crate::Result<()> {
let writer = VBinseqWriterBuilder::default()
.headless(true)
.build(Vec::new())?;
assert_eq!(writer.inner.len(), 0);
let writer = VBinseqWriterBuilder::default()
.headless(false)
.build(Vec::new())?;
assert_eq!(writer.inner.len(), SIZE_HEADER);
Ok(())
}
#[test]
fn test_ingest_empty_writer() -> crate::Result<()> {
let header = VBinseqHeader::new(false, false, false);
let mut source = VBinseqWriterBuilder::default()
.header(header)
.headless(true)
.build(Vec::new())?;
let mut dest = VBinseqWriterBuilder::default()
.header(header)
.headless(true)
.build(Vec::new())?;
dest.ingest(&mut source)?;
let source_vec = source.by_ref();
let dest_vec = dest.by_ref();
assert_eq!(source_vec.len(), 0);
assert_eq!(dest_vec.len(), 0);
Ok(())
}
#[test]
fn test_ingest_single_record() -> crate::Result<()> {
let header = VBinseqHeader::new(false, false, false);
let mut source = VBinseqWriterBuilder::default()
.header(header)
.headless(true)
.build(Vec::new())?;
let seq = b"ACGTACGTACGT";
source.write_nucleotides(1, seq)?;
assert!(source.by_ref().is_empty());
let mut dest = VBinseqWriterBuilder::default()
.header(header)
.headless(true)
.build(Vec::new())?;
dest.ingest(&mut source)?;
let source_vec = source.by_ref();
assert_eq!(source_vec.len(), 0);
let source_ubuf = &source.cblock.ubuf;
assert!(source_ubuf.is_empty());
let dest_vec = dest.by_ref();
assert!(dest_vec.is_empty());
let dest_ubuf = &dest.cblock.ubuf;
assert!(!dest_ubuf.is_empty());
Ok(())
}
#[test]
fn test_ingest_multi_record() -> crate::Result<()> {
let header = VBinseqHeader::new(false, false, false);
let mut source = VBinseqWriterBuilder::default()
.header(header)
.headless(true)
.build(Vec::new())?;
for _ in 0..30 {
let seq = b"ACGTACGTACGT";
source.write_nucleotides(1, seq)?;
}
assert!(source.by_ref().is_empty());
let mut dest = VBinseqWriterBuilder::default()
.header(header)
.headless(true)
.build(Vec::new())?;
dest.ingest(&mut source)?;
let source_vec = source.by_ref();
assert_eq!(source_vec.len(), 0);
let source_ubuf = &source.cblock.ubuf;
assert!(source_ubuf.is_empty());
let dest_vec = dest.by_ref();
assert!(dest_vec.is_empty());
let dest_ubuf = &dest.cblock.ubuf;
assert!(!dest_ubuf.is_empty());
Ok(())
}
#[test]
fn test_ingest_block_boundary() -> crate::Result<()> {
let header = VBinseqHeader::new(false, false, false);
let mut source = VBinseqWriterBuilder::default()
.header(header)
.headless(true)
.build(Vec::new())?;
for _ in 0..30000 {
let seq = b"ACGTACGTACGT";
source.write_nucleotides(1, seq)?;
}
assert!(!source.by_ref().is_empty());
let mut dest = VBinseqWriterBuilder::default()
.header(header)
.headless(true)
.build(Vec::new())?;
dest.ingest(&mut source)?;
let source_vec = source.by_ref();
assert_eq!(source_vec.len(), 0);
let source_ubuf = &source.cblock.ubuf;
assert!(source_ubuf.is_empty());
let dest_vec = dest.by_ref();
assert!(!dest_vec.is_empty());
let dest_ubuf = &dest.cblock.ubuf;
assert!(!dest_ubuf.is_empty());
Ok(())
}
#[test]
fn test_ingest_with_quality_scores() -> crate::Result<()> {
let source_header = VBinseqHeader::new(true, false, false); let dest_header = VBinseqHeader::new(true, false, false);
let mut source = VBinseqWriterBuilder::default()
.header(source_header)
.headless(true)
.build(Vec::new())?;
for i in 0..5 {
let seq = b"ACGTACGTACGT";
let qual = vec![40; seq.len()];
source.write_nucleotides_quality(i, seq, &qual)?;
}
let mut dest = VBinseqWriterBuilder::default()
.header(dest_header)
.headless(true)
.build(Vec::new())?;
dest.ingest(&mut source)?;
let source_vec = source.by_ref();
assert_eq!(source_vec.len(), 0);
let dest_ubuf = &dest.cblock.ubuf;
assert!(!dest_ubuf.is_empty());
Ok(())
}
#[test]
fn test_ingest_with_compression() -> crate::Result<()> {
let header = VBinseqHeader::new(false, true, false);
let mut source = VBinseqWriterBuilder::default()
.header(header)
.headless(true)
.build(Vec::new())?;
for _ in 0..30000 {
let seq = b"ACGTACGTACGT";
source.write_nucleotides(1, seq)?;
}
let mut dest = VBinseqWriterBuilder::default()
.header(header)
.headless(true)
.build(Vec::new())?;
dest.ingest(&mut source)?;
let source_vec = source.by_ref();
assert_eq!(source_vec.len(), 0);
let source_ubuf = &source.cblock.ubuf;
assert!(source_ubuf.is_empty());
let dest_vec = dest.by_ref();
assert!(!dest_vec.is_empty());
let dest_ubuf = &dest.cblock.ubuf;
assert!(!dest_ubuf.is_empty());
Ok(())
}
#[test]
fn test_ingest_incompatible_headers() -> crate::Result<()> {
let source_header = VBinseqHeader::new(false, false, false);
let dest_header = VBinseqHeader::new(true, false, false);
let mut source = VBinseqWriterBuilder::default()
.header(source_header)
.headless(true)
.build(Vec::new())?;
let mut dest = VBinseqWriterBuilder::default()
.header(dest_header)
.headless(true)
.build(Vec::new())?;
assert!(dest.ingest(&mut source).is_err());
Ok(())
}
}