#![forbid(unsafe_code)]
#![allow(
unused,
clippy::missing_panics_doc,
clippy::missing_errors_doc,
clippy::must_use_candidate,
clippy::module_name_repetitions
)]
#[cfg(feature = "bgzf_compressor")]
pub mod bgzf;
use std::time::Duration;
use std::{
error::Error,
io::{self, Read, Write},
sync::Arc,
thread::JoinHandle,
};
use bytes::{Bytes, BytesMut};
use flume::{self, bounded, Receiver, Sender};
use parking_lot::{lock_api::RawMutex, Mutex};
use thiserror::Error;
pub(crate) const BUFSIZE: usize = 128 * 1024;
type PoolResult<T> = Result<T, PoolError>;
#[non_exhaustive]
#[derive(Error, Debug)]
pub enum PoolError {
#[error("Failed to send over channel")]
ChannelSend,
#[error(transparent)]
ChannelReceive(#[from] flume::RecvError),
#[error("Error compressing data: {0}")]
CompressionError(String),
#[error(transparent)]
Io(#[from] io::Error),
}
#[derive(Debug)]
pub struct PooledWriter {
writer_index: usize,
compressor_tx: Sender<CompressorMessage>,
writer_tx: Sender<Receiver<WriterMessage>>,
buffer: BytesMut,
buffer_size: usize,
}
impl PooledWriter {
fn new<C>(
index: usize,
compressor_tx: Sender<CompressorMessage>,
writer_tx: Sender<Receiver<WriterMessage>>,
) -> Self
where
C: Compressor,
{
Self {
writer_index: index,
compressor_tx,
writer_tx,
buffer: BytesMut::with_capacity(C::BLOCK_SIZE),
buffer_size: C::BLOCK_SIZE,
}
}
#[inline]
fn buffer_full(&self) -> bool {
self.buffer.len() == self.buffer_size
}
fn flush_bytes(&mut self, is_last: bool) -> std::io::Result<()> {
if is_last || self.buffer_full() {
self.send_block(is_last)?;
}
Ok(())
}
fn send_block(&mut self, is_last: bool) -> std::io::Result<()> {
let bytes = self.buffer.split_to(self.buffer.len()).freeze();
let (mut m, r) = CompressorMessage::new_parts(self.writer_index, bytes);
m.is_last = is_last;
self.writer_tx
.send(r)
.map_err(|_e| io::Error::new(io::ErrorKind::Other, PoolError::ChannelSend))?;
self.compressor_tx
.send(m)
.map_err(|_e_| io::Error::new(io::ErrorKind::Other, PoolError::ChannelSend))
}
pub fn close(mut self) -> std::io::Result<()> {
self.flush_bytes(true)
}
}
impl Drop for PooledWriter {
fn drop(&mut self) {
self.flush_bytes(true).unwrap();
}
}
impl Write for PooledWriter {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
let mut bytes_added = 0;
while bytes_added < buf.len() {
let bytes_to_append =
std::cmp::min(buf.len() - bytes_added, self.buffer_size - self.buffer.len());
self.buffer.extend_from_slice(&buf[bytes_added..bytes_added + bytes_to_append]);
bytes_added += bytes_to_append;
if self.buffer_full() {
self.send_block(false)?;
}
}
Ok(buf.len())
}
fn flush(&mut self) -> std::io::Result<()> {
self.flush_bytes(false)
}
}
pub trait Compressor: Sized + Send + 'static
where
Self::CompressionLevel: Clone + Send + 'static,
Self::Error: Error + Send + 'static,
{
type Error;
type CompressionLevel;
const BLOCK_SIZE: usize = 65280;
fn new(compression_level: Self::CompressionLevel) -> Self;
fn default_compression_level() -> Self::CompressionLevel;
fn new_compression_level(compression_level: u8) -> Result<Self::CompressionLevel, Self::Error>;
fn compress(
&mut self,
input: &[u8],
output: &mut Vec<u8>,
is_last: bool,
) -> Result<(), Self::Error>;
}
#[derive(Debug)]
struct CompressorMessage {
writer_index: usize,
buffer: Bytes,
oneshot: Sender<WriterMessage>,
is_last: bool,
}
impl CompressorMessage {
fn new_parts(writer_index: usize, buffer: Bytes) -> (Self, Receiver<WriterMessage>) {
let (tx, rx) = flume::unbounded(); let new = Self { writer_index, buffer, oneshot: tx, is_last: false };
(new, rx)
}
}
#[derive(Debug)]
struct WriterMessage {
buffer: Vec<u8>,
}
pub struct PoolBuilder<W, C>
where
W: Write + Send + 'static,
C: Compressor,
{
writer_index: usize,
compression_level: C::CompressionLevel,
queue_size: Option<usize>,
threads: usize,
compressor_tx: Option<Sender<CompressorMessage>>,
compressor_rx: Option<Receiver<CompressorMessage>>,
writers: Vec<W>,
writer_txs: Vec<Sender<Receiver<WriterMessage>>>,
writer_rxs: Vec<Receiver<Receiver<WriterMessage>>>,
}
impl<W, C> PoolBuilder<W, C>
where
W: Write + Send + 'static,
C: Compressor,
{
pub const QUEUE_SIZE_THREAD_MULTIPLES: usize = 50;
pub const DEFAULT_THREADS: usize = 4;
pub fn new() -> Self {
PoolBuilder {
writer_index: 0,
compression_level: C::default_compression_level(),
queue_size: None,
threads: Self::DEFAULT_THREADS,
compressor_tx: None,
compressor_rx: None,
writers: vec![],
writer_txs: vec![],
writer_rxs: vec![],
}
}
pub fn threads(mut self, threads: usize) -> Self {
assert!(threads > 0, "Must provide a number of threads greater than 0.");
self.threads = threads;
self
}
pub fn queue_size(mut self, queue_size: usize) -> Self {
assert!(self.writers.is_empty(), "Cannot set queue_size after writers are exchanged.");
self.queue_size.insert(queue_size);
self
}
pub fn compression_level(mut self, level: u8) -> PoolResult<Self> {
self.compression_level = C::new_compression_level(level)
.map_err(|e| PoolError::CompressionError(e.to_string()))?;
Ok(self)
}
fn ensure_queue_is_setup(&mut self) {
if self.compressor_tx.is_none() && self.compressor_rx.is_none() {
if self.queue_size.is_none() {
self.queue_size.insert(self.threads * Self::QUEUE_SIZE_THREAD_MULTIPLES);
}
let (tx, rx) = bounded(self.queue_size.unwrap());
self.compressor_tx.insert(tx);
self.compressor_rx.insert(rx);
}
}
pub fn exchange(&mut self, writer: W) -> PooledWriter {
self.ensure_queue_is_setup();
let (tx, rx): (Sender<Receiver<WriterMessage>>, Receiver<Receiver<WriterMessage>>) =
flume::bounded(self.queue_size.expect("Unreachable"));
let p = PooledWriter::new::<C>(
self.writer_index,
self.compressor_tx.as_ref().expect("Unreachable").clone(),
tx.clone(),
);
self.writer_index += 1;
self.writers.push(writer);
self.writer_txs.push(tx);
self.writer_rxs.push(rx);
p
}
pub fn build(mut self) -> PoolResult<Pool> {
self.ensure_queue_is_setup();
let (shutdown_tx, shutdown_rx) = flume::unbounded();
let handle = std::thread::spawn(move || {
Pool::pool_main::<W, C>(
self.threads,
self.compression_level,
self.compressor_rx.expect("Unreachable."),
self.writer_rxs,
self.writers,
shutdown_rx,
)
});
let mut pool = Pool {
compressor_tx: self.compressor_tx,
shutdown_tx: Some(shutdown_tx),
pool_handle: Some(handle),
};
Ok(pool)
}
}
impl<W, C> Default for PoolBuilder<W, C>
where
W: Write + Send + 'static,
C: Compressor,
{
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
pub struct Pool {
pool_handle: Option<JoinHandle<PoolResult<()>>>,
compressor_tx: Option<Sender<CompressorMessage>>,
shutdown_tx: Option<Sender<()>>,
}
impl Pool {
#[allow(clippy::unnecessary_wraps, clippy::needless_collect, clippy::needless_pass_by_value)]
fn pool_main<W, C>(
num_threads: usize,
compression_level: C::CompressionLevel,
compressor_rx: Receiver<CompressorMessage>,
writer_rxs: Vec<Receiver<Receiver<WriterMessage>>>, writers: Vec<W>,
shutdown_rx: Receiver<()>,
) -> PoolResult<()>
where
W: Write + Send + 'static,
C: Compressor,
{
let writers: Arc<Vec<_>> =
Arc::new(writers.into_iter().map(|w| Arc::new(Mutex::new(w))).collect());
let (write_available_tx, write_available_rx): (Sender<usize>, Receiver<usize>) =
flume::unbounded();
let thread_handles: Vec<JoinHandle<PoolResult<()>>> = (0..num_threads)
.map(|thread_idx| {
let compressor_rx = compressor_rx.clone();
let mut compressor = C::new(compression_level.clone());
let writer_rxs = writer_rxs.clone();
let writers = writers.clone();
let shutdown_rx = shutdown_rx.clone();
let sleep_delay = Duration::from_millis(25);
let write_available_tx = write_available_tx.clone();
let write_available_rx = write_available_rx.clone();
std::thread::spawn(move || {
loop {
let mut did_something = false;
if let Ok(message) = compressor_rx.try_recv() {
let chunk = &message.buffer;
let mut compressed = Vec::new();
compressor
.compress(chunk, &mut compressed, message.is_last)
.map_err(|e| PoolError::CompressionError(e.to_string()))?;
message
.oneshot
.send(WriterMessage { buffer: compressed })
.map_err(|_e| PoolError::ChannelSend);
write_available_tx.send(message.writer_index);
did_something = true;
}
if let Ok(writer_index) = write_available_rx.try_recv() {
let mut writer = writers[writer_index].lock();
let writer_rx = &writer_rxs[writer_index];
let one_shot_rx = writer_rx.recv()?;
let write_message = one_shot_rx.recv()?;
writer.write_all(&write_message.buffer)?;
did_something = true;
}
if !did_something {
if shutdown_rx.is_disconnected()
&& write_available_rx.is_empty()
&& compressor_rx.is_empty()
&& writer_rxs.iter().all(|w| w.is_empty())
{
break;
} else {
std::thread::sleep(sleep_delay);
}
}
}
Ok(())
})
})
.collect();
thread_handles.into_iter().try_for_each(|handle| match handle.join() {
Ok(result) => result,
Err(e) => std::panic::resume_unwind(e),
});
writers.iter().try_for_each(|w| w.lock().flush())?;
Ok(())
}
pub fn stop_pool(&mut self) -> Result<(), PoolError> {
let compressor_queue = self.compressor_tx.take().unwrap();
while !compressor_queue.is_empty() {
}
drop(compressor_queue);
drop(self.shutdown_tx.take());
match self.pool_handle.take().unwrap().join() {
Ok(result) => result,
Err(e) => std::panic::resume_unwind(e),
}
}
}
impl Drop for Pool {
fn drop(&mut self) {
if self.compressor_tx.is_some() && self.pool_handle.is_some() {
self.stop_pool().unwrap();
}
}
}
#[cfg(test)]
mod test {
use std::{
assert_eq, format,
fs::File,
io::{BufReader, BufWriter},
path::{Path, PathBuf},
vec,
};
use crate::bgzf::BgzfCompressor;
use super::*;
use ::bgzf::Reader;
use proptest::prelude::*;
use tempfile::tempdir;
fn create_output_writer<P: AsRef<Path>>(path: P) -> BufWriter<File> {
BufWriter::new(File::create(path).unwrap())
}
fn create_output_file_name(name: impl AsRef<Path>, dir: impl AsRef<Path>) -> PathBuf {
let path = dir.as_ref().to_path_buf();
path.join(name)
}
#[test]
fn test_simple() {
let dir = tempdir().unwrap();
let output_names: Vec<PathBuf> = (0..20)
.into_iter()
.map(|i| create_output_file_name(format!("test.{}.txt.gz", i), &dir.path()))
.collect();
let output_writers: Vec<BufWriter<File>> =
output_names.iter().map(create_output_writer).collect();
let mut builder =
PoolBuilder::<_, BgzfCompressor>::new().threads(8).compression_level(2).unwrap();
let mut pooled_writers: Vec<PooledWriter> =
output_writers.into_iter().map(|w| builder.exchange(w)).collect();
let mut pool = builder.build().unwrap();
for (i, writer) in pooled_writers.iter_mut().enumerate() {
writer.write_all(format!("This is writer {}.", i).as_bytes()).unwrap();
}
pooled_writers.into_iter().try_for_each(|mut w| w.flush()).unwrap();
pool.stop_pool();
for (i, path) in output_names.iter().enumerate() {
let mut reader = Reader::new(BufReader::new(File::open(path).unwrap()));
let mut actual = vec![];
reader.read_to_end(&mut actual).unwrap();
assert_eq!(actual, format!("This is writer {}.", i).as_bytes());
}
}
proptest! {
#[ignore]
#[test]
fn test_complete(
input_size in 1..=BUFSIZE * 4,
buf_size in 1..=BUFSIZE,
num_output_files in 1..2*num_cpus::get(),
threads in 1..=2+num_cpus::get(),
comp_level in 1..=8_u8,
write_size in 1..=2*BUFSIZE,
) {
let dir = tempdir().unwrap();
let output_names: Vec<PathBuf> = (0..num_output_files)
.into_iter()
.map(|i| create_output_file_name(format!("test.{}.txt.gz", i), &dir.path()))
.collect();
let output_writers: Vec<_> = output_names.iter().map(create_output_writer).collect();
let mut builder = PoolBuilder::<_, BgzfCompressor>::new()
.threads(threads)
.compression_level(comp_level)?;
let mut pooled_writers: Vec<_> = output_writers.into_iter().map(|w| builder.exchange(w)).collect();
let mut pool = builder.build()?;
let inputs: Vec<Vec<u8>> = (0..num_output_files).map(|_| {
(0..input_size).map(|_| rand::random::<u8>()).collect()
}).collect();
let chunks = (input_size as f64 / write_size as f64).ceil() as usize;
for i in (0..chunks) {
for (j, writer) in pooled_writers.iter_mut().enumerate() {
let input = &inputs[j];
let bytes = &input[write_size * i..std::cmp::min(write_size * (i + 1), input.len())];
writer.write_all(bytes).unwrap()
}
}
pooled_writers.into_iter().try_for_each(|mut w| w.flush()).unwrap();
pool.stop_pool();
for (i, path) in output_names.iter().enumerate() {
let mut reader = Reader::new(BufReader::new(File::open(path).unwrap()));
let mut actual = vec![];
reader.read_to_end(&mut actual).unwrap();
assert_eq!(actual, inputs[i]);
}
}
}
}