pooled_writer/
lib.rs

1//! A pooled writer and compressor.
2//!
3//! # Overview
4//!
5//! `pooled-writer` solves the problem of compressing and writing data to a set of writers using
6//! multiple threads, where the number of writers and threads cannot easily be equal.  For example
7//! writing to hundreds of gzipped files using 16 threads, or writing to a four gzipped files
8//! using 32 threads.
9//!
10//! To accomplish this, a pool is configured and writers are exchanged for [`PooledWriter`]s
11//! that can be used in place of the original writers.  This is accomplished using the
12//! [`PoolBuilder`] which is the preferred way to configure and create a pool.  The [`Pool`] and
13//! builder require two generic types: the `W` Writer type and the `C` compressor type. `W` may
14//! usually be elided if calls to [`PoolBuilder::exchange`] may be used to infer the type. `C`
15//! must be specified as something that implements [`Compressor`].
16//!
17//! The [`Pool`] consists of a single thread pool that consumes work from both a compression queue
18//! and a writing queue.  All concurrency is managed via message passing over channels.
19//!
20//! Every time the internal buffer of a [`PooledWriter`] reaches capacity (defined by
21//! [`Compressor::BLOCK_SIZE`]) it sends two messages:
22//! 1. It sends a message over the corresponding writer's channel to the writer pool, enqueueing
23//!    a one-shot receiver channel in the writers queue that will receive the compressed bytes
24//!    once the compressor is done. This is done to maintain the output order.
25//! 2. It sends a message to the compressor pool that contains a buffer of bytes to compress
26//!    as well as the sender side of the one-shot channel to send the compressed bytes on.
27//!
28//! The threads in the thread pool loop continuously until the pool is shut down, and attempt
29//! first receive and compress one block, then secondly to receive and write one compressed block.
30//! A third internal channel is used to manage the queue of writes to be performed so that the
31//! individual per-writer channels (of which there may be many) are only polled if there is likely
32//! to be data available for writing.  When data is available to be written, the appropriate
33//! underlying writer is locked, and the data written.
34//!
35//! When all writing to [`PooledWriter`]s is complete, the writers should be close()'d or drop()'d
36//! and then the pool should be stopped using [`Pool::stop_pool`].  Writers that are not closed
37//! may have data buffered that is never written!  
38//!
39//! [`Pool::stop_pool`] will shutdown channels in a safe order ensuring that data submitted to the
40//! pool is compressed and written before threads are stopped.  After initiating the pool shutdown
41//! any subsequent attempts to write to [`PooledWriter`]s will result in errors.  Likewise any
42//! calls to [`PooledWriter:close`] that cause data to be flushed into the compression queue will
43//! raise errors.
44//!
45//! # Example
46//!
47//! ```rust
48//! use std::{
49//!     error::Error,
50//!     fs::File,
51//!     io::{BufWriter, Write},
52//!     path::Path,
53//! };
54//!
55//! use pooled_writer::{Compressor, PoolBuilder, Pool, bgzf::BgzfCompressor};
56//!
57//! type DynError = Box<dyn Error + 'static>;
58//!
59//! fn create_writer<P: AsRef<Path>>(name: P) -> Result<BufWriter<File>, DynError> {
60//!     Ok(BufWriter::new(File::create(name)?))
61//! }
62//!
63//! fn main() -> Result<(), DynError> {
64//!     let writers = vec![
65//!         create_writer("/tmp/test1.txt.gz")?,
66//!         create_writer("/tmp/test2.txt.gz")?,
67//!         create_writer("/tmp/test3.txt.gz")?,
68//!     ];
69//!
70//!     let mut builder = PoolBuilder::<_, BgzfCompressor>::new()
71//!         .threads(8)
72//!         .compression_level(5)?;
73//!
74//!    let mut pooled_writers = writers.into_iter().map(|w| builder.exchange(w)).collect::<Vec<_>>();
75//!    let mut pool = builder.build()?;
76//!
77//!     writeln!(&mut pooled_writers[1], "This is writer2")?;
78//!     writeln!(&mut pooled_writers[0], "This is writer1")?;
79//!     writeln!(&mut pooled_writers[2], "This is writer3")?;
80//!     pooled_writers.into_iter().try_for_each(|w| w.close())?;
81//!     pool.stop_pool()?;
82//!
83//!     Ok(())
84//! }
85//! ```
86#![forbid(unsafe_code)]
87#![allow(
88    unused,
89    clippy::missing_panics_doc,
90    clippy::missing_errors_doc,
91    clippy::must_use_candidate,
92    clippy::module_name_repetitions
93)]
94
95#[cfg(feature = "bgzf_compressor")]
96pub mod bgzf;
97
98use std::time::Duration;
99use std::{
100    error::Error,
101    io::{self, Read, Write},
102    sync::Arc,
103    thread::JoinHandle,
104};
105
106use bytes::{Bytes, BytesMut};
107use flume::{self, bounded, Receiver, Sender};
108use parking_lot::{lock_api::RawMutex, Mutex};
109use thiserror::Error;
110
111/// 128 KB default buffer size, same as pigz.
112pub(crate) const BUFSIZE: usize = 128 * 1024;
113
114/// Convenience type for functions that return [`PoolError`].
115type PoolResult<T> = Result<T, PoolError>;
116
117/// Represents errors that may be generated by any `Pool` related functionality.
118#[non_exhaustive]
119#[derive(Error, Debug)]
120pub enum PoolError {
121    #[error("Failed to send over channel")]
122    ChannelSend,
123    #[error(transparent)]
124    ChannelReceive(#[from] flume::RecvError),
125
126    // TODO: figure out how to better pass in an generic / dynamic error type to this.
127    #[error("Error compressing data: {0}")]
128    CompressionError(String),
129    #[error(transparent)]
130    Io(#[from] io::Error),
131}
132
133////////////////////////////////////////////////////////////////////////////////
134// The PooledWriter and it's impls
135////////////////////////////////////////////////////////////////////////////////
136
137/// A [`PooledWriter`] is created by exchanging a writer with a [`Pool`].
138///
139/// The pooled writer will internally buffer writes, sending bytes to the [`Pool`]
140/// after the internal buffer has been filled.
141///
142/// Note that the `compressor_tx` channel is shared by all pooled writers, whereas the `writer_tx`
143/// is specific to the _underlying_ writer that this pooled writer encapsulates.
144#[derive(Debug)]
145pub struct PooledWriter {
146    /// The index/serial number of the pooled writer within the pool
147    writer_index: usize,
148    /// Channel to send messages containing bytes to compress to the compressors' pool.
149    compressor_tx: Sender<CompressorMessage>,
150    /// Channel to send the receiving end of the one-shot channel that will be
151    /// used to send the compressed bytes. This effectively "place holds" the
152    /// position of the compressed bytes in the writers queue until the compressed bytes
153    /// are ready.
154    writer_tx: Sender<Receiver<WriterMessage>>,
155    /// The internal buffer to gather bytes to send.
156    buffer: BytesMut,
157    /// The desired size of the internal buffer.
158    buffer_size: usize,
159}
160
161impl PooledWriter {
162    /// Create a new [`PooledWriter`] that has an internal buffer capacity that matches [`bgzf::BGZF_BLOCK_SIZE`].
163    ///
164    /// # Arguments
165    /// - `index` - a usize representing that this is the nth pooled writer created within the pool
166    /// - `compressor_tx` - The channel to send uncompressed bytes to the compressor pool.
167    /// - `writer_tx` - The `Send` end of the channel that transmits the `Receiver` end of the one-shot
168    ///                 channel, which will be consumed when the compressor sends the compressed bytes.
169    fn new<C>(
170        index: usize,
171        compressor_tx: Sender<CompressorMessage>,
172        writer_tx: Sender<Receiver<WriterMessage>>,
173    ) -> Self
174    where
175        C: Compressor,
176    {
177        Self {
178            writer_index: index,
179            compressor_tx,
180            writer_tx,
181            buffer: BytesMut::with_capacity(C::BLOCK_SIZE),
182            buffer_size: C::BLOCK_SIZE,
183        }
184    }
185
186    /// Test whether the internal buffer has reached capacity.
187    #[inline]
188    fn buffer_full(&self) -> bool {
189        self.buffer.len() == self.buffer_size
190    }
191
192    /// Send all bytes in the current buffer to the compressor.
193    ///
194    /// If `is_last` is `true`, the message sent to the compressor will also have the `is_last` true flag set
195    /// and the compressor will finish the BGZF stream.
196    ///
197    /// If `is_last` is not true then only full block will be sent. If `is_last` is true, an incomplete block may be set
198    /// as the final block.
199    fn flush_bytes(&mut self, is_last: bool) -> std::io::Result<()> {
200        if is_last || self.buffer_full() {
201            self.send_block(is_last)?;
202        }
203        Ok(())
204    }
205
206    /// Send a single block
207    fn send_block(&mut self, is_last: bool) -> std::io::Result<()> {
208        let bytes = self.buffer.split_to(self.buffer.len()).freeze();
209        let (mut m, r) = CompressorMessage::new_parts(self.writer_index, bytes);
210        m.is_last = is_last;
211        self.writer_tx
212            .send(r)
213            .map_err(|_e| io::Error::new(io::ErrorKind::Other, PoolError::ChannelSend))?;
214        self.compressor_tx
215            .send(m)
216            .map_err(|_e_| io::Error::new(io::ErrorKind::Other, PoolError::ChannelSend))
217    }
218
219    /// Flush any remaining bytes and consume self, triggering drops of the senders.
220    pub fn close(mut self) -> std::io::Result<()> {
221        self.flush_bytes(true)
222    }
223}
224
225impl Drop for PooledWriter {
226    /// Drop [`PooledWriter`].
227    ///
228    /// This will flush the writer.
229    fn drop(&mut self) {
230        self.flush_bytes(true).unwrap();
231    }
232}
233
234impl Write for PooledWriter {
235    /// Send all bytes in `buf` to the [`Pool`].
236    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
237        let mut bytes_added = 0;
238
239        while bytes_added < buf.len() {
240            let bytes_to_append =
241                std::cmp::min(buf.len() - bytes_added, self.buffer_size - self.buffer.len());
242
243            self.buffer.extend_from_slice(&buf[bytes_added..bytes_added + bytes_to_append]);
244            bytes_added += bytes_to_append;
245            if self.buffer_full() {
246                self.send_block(false)?;
247            }
248        }
249
250        Ok(buf.len())
251    }
252
253    /// Send whatever is in the current buffer even if it is not a full buffer.
254    fn flush(&mut self) -> std::io::Result<()> {
255        self.flush_bytes(false)
256    }
257}
258
259////////////////////////////////////////////////////////////////////////////////
260// The Compressor trait
261////////////////////////////////////////////////////////////////////////////////
262
263/// A [`Compressor`] is used in the compressor pool to compress bytes.
264///
265/// An implementation must be provided as a type to the [`Pool::new`] function so that the pool
266/// knows what kind of compression to use.
267///
268/// See the module level example for more details.
269pub trait Compressor: Sized + Send + 'static
270where
271    Self::CompressionLevel: Clone + Send + 'static,
272    Self::Error: Error + Send + 'static,
273{
274    type Error;
275    type CompressionLevel;
276
277    /// The `BLOCK_SIZE` is used to set the buffer size of the [`PooledWriter`]s and should match the max
278    /// size allowed by the block compression format being used.
279    const BLOCK_SIZE: usize = 65280;
280
281    /// Create a new compressor with the given compression level.
282    fn new(compression_level: Self::CompressionLevel) -> Self;
283
284    /// Returns the default compression level for the compressor.
285    fn default_compression_level() -> Self::CompressionLevel;
286
287    /// Create an instance of the compression level.
288    ///
289    /// The validity of the compression level should be checked here.
290    fn new_compression_level(compression_level: u8) -> Result<Self::CompressionLevel, Self::Error>;
291
292    /// Compress a set of bytes into the `output` vec. If `is_last` is true, and depending on the
293    /// block compression format, an EOF block may be appended as well.
294    fn compress(
295        &mut self,
296        input: &[u8],
297        output: &mut Vec<u8>,
298        is_last: bool,
299    ) -> Result<(), Self::Error>;
300}
301
302////////////////////////////////////////////////////////////////////////////////
303// The messages passed between threads
304////////////////////////////////////////////////////////////////////////////////
305
306/// A message that is sent from a [`PooledWriter`] to the compressor threadpool within a [`Pool`].
307#[derive(Debug)]
308struct CompressorMessage {
309    /// The index of the destination writer
310    writer_index: usize,
311    /// The bytes to compress.
312    buffer: Bytes,
313    /// Where the compressed bytes will be sent after compression.
314    oneshot: Sender<WriterMessage>,
315    /// A sentinel value to let the compressor know that the BGZF stream needs an EOF.
316    is_last: bool,
317}
318
319impl CompressorMessage {
320    fn new_parts(writer_index: usize, buffer: Bytes) -> (Self, Receiver<WriterMessage>) {
321        let (tx, rx) = flume::unbounded(); // oneshot channel
322        let new = Self { writer_index, buffer, oneshot: tx, is_last: false };
323        (new, rx)
324    }
325}
326
327/// The compressed bytes to be written to a file.
328///
329/// This is sent from the compressor threadpool to the writer queue in the writer threadpool
330/// via the one-shot channel provided by the [`PooledWriter`].
331#[derive(Debug)]
332struct WriterMessage {
333    buffer: Vec<u8>,
334}
335
336////////////////////////////////////////////////////////////////////////////////
337// The PoolBuilder struct and impls
338////////////////////////////////////////////////////////////////////////////////
339
340/// A struct to make building up a Pool simpler.  The builder should be constructed using
341/// [`PoolBuilder::new`], which provides the user control over the sizes of the queues used for
342/// compression and writing.  It should be noted that a single compression queue is created,
343/// and one writer queue per writer exchanged.  A good starting point for these queue sizes is
344/// two times the number of threads.
345///
346/// Once created various functions can configure aspects of the pool.  It is best practice, though
347/// not required, to configure the builder _before_ exchanging writers.  The exception is
348/// `queue_size` that may _not_ be set after any writers have been exchanged.  If not set manually
349/// then `queue_size` defaults to the number of threads multiplied by
350/// [`PoolBuilder::QUEUE_SIZE_THREAD_MULTIPLES`].
351///
352/// Once the builder is configured writers may be exchanged for [`PooledWriter`]s using the
353/// [`PoolBuilder::exchange`] function, which consumes the provided writer and returns a new
354/// writer that can be used in it's place.
355///
356/// After exchanging all writers the pool may be created and started with [`PoolBuilder::build`]
357/// which consumes the builder and after which no more writers may be exchanged.
358pub struct PoolBuilder<W, C>
359where
360    W: Write + Send + 'static,
361    C: Compressor,
362{
363    writer_index: usize,
364    compression_level: C::CompressionLevel,
365    queue_size: Option<usize>,
366    threads: usize,
367    compressor_tx: Option<Sender<CompressorMessage>>,
368    compressor_rx: Option<Receiver<CompressorMessage>>,
369    writers: Vec<W>,
370    writer_txs: Vec<Sender<Receiver<WriterMessage>>>,
371    writer_rxs: Vec<Receiver<Receiver<WriterMessage>>>,
372}
373
374impl<W, C> PoolBuilder<W, C>
375where
376    W: Write + Send + 'static,
377    C: Compressor,
378{
379    /// By default queue sizes will be set to threads * this constant.
380    pub const QUEUE_SIZE_THREAD_MULTIPLES: usize = 50;
381
382    /// The default number of threads that will be used if not otherwise configured
383    pub const DEFAULT_THREADS: usize = 4;
384
385    /// Creates a new PoolBuilder that can be used to configure and build a [`Pool`].
386    pub fn new() -> Self {
387        PoolBuilder {
388            writer_index: 0,
389            compression_level: C::default_compression_level(),
390            queue_size: None,
391            threads: Self::DEFAULT_THREADS,
392            compressor_tx: None,
393            compressor_rx: None,
394            writers: vec![],
395            writer_txs: vec![],
396            writer_rxs: vec![],
397        }
398    }
399
400    /// Sets the number of threads that will be used by the [[Pool]].
401    ///
402    /// Will panic if set to 0.
403    pub fn threads(mut self, threads: usize) -> Self {
404        assert!(threads > 0, "Must provide a number of threads greater than 0.");
405        self.threads = threads;
406        self
407    }
408
409    /// Sets the size of queues used by the pool [[Pool]].  The same size is used for
410    /// a) the queue of byte buffers to be compressed, b) the per-sample queues to receive
411    /// compressed bytes, and c) a control queue to manage writing to the underlying writers.
412    ///
413    /// In the worst case scenario the pool can be holding both queue_size uncompressed blocks
414    /// _and_ queue_size compressed blocks in memory when it cannot keep up with the incoming
415    /// load of writes.
416    ///
417    ///
418    ///
419    /// Will panic if called _after_ writers have been created because queues will already have
420    /// been created.
421    pub fn queue_size(mut self, queue_size: usize) -> Self {
422        assert!(self.writers.is_empty(), "Cannot set queue_size after writers are exchanged.");
423        self.queue_size.insert(queue_size);
424        self
425    }
426
427    /// Sets the compression level that will be used by the [[Pool]].
428    pub fn compression_level(mut self, level: u8) -> PoolResult<Self> {
429        self.compression_level = C::new_compression_level(level)
430            .map_err(|e| PoolError::CompressionError(e.to_string()))?;
431        Ok(self)
432    }
433
434    /// If queues/channels are not yet setup, initialize them.
435    fn ensure_queue_is_setup(&mut self) {
436        if self.compressor_tx.is_none() && self.compressor_rx.is_none() {
437            if self.queue_size.is_none() {
438                self.queue_size.insert(self.threads * Self::QUEUE_SIZE_THREAD_MULTIPLES);
439            }
440
441            let (tx, rx) = bounded(self.queue_size.unwrap());
442            self.compressor_tx.insert(tx);
443            self.compressor_rx.insert(rx);
444        }
445    }
446
447    /// Exchanges a writer for a [[PooledWriter]].
448    pub fn exchange(&mut self, writer: W) -> PooledWriter {
449        // Make sure queue/channel configuration is done
450        self.ensure_queue_is_setup();
451
452        let (tx, rx): (Sender<Receiver<WriterMessage>>, Receiver<Receiver<WriterMessage>>) =
453            flume::bounded(self.queue_size.expect("Unreachable"));
454
455        let p = PooledWriter::new::<C>(
456            self.writer_index,
457            self.compressor_tx.as_ref().expect("Unreachable").clone(),
458            tx.clone(),
459        );
460
461        self.writer_index += 1;
462        self.writers.push(writer);
463        self.writer_txs.push(tx);
464        self.writer_rxs.push(rx);
465        p
466    }
467
468    /// Consumes the builder and generates the [[Pool]] ready for use.
469    pub fn build(mut self) -> PoolResult<Pool> {
470        // Make sure the queue/channel configuration is done - this could be necessary if
471        // a pool is created by zero writers exchanged.
472        self.ensure_queue_is_setup();
473
474        // Create the channel to gracefully signal a shutdown of the pool
475        let (shutdown_tx, shutdown_rx) = flume::unbounded();
476
477        // Start the pool manager thread and thread pools
478        let handle = std::thread::spawn(move || {
479            Pool::pool_main::<W, C>(
480                self.threads,
481                self.compression_level,
482                self.compressor_rx.expect("Unreachable."),
483                self.writer_rxs,
484                self.writers,
485                shutdown_rx,
486            )
487        });
488
489        let mut pool = Pool {
490            compressor_tx: self.compressor_tx,
491            shutdown_tx: Some(shutdown_tx),
492            pool_handle: Some(handle),
493        };
494
495        Ok(pool)
496    }
497}
498
499impl<W, C> Default for PoolBuilder<W, C>
500where
501    W: Write + Send + 'static,
502    C: Compressor,
503{
504    fn default() -> Self {
505        Self::new()
506    }
507}
508
509////////////////////////////////////////////////////////////////////////////////
510// The Pool struct and impls
511////////////////////////////////////////////////////////////////////////////////
512
513/// A [`Pool`] orchestrates two different threadpools, a compressor pool and a writer pool.
514///
515/// The pool is suitable for scenarios where there are many more writers than threads, efficiently
516/// managing resources for M writers to N threads.
517#[derive(Debug)]
518pub struct Pool {
519    /// The join handle for the thread that manages all pool resources and coordination.
520    pool_handle: Option<JoinHandle<PoolResult<()>>>,
521    /// The send end of the channel for communicating with the compressor pool.
522    compressor_tx: Option<Sender<CompressorMessage>>,
523    /// Sentinel channel to tell the pool management thread to shutdown.
524    shutdown_tx: Option<Sender<()>>,
525}
526
527impl Pool {
528    /// The main "run" method for the pool that orchestrates all the pieces.
529    ///
530    /// The [`PooledWriter`]s are sending to the compressor, the compressor compresses them, then forwards the compressed bytes.
531    /// The bytes are forwarded to a queue per writer and the writer threads are iterating over that queue pulling down
532    /// all values in the queue at once and writing till the queue is empty.
533    ///
534    /// # Arguments
535    /// - `num_threads` - The number of threads to use.
536    /// - `compression_level` - The compression level to use for the [`Compressor`] pool.
537    /// - `compressor_rx ` - The receiving end of the channel for communicating with the compressor pool.
538    /// - `writer_rxs ` - The receive halves of the channels for the [`PooledWriter`]s to enqueue the one-shot channels.
539    /// - `writers` - The writers that were exchanged for [`PooledWriter`]s.
540    /// - `shutdown_rx` - Sentinel channel to tell the pool management thread to shutdown.
541    #[allow(clippy::unnecessary_wraps, clippy::needless_collect, clippy::needless_pass_by_value)]
542    fn pool_main<W, C>(
543        num_threads: usize,
544        compression_level: C::CompressionLevel,
545        compressor_rx: Receiver<CompressorMessage>,
546        writer_rxs: Vec<Receiver<Receiver<WriterMessage>>>, // must be pass by value to allow for easy sharing between threads
547        writers: Vec<W>,
548        shutdown_rx: Receiver<()>,
549    ) -> PoolResult<()>
550    where
551        W: Write + Send + 'static,
552        C: Compressor,
553    {
554        // Add locks to the writers
555        let writers: Arc<Vec<_>> =
556            Arc::new(writers.into_iter().map(|w| Arc::new(Mutex::new(w))).collect());
557
558        // Generate one more channel for queuing up information about when a writer has data
559        // available to be written
560        let (write_available_tx, write_available_rx): (Sender<usize>, Receiver<usize>) =
561            flume::unbounded();
562
563        let thread_handles: Vec<JoinHandle<PoolResult<()>>> = (0..num_threads)
564            .map(|thread_idx| {
565                let compressor_rx = compressor_rx.clone();
566                let mut compressor = C::new(compression_level.clone());
567                let writer_rxs = writer_rxs.clone();
568                let writers = writers.clone();
569                let shutdown_rx = shutdown_rx.clone();
570                let sleep_delay = Duration::from_millis(25);
571                let write_available_tx = write_available_tx.clone();
572                let write_available_rx = write_available_rx.clone();
573
574                std::thread::spawn(move || {
575                    loop {
576                        let mut did_something = false;
577
578                        // Try to process one compression message
579                        if let Ok(message) = compressor_rx.try_recv() {
580                            // Compress the buffer in the message
581                            let chunk = &message.buffer;
582                            // Compress will correctly resize the compressed vec.
583                            let mut compressed = Vec::new();
584                            compressor
585                                .compress(chunk, &mut compressed, message.is_last)
586                                .map_err(|e| PoolError::CompressionError(e.to_string()))?;
587                            message
588                                .oneshot
589                                .send(WriterMessage { buffer: compressed })
590                                .map_err(|_e| PoolError::ChannelSend);
591                            write_available_tx.send(message.writer_index);
592                            did_something = true;
593                        }
594
595                        // Then try to process one write message
596                        if let Ok(writer_index) = write_available_rx.try_recv() {
597                            let mut writer = writers[writer_index].lock();
598                            let writer_rx = &writer_rxs[writer_index];
599                            let one_shot_rx = writer_rx.recv()?;
600                            let write_message = one_shot_rx.recv()?;
601                            writer.write_all(&write_message.buffer)?;
602                            did_something = true;
603                        }
604
605                        // If we didn't do anything either sleep for a few ms to avoid busy-waiting
606                        // or if shutdown is requested and all the channels are empty, terminate.
607                        if !did_something {
608                            if shutdown_rx.is_disconnected()
609                                && write_available_rx.is_empty()
610                                && compressor_rx.is_empty()
611                                && writer_rxs.iter().all(|w| w.is_empty())
612                            {
613                                break;
614                            } else {
615                                std::thread::sleep(sleep_delay);
616                            }
617                        }
618                    }
619
620                    Ok(())
621                })
622            })
623            .collect();
624
625        // Close writer handles
626        thread_handles.into_iter().try_for_each(|handle| match handle.join() {
627            Ok(result) => result,
628            Err(e) => std::panic::resume_unwind(e),
629        });
630
631        // Flush each writer
632        writers.iter().try_for_each(|w| w.lock().flush())?;
633
634        Ok(())
635    }
636
637    /// Shutdown all pool resources and close all channels.
638    ///
639    /// Ideally the [`PooledWriter`]s should all have been flushed first, that is up to the user. Any
640    /// further attempts to send to the [`Pool`] will return an error.
641    pub fn stop_pool(&mut self) -> Result<(), PoolError> {
642        let compressor_queue = self.compressor_tx.take().unwrap();
643        while !compressor_queue.is_empty() {
644            // Wait for compression to finish before dropping the sender
645        }
646        drop(compressor_queue);
647
648        // Shutdown called to force writers to start checking their receivers for disconnection / empty
649        drop(self.shutdown_tx.take());
650
651        // Wait on the pool thread to finish and pull any errors from it
652        match self.pool_handle.take().unwrap().join() {
653            Ok(result) => result,
654            Err(e) => std::panic::resume_unwind(e),
655        }
656    }
657}
658
659impl Drop for Pool {
660    fn drop(&mut self) {
661        // Check if `stop_pool` has already been called. If it hasn't, call it.
662        if self.compressor_tx.is_some() && self.pool_handle.is_some() {
663            self.stop_pool().unwrap();
664        }
665    }
666}
667
668////////////////////////////////////////////////////////////////////////////////
669// Tests
670////////////////////////////////////////////////////////////////////////////////
671
672#[cfg(test)]
673mod test {
674    use std::{
675        assert_eq, format,
676        fs::File,
677        io::{BufReader, BufWriter},
678        path::{Path, PathBuf},
679        vec,
680    };
681
682    use crate::bgzf::BgzfCompressor;
683
684    use super::*;
685    use ::bgzf::Reader;
686    use proptest::prelude::*;
687    use tempfile::tempdir;
688
689    fn create_output_writer<P: AsRef<Path>>(path: P) -> BufWriter<File> {
690        BufWriter::new(File::create(path).unwrap())
691    }
692
693    fn create_output_file_name(name: impl AsRef<Path>, dir: impl AsRef<Path>) -> PathBuf {
694        let path = dir.as_ref().to_path_buf();
695        path.join(name)
696    }
697
698    #[test]
699    fn test_simple() {
700        let dir = tempdir().unwrap();
701        let output_names: Vec<PathBuf> = (0..20)
702            .into_iter()
703            .map(|i| create_output_file_name(format!("test.{}.txt.gz", i), &dir.path()))
704            .collect();
705
706        let output_writers: Vec<BufWriter<File>> =
707            output_names.iter().map(create_output_writer).collect();
708        let mut builder =
709            PoolBuilder::<_, BgzfCompressor>::new().threads(8).compression_level(2).unwrap();
710        let mut pooled_writers: Vec<PooledWriter> =
711            output_writers.into_iter().map(|w| builder.exchange(w)).collect();
712        let mut pool = builder.build().unwrap();
713
714        for (i, writer) in pooled_writers.iter_mut().enumerate() {
715            writer.write_all(format!("This is writer {}.", i).as_bytes()).unwrap();
716        }
717        pooled_writers.into_iter().try_for_each(|mut w| w.flush()).unwrap();
718        pool.stop_pool();
719
720        for (i, path) in output_names.iter().enumerate() {
721            let mut reader = Reader::new(BufReader::new(File::open(path).unwrap()));
722            let mut actual = vec![];
723            reader.read_to_end(&mut actual).unwrap();
724            assert_eq!(actual, format!("This is writer {}.", i).as_bytes());
725        }
726    }
727
728    proptest! {
729        // This test takes around 20 minutes on a 32 core machine to run but is very comprehensive.
730        // Run with `cargo test -- --ignored`
731        #[ignore]
732        #[test]
733        fn test_complete(
734            input_size in 1..=BUFSIZE * 4,
735            buf_size in 1..=BUFSIZE,
736            num_output_files in 1..2*num_cpus::get(),
737            threads in 1..=2+num_cpus::get(),
738            comp_level in 1..=8_u8,
739            write_size in 1..=2*BUFSIZE,
740        ) {
741            let dir = tempdir().unwrap();
742            let output_names: Vec<PathBuf> = (0..num_output_files)
743                .into_iter()
744                .map(|i| create_output_file_name(format!("test.{}.txt.gz", i), &dir.path()))
745                .collect();
746            let output_writers: Vec<_> = output_names.iter().map(create_output_writer).collect();
747
748            let mut builder = PoolBuilder::<_, BgzfCompressor>::new()
749                .threads(threads)
750                .compression_level(comp_level)?;
751
752            let mut pooled_writers: Vec<_> = output_writers.into_iter().map(|w| builder.exchange(w)).collect();
753            let mut pool = builder.build()?;
754
755            let inputs: Vec<Vec<u8>> = (0..num_output_files).map(|_| {
756                (0..input_size).map(|_| rand::random::<u8>()).collect()
757            }).collect();
758
759            let chunks = (input_size as f64 / write_size as f64).ceil() as usize;
760
761            // write a chunk to each writer (could randomly select the writers?)
762            for i in (0..chunks) {
763                for (j, writer) in pooled_writers.iter_mut().enumerate() {
764                    let input = &inputs[j];
765                    let bytes = &input[write_size * i..std::cmp::min(write_size * (i + 1), input.len())];
766                    writer.write_all(bytes).unwrap()
767                }
768            }
769
770            pooled_writers.into_iter().try_for_each(|mut w| w.flush()).unwrap();
771            pool.stop_pool();
772
773            for (i, path) in output_names.iter().enumerate() {
774                let mut reader = Reader::new(BufReader::new(File::open(path).unwrap()));
775                let mut actual = vec![];
776                reader.read_to_end(&mut actual).unwrap();
777                assert_eq!(actual, inputs[i]);
778            }
779
780        }
781    }
782}