use std::io::{self, Write};
use std::mem;
use std::num::NonZero;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::thread::{self, JoinHandle};
use bytes::{Bytes, BytesMut};
use crossbeam_channel::{Receiver, Sender, bounded, unbounded};
use bgzf::{BgzfError, CompressionLevel, Compressor};
const BGZF_BLOCK_SIZE: usize = 65280;
static BGZF_EOF: &[u8] = &[
0x1f, 0x8b, 0x08, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0x06, 0x00, 0x42, 0x43, 0x02, 0x00, 0x1b, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, ];
type BgzfResult<T> = Result<T, BgzfError>;
#[derive(Debug, Clone, Copy)]
pub struct BlockInfo {
pub block_number: u64,
pub compressed_start: u64,
pub compressed_size: usize,
pub uncompressed_size: usize,
}
pub type BlockInfoRx = Receiver<BlockInfo>;
type FrameParts = (Vec<u8>, u32, usize); type BufferedTx = Sender<BgzfResult<FrameParts>>;
type BufferedRx = Receiver<BgzfResult<FrameParts>>;
type DeflateTx = Sender<(Bytes, BufferedTx)>;
type DeflateRx = Receiver<(Bytes, BufferedTx)>;
type WriteTx = Sender<(BufferedRx, u64)>;
type WriteRx = Receiver<(BufferedRx, u64)>;
type BlockInfoTx = Sender<BlockInfo>;
enum State<W> {
Running {
writer_handle: JoinHandle<BgzfResult<W>>,
deflater_handles: Vec<JoinHandle<()>>,
write_tx: WriteTx,
deflate_tx: DeflateTx,
block_info_rx: BlockInfoRx,
},
Done,
}
pub struct MultithreadedWriter<W>
where
W: Write + Send + 'static,
{
state: State<W>,
buf: BytesMut,
blocksize: usize,
current_block_number: u64,
position: Arc<AtomicU64>,
blocks_written: Arc<AtomicU64>,
}
impl<W> MultithreadedWriter<W>
where
W: Write + Send + 'static,
{
pub fn new(inner: W, compression_level: CompressionLevel) -> Self {
Self::with_worker_count(NonZero::<usize>::MIN, inner, compression_level)
}
pub fn with_worker_count(
worker_count: NonZero<usize>,
inner: W,
compression_level: CompressionLevel,
) -> Self {
Self::with_capacity(worker_count, inner, compression_level, BGZF_BLOCK_SIZE)
}
fn with_capacity(
worker_count: NonZero<usize>,
inner: W,
compression_level: CompressionLevel,
blocksize: usize,
) -> Self {
let position = Arc::new(AtomicU64::new(0));
let blocks_written = Arc::new(AtomicU64::new(0));
let (deflate_tx, deflate_rx) = bounded(worker_count.get());
let (write_tx, write_rx) = bounded(worker_count.get());
let (block_info_tx, block_info_rx) = unbounded();
let deflater_handles = spawn_deflaters(compression_level, worker_count, deflate_rx);
let writer_handle = spawn_writer(
inner,
write_rx,
Arc::clone(&position),
Arc::clone(&blocks_written),
block_info_tx,
);
Self {
state: State::Running {
writer_handle,
deflater_handles,
write_tx,
deflate_tx,
block_info_rx,
},
buf: BytesMut::with_capacity(blocksize),
blocksize,
current_block_number: 0,
position,
blocks_written,
}
}
#[must_use]
pub fn block_info_receiver(&self) -> Option<&BlockInfoRx> {
match &self.state {
State::Running { block_info_rx, .. } => Some(block_info_rx),
State::Done => None,
}
}
#[inline]
#[must_use]
pub fn current_block_number(&self) -> u64 {
self.current_block_number
}
#[inline]
#[must_use]
pub fn buffer_offset(&self) -> usize {
self.buf.len()
}
#[inline]
#[must_use]
pub fn position(&self) -> u64 {
self.position.load(Ordering::Acquire)
}
#[inline]
#[must_use]
pub fn blocks_written(&self) -> u64 {
self.blocks_written.load(Ordering::Acquire)
}
pub fn finish(&mut self) -> BgzfResult<W> {
if !self.buf.is_empty() {
self.send()?;
}
let state = mem::replace(&mut self.state, State::Done);
match state {
State::Running {
writer_handle,
mut deflater_handles,
write_tx,
deflate_tx,
block_info_rx: _,
} => {
drop(deflate_tx);
for handle in deflater_handles.drain(..) {
handle
.join()
.map_err(|_| BgzfError::Io(io::Error::other("Deflater thread panicked")))?;
}
drop(write_tx);
writer_handle
.join()
.map_err(|_| BgzfError::Io(io::Error::other("Writer thread panicked")))?
}
State::Done => Err(BgzfError::Io(io::Error::other("finish() called twice"))),
}
}
#[inline]
fn remaining(&self) -> usize {
self.blocksize.saturating_sub(self.buf.len())
}
#[inline]
fn has_remaining(&self) -> bool {
self.remaining() > 0
}
fn send(&mut self) -> BgzfResult<()> {
if self.buf.is_empty() {
return Ok(());
}
let State::Running { write_tx, deflate_tx, .. } = &self.state else {
return Err(BgzfError::Io(io::Error::other("Writer already finished")));
};
let data = self.buf.split().freeze();
let (buffered_tx, buffered_rx) = bounded(1);
deflate_tx
.send((data, buffered_tx))
.map_err(|_| BgzfError::Io(io::Error::other("Deflate channel closed")))?;
write_tx
.send((buffered_rx, self.current_block_number))
.map_err(|_| BgzfError::Io(io::Error::other("Write channel closed")))?;
self.current_block_number += 1;
Ok(())
}
}
impl<W> Write for MultithreadedWriter<W>
where
W: Write + Send + 'static,
{
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let amt = self.remaining().min(buf.len());
self.buf.extend_from_slice(&buf[..amt]);
if !self.has_remaining() {
self.send().map_err(|e| io::Error::other(e.to_string()))?;
}
Ok(amt)
}
fn flush(&mut self) -> io::Result<()> {
self.send().map_err(|e| io::Error::other(e.to_string()))
}
}
impl<W> Drop for MultithreadedWriter<W>
where
W: Write + Send + 'static,
{
fn drop(&mut self) {
if matches!(self.state, State::Running { .. }) {
let _ = self.finish();
}
}
}
#[derive(Debug, Clone)]
pub struct Builder {
compression_level: CompressionLevel,
worker_count: NonZero<usize>,
blocksize: usize,
}
impl Default for Builder {
fn default() -> Self {
Self {
compression_level: CompressionLevel::new(6)
.expect("compression level 6 is always valid"),
worker_count: NonZero::<usize>::MIN,
blocksize: BGZF_BLOCK_SIZE,
}
}
}
impl Builder {
#[must_use]
pub fn set_compression_level(mut self, level: CompressionLevel) -> Self {
self.compression_level = level;
self
}
#[must_use]
pub fn set_worker_count(mut self, count: NonZero<usize>) -> Self {
self.worker_count = count;
self
}
#[must_use]
pub fn set_blocksize(mut self, size: usize) -> Self {
self.blocksize = size.min(BGZF_BLOCK_SIZE);
self
}
pub fn build_from_writer<W>(self, writer: W) -> MultithreadedWriter<W>
where
W: Write + Send + 'static,
{
MultithreadedWriter::with_capacity(
self.worker_count,
writer,
self.compression_level,
self.blocksize,
)
}
}
#[allow(clippy::needless_pass_by_value)]
fn spawn_deflaters(
compression_level: CompressionLevel,
worker_count: NonZero<usize>,
deflate_rx: DeflateRx,
) -> Vec<JoinHandle<()>> {
(0..worker_count.get())
.map(|_| {
let rx = deflate_rx.clone();
let level = compression_level;
thread::spawn(move || {
let mut compressor = Compressor::new(level);
let mut compress_buf = Vec::new();
while let Ok((data, result_tx)) = rx.recv() {
let result = compress_block(&mut compressor, &data, &mut compress_buf);
let _ = result_tx.send(result);
}
})
})
.collect()
}
fn compress_block(
compressor: &mut Compressor,
input: &[u8],
buffer: &mut Vec<u8>,
) -> BgzfResult<FrameParts> {
buffer.clear();
compressor.compress(input, buffer)?;
let crc32 = u32::from_le_bytes([
buffer[buffer.len() - 8],
buffer[buffer.len() - 7],
buffer[buffer.len() - 6],
buffer[buffer.len() - 5],
]);
Ok((buffer.clone(), crc32, input.len()))
}
fn spawn_writer<W>(
mut writer: W,
write_rx: WriteRx,
position: Arc<AtomicU64>,
blocks_written: Arc<AtomicU64>,
block_info_tx: BlockInfoTx,
) -> JoinHandle<BgzfResult<W>>
where
W: Write + Send + 'static,
{
thread::spawn(move || {
while let Ok((buffered_rx, block_number)) = write_rx.recv() {
let (compressed_data, _crc32, uncompressed_size) = buffered_rx
.recv()
.map_err(|_| BgzfError::Io(io::Error::other("Compression channel closed")))??;
let compressed_start = position.load(Ordering::Acquire);
let compressed_size = compressed_data.len();
writer.write_all(&compressed_data)?;
position.fetch_add(compressed_size as u64, Ordering::Release);
blocks_written.fetch_add(1, Ordering::Release);
let _ = block_info_tx.send(BlockInfo {
block_number,
compressed_start,
compressed_size,
uncompressed_size,
});
}
writer.write_all(BGZF_EOF)?;
position.fetch_add(BGZF_EOF.len() as u64, Ordering::Release);
Ok(writer)
})
}
#[cfg(test)]
mod tests {
use super::*;
use bgzf::Reader;
use std::io::Read;
#[test]
fn test_new_and_finish() {
let mut writer = MultithreadedWriter::new(
Vec::new(),
CompressionLevel::new(6).expect("valid compression level 6"),
);
let result = writer.finish();
assert!(result.is_ok());
}
#[test]
fn test_roundtrip() {
let mut writer = MultithreadedWriter::new(
Vec::new(),
CompressionLevel::new(6).expect("valid compression level 6"),
);
writer.write_all(b"hello world").expect("write_all should succeed");
let data = writer.finish().expect("finish should succeed");
let mut reader = Reader::new(&data[..]);
let mut buf = Vec::new();
reader.read_to_end(&mut buf).expect("read_to_end should succeed");
assert_eq!(buf, b"hello world");
}
#[test]
fn test_position_tracking_initial() {
let writer = MultithreadedWriter::new(
Vec::new(),
CompressionLevel::new(6).expect("valid compression level 6"),
);
assert_eq!(writer.position(), 0);
assert_eq!(writer.current_block_number(), 0);
assert_eq!(writer.blocks_written(), 0);
assert_eq!(writer.buffer_offset(), 0);
}
#[test]
fn test_position_tracking_after_writes() {
let mut writer = MultithreadedWriter::new(
Vec::new(),
CompressionLevel::new(6).expect("valid compression level 6"),
);
writer.write_all(b"hello").expect("write_all should succeed");
assert_eq!(writer.buffer_offset(), 5);
writer.flush().expect("flush should succeed");
assert_eq!(writer.current_block_number(), 1);
assert_eq!(writer.buffer_offset(), 0);
writer.finish().expect("finish should succeed");
}
#[test]
fn test_block_info_notifications() {
let mut writer = MultithreadedWriter::with_worker_count(
NonZero::new(2).expect("non-zero value 2"),
Vec::new(),
CompressionLevel::new(6).expect("valid compression level 6"),
);
let rx = writer
.block_info_receiver()
.expect("block_info_receiver should return a receiver")
.clone();
writer.write_all(b"block1").expect("write_all should succeed");
writer.flush().expect("flush should succeed");
writer.write_all(b"block2").expect("write_all should succeed");
writer.flush().expect("flush should succeed");
writer.finish().expect("finish should succeed");
let infos: Vec<_> = rx.try_iter().collect();
assert_eq!(infos.len(), 2);
assert_eq!(infos[0].block_number, 0);
assert_eq!(infos[1].block_number, 1);
assert!(infos[1].compressed_start > 0);
}
#[test]
fn test_multiple_workers() {
for worker_count in [1, 2, 4] {
let mut writer = MultithreadedWriter::with_worker_count(
NonZero::new(worker_count).expect("worker_count is non-zero"),
Vec::new(),
CompressionLevel::new(6).expect("valid compression level 6"),
);
for i in 0..20 {
writer.write_all(format!("block{i}").as_bytes()).expect("write_all should succeed");
writer.flush().expect("flush should succeed");
}
let data = writer.finish().expect("finish should succeed");
let mut reader = Reader::new(&data[..]);
let mut buf = String::new();
reader.read_to_string(&mut buf).expect("read_to_string should succeed");
for i in 0..20 {
assert!(buf.contains(&format!("block{i}")));
}
}
}
#[test]
fn test_large_data() {
let mut writer = MultithreadedWriter::with_worker_count(
NonZero::new(4).expect("non-zero value 4"),
Vec::new(),
CompressionLevel::new(6).expect("valid compression level 6"),
);
let large_data = vec![b'A'; BGZF_BLOCK_SIZE * 5];
writer.write_all(&large_data).expect("write_all should succeed");
let compressed = writer.finish().expect("finish should succeed");
let mut reader = Reader::new(&compressed[..]);
let mut decompressed = Vec::new();
reader.read_to_end(&mut decompressed).expect("read_to_end should succeed");
assert_eq!(decompressed, large_data);
}
#[test]
fn test_eof_marker() {
let mut writer = MultithreadedWriter::new(
Vec::new(),
CompressionLevel::new(6).expect("valid compression level 6"),
);
writer.write_all(b"test").expect("write_all should succeed");
let data = writer.finish().expect("finish should succeed");
assert!(data.ends_with(BGZF_EOF));
}
#[test]
fn test_builder() {
let writer = Builder::default()
.set_compression_level(CompressionLevel::new(9).expect("valid compression level 9"))
.set_worker_count(NonZero::new(4).expect("non-zero value 4"))
.set_blocksize(32768)
.build_from_writer(Vec::new());
assert!(writer.block_info_receiver().is_some());
}
}