Skip to main content

pooled_writer/
lib.rs

1//! A pooled writer and compressor.
2//!
3//! # Overview
4//!
5//! `pooled-writer` solves the problem of compressing and writing data to a set of writers using
6//! multiple threads, where the number of writers and threads cannot easily be equal.  For example
7//! writing to hundreds of gzipped files using 16 threads, or writing to a four gzipped files
8//! using 32 threads.
9//!
10//! To accomplish this, a pool is configured and writers are exchanged for [`PooledWriter`]s
11//! that can be used in place of the original writers.  This is accomplished using the
12//! [`PoolBuilder`] which is the preferred way to configure and create a pool.  The [`Pool`] and
13//! builder require two generic types: the `W` Writer type and the `C` compressor type. `W` may
14//! usually be elided if calls to [`PoolBuilder::exchange`] may be used to infer the type. `C`
15//! must be specified as something that implements [`Compressor`].
16//!
17//! The [`Pool`] consists of a single thread pool that consumes work from both a compression queue
18//! and a writing queue.  All concurrency is managed via message passing over channels.
19//!
20//! Every time the internal buffer of a [`PooledWriter`] reaches capacity (defined by
21//! [`Compressor::BLOCK_SIZE`]) it sends two messages:
22//! 1. It sends a message over the corresponding writer's channel to the writer pool, enqueueing
23//!    a one-shot receiver channel in the writers queue that will receive the compressed bytes
24//!    once the compressor is done. This is done to maintain the output order.
25//! 2. It sends a message to the compressor pool that contains a buffer of bytes to compress
26//!    as well as the sender side of the one-shot channel to send the compressed bytes on.
27//!
28//! The threads in the thread pool loop continuously until the pool is shut down, and attempt
29//! first receive and compress one block, then secondly to receive and write one compressed block.
30//! A third internal channel is used to manage the queue of writes to be performed so that the
31//! individual per-writer channels (of which there may be many) are only polled if there is likely
32//! to be data available for writing.  When data is available to be written, the appropriate
33//! underlying writer is locked, and the data written.
34//!
35//! When all writing to [`PooledWriter`]s is complete, the writers should be close()'d or drop()'d
36//! and then the pool should be stopped using [`Pool::stop_pool`].  Writers that are not closed
37//! may have data buffered that is never written!  
38//!
39//! [`Pool::stop_pool`] will shutdown channels in a safe order ensuring that data submitted to the
40//! pool is compressed and written before threads are stopped.  After initiating the pool shutdown
41//! any subsequent attempts to write to [`PooledWriter`]s will result in errors.  Likewise any
42//! calls to [`PooledWriter:close`] that cause data to be flushed into the compression queue will
43//! raise errors.
44//!
45//! # Example
46//!
47//! ```rust
48//! use std::{
49//!     error::Error,
50//!     fs::File,
51//!     io::{BufWriter, Write},
52//!     path::Path,
53//! };
54//!
55//! use pooled_writer::{Compressor, PoolBuilder, Pool, bgzf::BgzfCompressor};
56//!
57//! type DynError = Box<dyn Error + 'static>;
58//!
59//! fn create_writer<P: AsRef<Path>>(name: P) -> Result<BufWriter<File>, DynError> {
60//!     Ok(BufWriter::new(File::create(name)?))
61//! }
62//!
63//! fn main() -> Result<(), DynError> {
64//!     let writers = vec![
65//!         create_writer("/tmp/test1.txt.gz")?,
66//!         create_writer("/tmp/test2.txt.gz")?,
67//!         create_writer("/tmp/test3.txt.gz")?,
68//!     ];
69//!
70//!     let mut builder = PoolBuilder::<_, BgzfCompressor>::new()
71//!         .threads(8)
72//!         .compression_level(5)?;
73//!
74//!    let mut pooled_writers = writers.into_iter().map(|w| builder.exchange(w)).collect::<Vec<_>>();
75//!    let mut pool = builder.build()?;
76//!
77//!     writeln!(&mut pooled_writers[1], "This is writer2")?;
78//!     writeln!(&mut pooled_writers[0], "This is writer1")?;
79//!     writeln!(&mut pooled_writers[2], "This is writer3")?;
80//!     pooled_writers.into_iter().try_for_each(|w| w.close())?;
81//!     pool.stop_pool()?;
82//!
83//!     Ok(())
84//! }
85//! ```
86#![forbid(unsafe_code)]
87#![allow(
88    unused,
89    clippy::missing_panics_doc,
90    clippy::missing_errors_doc,
91    clippy::must_use_candidate,
92    clippy::module_name_repetitions
93)]
94
95#[cfg(feature = "bgzf_compressor")]
96pub mod bgzf;
97
98use std::time::Duration;
99use std::{
100    error::Error,
101    io::{self, Read, Write},
102    sync::Arc,
103    thread::JoinHandle,
104};
105
106use bytes::{Bytes, BytesMut};
107use flume::{self, Receiver, Sender, bounded};
108use parking_lot::{Mutex, lock_api::RawMutex};
109use thiserror::Error;
110
111/// 128 KB default buffer size, same as pigz.
112pub(crate) const BUFSIZE: usize = 128 * 1024;
113
114/// Convenience type for functions that return [`PoolError`].
115type PoolResult<T> = Result<T, PoolError>;
116
117/// Represents errors that may be generated by any `Pool` related functionality.
118#[non_exhaustive]
119#[derive(Error, Debug)]
120pub enum PoolError {
121    #[error("Failed to send over channel")]
122    ChannelSend,
123    #[error(transparent)]
124    ChannelReceive(#[from] flume::RecvError),
125
126    // TODO: figure out how to better pass in an generic / dynamic error type to this.
127    #[error("Error compressing data: {0}")]
128    CompressionError(String),
129    #[error(transparent)]
130    Io(#[from] io::Error),
131}
132
133////////////////////////////////////////////////////////////////////////////////
134// The PooledWriter and it's impls
135////////////////////////////////////////////////////////////////////////////////
136
137/// A [`PooledWriter`] is created by exchanging a writer with a [`Pool`].
138///
139/// The pooled writer will internally buffer writes, sending bytes to the [`Pool`]
140/// after the internal buffer has been filled.
141///
142/// Note that the `compressor_tx` channel is shared by all pooled writers, whereas the `writer_tx`
143/// is specific to the _underlying_ writer that this pooled writer encapsulates.
144#[derive(Debug)]
145pub struct PooledWriter {
146    /// The index/serial number of the pooled writer within the pool
147    writer_index: usize,
148    /// Channel to send messages containing bytes to compress to the compressors' pool.
149    compressor_tx: Sender<CompressorMessage>,
150    /// Channel to send the receiving end of the one-shot channel that will be
151    /// used to send the compressed bytes. This effectively "place holds" the
152    /// position of the compressed bytes in the writers queue until the compressed bytes
153    /// are ready.
154    writer_tx: Sender<oneshot::Receiver<WriterMessage>>,
155    /// The internal buffer to gather bytes to send.
156    buffer: BytesMut,
157    /// The desired size of the internal buffer.
158    buffer_size: usize,
159}
160
161impl PooledWriter {
162    /// Create a new [`PooledWriter`] that has an internal buffer capacity that matches [`bgzf::BGZF_BLOCK_SIZE`].
163    ///
164    /// # Arguments
165    /// - `index` - a usize representing that this is the nth pooled writer created within the pool
166    /// - `compressor_tx` - The channel to send uncompressed bytes to the compressor pool.
167    /// - `writer_tx` - The `Send` end of the channel that transmits the `Receiver` end of the one-shot
168    ///                 channel, which will be consumed when the compressor sends the compressed bytes.
169    fn new<C>(
170        index: usize,
171        compressor_tx: Sender<CompressorMessage>,
172        writer_tx: Sender<oneshot::Receiver<WriterMessage>>,
173    ) -> Self
174    where
175        C: Compressor,
176    {
177        Self {
178            writer_index: index,
179            compressor_tx,
180            writer_tx,
181            buffer: BytesMut::with_capacity(C::BLOCK_SIZE),
182            buffer_size: C::BLOCK_SIZE,
183        }
184    }
185
186    /// Test whether the internal buffer has reached capacity.
187    #[inline]
188    fn buffer_full(&self) -> bool {
189        self.buffer.len() == self.buffer_size
190    }
191
192    /// Send all bytes in the current buffer to the compressor.
193    ///
194    /// If `is_last` is `true`, the message sent to the compressor will also have the `is_last` true flag set
195    /// and the compressor will finish the BGZF stream.
196    ///
197    /// If `is_last` is not true then only full block will be sent. If `is_last` is true, an incomplete block may be set
198    /// as the final block.
199    fn flush_bytes(&mut self, is_last: bool) -> std::io::Result<()> {
200        if is_last || self.buffer_full() {
201            self.send_block(is_last)?;
202        }
203        Ok(())
204    }
205
206    /// Send a single block
207    fn send_block(&mut self, is_last: bool) -> std::io::Result<()> {
208        let bytes = self.buffer.split_to(self.buffer.len()).freeze();
209        let (mut m, r) = CompressorMessage::new_parts(self.writer_index, bytes);
210        m.is_last = is_last;
211        self.writer_tx
212            .send(r)
213            .map_err(|_e| io::Error::new(io::ErrorKind::Other, PoolError::ChannelSend))?;
214        self.compressor_tx
215            .send(m)
216            .map_err(|_e_| io::Error::new(io::ErrorKind::Other, PoolError::ChannelSend))
217    }
218
219    /// Flush any remaining bytes and consume self, triggering drops of the senders.
220    pub fn close(mut self) -> std::io::Result<()> {
221        self.flush_bytes(true)
222    }
223}
224
225impl Drop for PooledWriter {
226    /// Drop [`PooledWriter`].
227    ///
228    /// This will flush the writer.
229    fn drop(&mut self) {
230        self.flush_bytes(true).unwrap();
231    }
232}
233
234impl Write for PooledWriter {
235    /// Send all bytes in `buf` to the [`Pool`].
236    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
237        let mut bytes_added = 0;
238
239        while bytes_added < buf.len() {
240            let bytes_to_append =
241                std::cmp::min(buf.len() - bytes_added, self.buffer_size - self.buffer.len());
242
243            self.buffer.extend_from_slice(&buf[bytes_added..bytes_added + bytes_to_append]);
244            bytes_added += bytes_to_append;
245            if self.buffer_full() {
246                self.send_block(false)?;
247            }
248        }
249
250        Ok(buf.len())
251    }
252
253    /// Send whatever is in the current buffer even if it is not a full buffer.
254    fn flush(&mut self) -> std::io::Result<()> {
255        self.flush_bytes(false)
256    }
257}
258
259////////////////////////////////////////////////////////////////////////////////
260// The Compressor trait
261////////////////////////////////////////////////////////////////////////////////
262
263/// A [`Compressor`] is used in the compressor pool to compress bytes.
264///
265/// An implementation must be provided as a type to the [`Pool::new`] function so that the pool
266/// knows what kind of compression to use.
267///
268/// See the module level example for more details.
269pub trait Compressor: Sized + Send + 'static
270where
271    Self::CompressionLevel: Clone + Send + 'static,
272    Self::Error: Error + Send + 'static,
273{
274    type Error;
275    type CompressionLevel;
276
277    /// The `BLOCK_SIZE` is used to set the buffer size of the [`PooledWriter`]s and should match the max
278    /// size allowed by the block compression format being used.
279    const BLOCK_SIZE: usize = 65280;
280
281    /// Create a new compressor with the given compression level.
282    fn new(compression_level: Self::CompressionLevel) -> Self;
283
284    /// Returns the default compression level for the compressor.
285    fn default_compression_level() -> Self::CompressionLevel;
286
287    /// Create an instance of the compression level.
288    ///
289    /// The validity of the compression level should be checked here.
290    fn new_compression_level(compression_level: u8) -> Result<Self::CompressionLevel, Self::Error>;
291
292    /// Compress a set of bytes into the `output` vec. If `is_last` is true, and depending on the
293    /// block compression format, an EOF block may be appended as well.
294    fn compress(
295        &mut self,
296        input: &[u8],
297        output: &mut Vec<u8>,
298        is_last: bool,
299    ) -> Result<(), Self::Error>;
300}
301
302////////////////////////////////////////////////////////////////////////////////
303// The messages passed between threads
304////////////////////////////////////////////////////////////////////////////////
305
306/// A message that is sent from a [`PooledWriter`] to the compressor threadpool within a [`Pool`].
307struct CompressorMessage {
308    /// The index of the destination writer
309    writer_index: usize,
310    /// The bytes to compress.
311    buffer: Bytes,
312    /// Where the compressed bytes will be sent after compression.
313    oneshot_tx: oneshot::Sender<WriterMessage>,
314    /// A sentinel value to let the compressor know that the BGZF stream needs an EOF.
315    is_last: bool,
316}
317
318impl CompressorMessage {
319    fn new_parts(writer_index: usize, buffer: Bytes) -> (Self, oneshot::Receiver<WriterMessage>) {
320        let (tx, rx) = oneshot::channel();
321        let new = Self { writer_index, buffer, oneshot_tx: tx, is_last: false };
322        (new, rx)
323    }
324}
325
326/// The compressed bytes to be written to a file.
327///
328/// This is sent from the compressor threadpool to the writer queue in the writer threadpool
329/// via the one-shot channel provided by the [`PooledWriter`].
330#[derive(Debug)]
331struct WriterMessage {
332    buffer: Vec<u8>,
333}
334
335/// Internal enum used by worker threads to dispatch between compression and write work.
336enum WorkItem {
337    Compress(CompressorMessage),
338    Write(usize),
339}
340
341////////////////////////////////////////////////////////////////////////////////
342// The PoolBuilder struct and impls
343////////////////////////////////////////////////////////////////////////////////
344
345/// A struct to make building up a Pool simpler.  The builder should be constructed using
346/// [`PoolBuilder::new`], which provides the user control over the sizes of the queues used for
347/// compression and writing.  It should be noted that a single compression queue is created,
348/// and one writer queue per writer exchanged.  A good starting point for these queue sizes is
349/// two times the number of threads.
350///
351/// Once created various functions can configure aspects of the pool.  It is best practice, though
352/// not required, to configure the builder _before_ exchanging writers.  The exception is
353/// `queue_size` that may _not_ be set after any writers have been exchanged.  If not set manually
354/// then `queue_size` defaults to the number of threads multiplied by
355/// [`PoolBuilder::QUEUE_SIZE_THREAD_MULTIPLES`].
356///
357/// Once the builder is configured writers may be exchanged for [`PooledWriter`]s using the
358/// [`PoolBuilder::exchange`] function, which consumes the provided writer and returns a new
359/// writer that can be used in it's place.
360///
361/// After exchanging all writers the pool may be created and started with [`PoolBuilder::build`]
362/// which consumes the builder and after which no more writers may be exchanged.
363pub struct PoolBuilder<W, C>
364where
365    W: Write + Send + 'static,
366    C: Compressor,
367{
368    writer_index: usize,
369    compression_level: C::CompressionLevel,
370    queue_size: Option<usize>,
371    threads: usize,
372    compressor_tx: Option<Sender<CompressorMessage>>,
373    compressor_rx: Option<Receiver<CompressorMessage>>,
374    writers: Vec<W>,
375    writer_txs: Vec<Sender<oneshot::Receiver<WriterMessage>>>,
376    writer_rxs: Vec<Receiver<oneshot::Receiver<WriterMessage>>>,
377}
378
379impl<W, C> PoolBuilder<W, C>
380where
381    W: Write + Send + 'static,
382    C: Compressor,
383{
384    /// By default queue sizes will be set to threads * this constant.
385    pub const QUEUE_SIZE_THREAD_MULTIPLES: usize = 50;
386
387    /// The default number of threads that will be used if not otherwise configured
388    pub const DEFAULT_THREADS: usize = 4;
389
390    /// Creates a new PoolBuilder that can be used to configure and build a [`Pool`].
391    pub fn new() -> Self {
392        PoolBuilder {
393            writer_index: 0,
394            compression_level: C::default_compression_level(),
395            queue_size: None,
396            threads: Self::DEFAULT_THREADS,
397            compressor_tx: None,
398            compressor_rx: None,
399            writers: vec![],
400            writer_txs: vec![],
401            writer_rxs: vec![],
402        }
403    }
404
405    /// Sets the number of threads that will be used by the [[Pool]].
406    ///
407    /// Will panic if set to 0.
408    pub fn threads(mut self, threads: usize) -> Self {
409        assert!(threads > 0, "Must provide a number of threads greater than 0.");
410        self.threads = threads;
411        self
412    }
413
414    /// Sets the size of queues used by the pool [[Pool]].  The same size is used for
415    /// a) the queue of byte buffers to be compressed, b) the per-sample queues to receive
416    /// compressed bytes, and c) a control queue to manage writing to the underlying writers.
417    ///
418    /// In the worst case scenario the pool can be holding both queue_size uncompressed blocks
419    /// _and_ queue_size compressed blocks in memory when it cannot keep up with the incoming
420    /// load of writes.
421    ///
422    ///
423    ///
424    /// Will panic if called _after_ writers have been created because queues will already have
425    /// been created.
426    pub fn queue_size(mut self, queue_size: usize) -> Self {
427        assert!(self.writers.is_empty(), "Cannot set queue_size after writers are exchanged.");
428        self.queue_size.insert(queue_size);
429        self
430    }
431
432    /// Sets the compression level that will be used by the [[Pool]].
433    pub fn compression_level(mut self, level: u8) -> PoolResult<Self> {
434        self.compression_level = C::new_compression_level(level)
435            .map_err(|e| PoolError::CompressionError(e.to_string()))?;
436        Ok(self)
437    }
438
439    /// If queues/channels are not yet setup, initialize them.
440    fn ensure_queue_is_setup(&mut self) {
441        if self.compressor_tx.is_none() && self.compressor_rx.is_none() {
442            if self.queue_size.is_none() {
443                self.queue_size.insert(self.threads * Self::QUEUE_SIZE_THREAD_MULTIPLES);
444            }
445
446            let (tx, rx) = bounded(self.queue_size.unwrap());
447            self.compressor_tx.insert(tx);
448            self.compressor_rx.insert(rx);
449        }
450    }
451
452    /// Exchanges a writer for a [[PooledWriter]].
453    pub fn exchange(&mut self, writer: W) -> PooledWriter {
454        // Make sure queue/channel configuration is done
455        self.ensure_queue_is_setup();
456
457        let (tx, rx): (
458            Sender<oneshot::Receiver<WriterMessage>>,
459            Receiver<oneshot::Receiver<WriterMessage>>,
460        ) = flume::bounded(self.queue_size.expect("Unreachable"));
461
462        let p = PooledWriter::new::<C>(
463            self.writer_index,
464            self.compressor_tx.as_ref().expect("Unreachable").clone(),
465            tx.clone(),
466        );
467
468        self.writer_index += 1;
469        self.writers.push(writer);
470        self.writer_txs.push(tx);
471        self.writer_rxs.push(rx);
472        p
473    }
474
475    /// Consumes the builder and generates the [[Pool]] ready for use.
476    pub fn build(mut self) -> PoolResult<Pool> {
477        // Make sure the queue/channel configuration is done - this could be necessary if
478        // a pool is created by zero writers exchanged.
479        self.ensure_queue_is_setup();
480
481        // Create the channel to gracefully signal a shutdown of the pool
482        let (shutdown_tx, shutdown_rx) = flume::unbounded();
483
484        // Start the pool manager thread and thread pools
485        let handle = std::thread::spawn(move || {
486            Pool::pool_main::<W, C>(
487                self.threads,
488                self.compression_level,
489                self.compressor_rx.expect("Unreachable."),
490                self.writer_rxs,
491                self.writers,
492                shutdown_rx,
493            )
494        });
495
496        let mut pool = Pool {
497            compressor_tx: self.compressor_tx,
498            shutdown_tx: Some(shutdown_tx),
499            pool_handle: Some(handle),
500        };
501
502        Ok(pool)
503    }
504}
505
506impl<W, C> Default for PoolBuilder<W, C>
507where
508    W: Write + Send + 'static,
509    C: Compressor,
510{
511    fn default() -> Self {
512        Self::new()
513    }
514}
515
516////////////////////////////////////////////////////////////////////////////////
517// The Pool struct and impls
518////////////////////////////////////////////////////////////////////////////////
519
520/// A [`Pool`] orchestrates two different threadpools, a compressor pool and a writer pool.
521///
522/// The pool is suitable for scenarios where there are many more writers than threads, efficiently
523/// managing resources for M writers to N threads.
524#[derive(Debug)]
525pub struct Pool {
526    /// The join handle for the thread that manages all pool resources and coordination.
527    pool_handle: Option<JoinHandle<PoolResult<()>>>,
528    /// The send end of the channel for communicating with the compressor pool.
529    compressor_tx: Option<Sender<CompressorMessage>>,
530    /// Sentinel channel to tell the pool management thread to shutdown.
531    shutdown_tx: Option<Sender<()>>,
532}
533
534impl Pool {
535    /// The main "run" method for the pool that orchestrates all the pieces.
536    ///
537    /// The [`PooledWriter`]s are sending to the compressor, the compressor compresses them, then forwards the compressed bytes.
538    /// The bytes are forwarded to a queue per writer and the writer threads are iterating over that queue pulling down
539    /// all values in the queue at once and writing till the queue is empty.
540    ///
541    /// # Arguments
542    /// - `num_threads` - The number of threads to use.
543    /// - `compression_level` - The compression level to use for the [`Compressor`] pool.
544    /// - `compressor_rx ` - The receiving end of the channel for communicating with the compressor pool.
545    /// - `writer_rxs ` - The receive halves of the channels for the [`PooledWriter`]s to enqueue the one-shot channels.
546    /// - `writers` - The writers that were exchanged for [`PooledWriter`]s.
547    /// - `shutdown_rx` - Sentinel channel to tell the pool management thread to shutdown.
548    #[allow(clippy::unnecessary_wraps, clippy::needless_collect, clippy::needless_pass_by_value)]
549    fn pool_main<W, C>(
550        num_threads: usize,
551        compression_level: C::CompressionLevel,
552        compressor_rx: Receiver<CompressorMessage>,
553        writer_rxs: Vec<Receiver<oneshot::Receiver<WriterMessage>>>, // must be pass by value to allow for easy sharing between threads
554        writers: Vec<W>,
555        shutdown_rx: Receiver<()>,
556    ) -> PoolResult<()>
557    where
558        W: Write + Send + 'static,
559        C: Compressor,
560    {
561        // Add locks to the writers
562        let writers: Arc<Vec<_>> =
563            Arc::new(writers.into_iter().map(|w| Arc::new(Mutex::new(w))).collect());
564
565        // Generate one more channel for queuing up information about when a writer has data
566        // available to be written
567        let (write_available_tx, write_available_rx): (Sender<usize>, Receiver<usize>) =
568            flume::unbounded();
569
570        let thread_handles: Vec<JoinHandle<PoolResult<()>>> = (0..num_threads)
571            .map(|thread_idx| {
572                let compressor_rx = compressor_rx.clone();
573                let mut compressor = C::new(compression_level.clone());
574                let writer_rxs = writer_rxs.clone();
575                let writers = writers.clone();
576                let shutdown_rx = shutdown_rx.clone();
577                let write_available_tx = write_available_tx.clone();
578                let write_available_rx = write_available_rx.clone();
579                let select_timeout = Duration::from_millis(100);
580
581                std::thread::spawn(move || {
582                    // Reuse a single compression buffer per thread to avoid
583                    // re-allocating ~70KB on every block.
584                    let mut compress_buf = Vec::new();
585
586                    loop {
587                        // Try non-blocking receives first (fast path under load).
588                        // Then fall through to Selector which blocks until work
589                        // arrives, avoiding the old sleep(25ms) polling delay.
590                        let item = if let Ok(msg) = compressor_rx.try_recv() {
591                            Some(WorkItem::Compress(msg))
592                        } else if let Ok(idx) = write_available_rx.try_recv() {
593                            Some(WorkItem::Write(idx))
594                        } else {
595                            flume::Selector::new()
596                                .recv(&compressor_rx, |r| r.ok().map(WorkItem::Compress))
597                                .recv(&write_available_rx, |r| r.ok().map(WorkItem::Write))
598                                .wait_timeout(select_timeout)
599                                .ok()
600                                .flatten()
601                        };
602
603                        match item {
604                            Some(WorkItem::Compress(message)) => {
605                                let chunk = &message.buffer;
606                                compress_buf.clear();
607                                compressor
608                                    .compress(chunk, &mut compress_buf, message.is_last)
609                                    .map_err(|e| PoolError::CompressionError(e.to_string()))?;
610                                message
611                                    .oneshot_tx
612                                    .send(WriterMessage { buffer: compress_buf.clone() })
613                                    .map_err(|_e| PoolError::ChannelSend);
614                                write_available_tx.send(message.writer_index);
615                            }
616                            Some(WorkItem::Write(writer_index)) => {
617                                let mut writer = writers[writer_index].lock();
618                                let writer_rx = &writer_rxs[writer_index];
619                                let one_shot_rx = writer_rx.recv()?;
620                                let write_message =
621                                    one_shot_rx.recv().map_err(|_| PoolError::ChannelSend)?;
622                                writer.write_all(&write_message.buffer)?;
623                            }
624                            None => {
625                                // Timeout or channel disconnect. Check if all work
626                                // is drained and shutdown was requested.
627                                if shutdown_rx.is_disconnected()
628                                    && write_available_rx.is_empty()
629                                    && compressor_rx.is_empty()
630                                    && writer_rxs.iter().all(|w| w.is_empty())
631                                {
632                                    break;
633                                }
634                            }
635                        }
636                    }
637
638                    Ok(())
639                })
640            })
641            .collect();
642
643        // Close writer handles
644        thread_handles.into_iter().try_for_each(|handle| match handle.join() {
645            Ok(result) => result,
646            Err(e) => std::panic::resume_unwind(e),
647        });
648
649        // Flush each writer
650        writers.iter().try_for_each(|w| w.lock().flush())?;
651
652        Ok(())
653    }
654
655    /// Shutdown all pool resources and close all channels.
656    ///
657    /// Ideally the [`PooledWriter`]s should all have been flushed first, that is up to the user. Any
658    /// further attempts to send to the [`Pool`] will return an error.
659    pub fn stop_pool(&mut self) -> Result<(), PoolError> {
660        // Drop the compressor sender to disconnect the channel.  Buffered
661        // messages are preserved by flume and will be drained by the worker
662        // threads before they observe the disconnect and shut down.
663        drop(self.compressor_tx.take().unwrap());
664
665        // Shutdown called to force writers to start checking their receivers for disconnection / empty
666        drop(self.shutdown_tx.take());
667
668        // Wait on the pool thread to finish and pull any errors from it
669        match self.pool_handle.take().unwrap().join() {
670            Ok(result) => result,
671            Err(e) => std::panic::resume_unwind(e),
672        }
673    }
674}
675
676impl Drop for Pool {
677    fn drop(&mut self) {
678        // Check if `stop_pool` has already been called. If it hasn't, call it.
679        if self.compressor_tx.is_some() && self.pool_handle.is_some() {
680            self.stop_pool().unwrap();
681        }
682    }
683}
684
685////////////////////////////////////////////////////////////////////////////////
686// Tests
687////////////////////////////////////////////////////////////////////////////////
688
689#[cfg(test)]
690mod test {
691    use std::{
692        assert_eq, format,
693        fs::File,
694        io::{BufReader, BufWriter},
695        path::{Path, PathBuf},
696        vec,
697    };
698
699    use crate::bgzf::BgzfCompressor;
700
701    use super::*;
702    use ::bgzf::Reader;
703    use proptest::prelude::*;
704    use tempfile::tempdir;
705
706    fn create_output_writer<P: AsRef<Path>>(path: P) -> BufWriter<File> {
707        BufWriter::new(File::create(path).unwrap())
708    }
709
710    fn create_output_file_name(name: impl AsRef<Path>, dir: impl AsRef<Path>) -> PathBuf {
711        let path = dir.as_ref().to_path_buf();
712        path.join(name)
713    }
714
715    #[test]
716    fn test_simple() {
717        let dir = tempdir().unwrap();
718        let output_names: Vec<PathBuf> = (0..20)
719            .map(|i| create_output_file_name(format!("test.{}.txt.gz", i), dir.path()))
720            .collect();
721
722        let output_writers: Vec<BufWriter<File>> =
723            output_names.iter().map(create_output_writer).collect();
724        let mut builder =
725            PoolBuilder::<_, BgzfCompressor>::new().threads(8).compression_level(2).unwrap();
726        let mut pooled_writers: Vec<PooledWriter> =
727            output_writers.into_iter().map(|w| builder.exchange(w)).collect();
728        let mut pool = builder.build().unwrap();
729
730        for (i, writer) in pooled_writers.iter_mut().enumerate() {
731            writer.write_all(format!("This is writer {}.", i).as_bytes()).unwrap();
732        }
733        pooled_writers.into_iter().try_for_each(|mut w| w.flush()).unwrap();
734        pool.stop_pool();
735
736        for (i, path) in output_names.iter().enumerate() {
737            let mut reader = Reader::new(BufReader::new(File::open(path).unwrap()));
738            let mut actual = vec![];
739            reader.read_to_end(&mut actual).unwrap();
740            assert_eq!(actual, format!("This is writer {}.", i).as_bytes());
741        }
742    }
743
744    proptest! {
745        // This test takes around 20 minutes on a 32 core machine to run but is very comprehensive.
746        // Run with `cargo test -- --ignored`
747        #[ignore]
748        #[test]
749        fn test_complete(
750            input_size in 1..=BUFSIZE * 4,
751            buf_size in 1..=BUFSIZE,
752            num_output_files in 1..2*num_cpus::get(),
753            threads in 1..=2+num_cpus::get(),
754            comp_level in 1..=8_u8,
755            write_size in 1..=2*BUFSIZE,
756        ) {
757            let dir = tempdir().unwrap();
758            let output_names: Vec<PathBuf> = (0..num_output_files)
759                .map(|i| create_output_file_name(format!("test.{}.txt.gz", i), dir.path()))
760                .collect();
761            let output_writers: Vec<_> = output_names.iter().map(create_output_writer).collect();
762
763            let mut builder = PoolBuilder::<_, BgzfCompressor>::new()
764                .threads(threads)
765                .compression_level(comp_level)?;
766
767            let mut pooled_writers: Vec<_> = output_writers.into_iter().map(|w| builder.exchange(w)).collect();
768            let mut pool = builder.build()?;
769
770            let inputs: Vec<Vec<u8>> = (0..num_output_files).map(|_| {
771                (0..input_size).map(|_| rand::random::<u8>()).collect()
772            }).collect();
773
774            let chunks = (input_size as f64 / write_size as f64).ceil() as usize;
775
776            // write a chunk to each writer (could randomly select the writers?)
777            for i in (0..chunks) {
778                for (j, writer) in pooled_writers.iter_mut().enumerate() {
779                    let input = &inputs[j];
780                    let bytes = &input[write_size * i..std::cmp::min(write_size * (i + 1), input.len())];
781                    writer.write_all(bytes).unwrap()
782                }
783            }
784
785            pooled_writers.into_iter().try_for_each(|mut w| w.flush()).unwrap();
786            pool.stop_pool();
787
788            for (i, path) in output_names.iter().enumerate() {
789                let mut reader = Reader::new(BufReader::new(File::open(path).unwrap()));
790                let mut actual = vec![];
791                reader.read_to_end(&mut actual).unwrap();
792                assert_eq!(actual, inputs[i]);
793            }
794
795        }
796    }
797}