pooled_writer/lib.rs
1//! A pooled writer and compressor.
2//!
3//! # Overview
4//!
5//! `pooled-writer` solves the problem of compressing and writing data to a set of writers using
6//! multiple threads, where the number of writers and threads cannot easily be equal. For example
7//! writing to hundreds of gzipped files using 16 threads, or writing to a four gzipped files
8//! using 32 threads.
9//!
10//! To accomplish this, a pool is configured and writers are exchanged for [`PooledWriter`]s
11//! that can be used in place of the original writers. This is accomplished using the
12//! [`PoolBuilder`] which is the preferred way to configure and create a pool. The [`Pool`] and
13//! builder require two generic types: the `W` Writer type and the `C` compressor type. `W` may
14//! usually be elided if calls to [`PoolBuilder::exchange`] may be used to infer the type. `C`
15//! must be specified as something that implements [`Compressor`].
16//!
17//! The [`Pool`] consists of a single thread pool that consumes work from both a compression queue
18//! and a writing queue. All concurrency is managed via message passing over channels.
19//!
20//! Every time the internal buffer of a [`PooledWriter`] reaches capacity (defined by
21//! [`Compressor::BLOCK_SIZE`]) it sends two messages:
22//! 1. It sends a message over the corresponding writer's channel to the writer pool, enqueueing
23//! a one-shot receiver channel in the writers queue that will receive the compressed bytes
24//! once the compressor is done. This is done to maintain the output order.
25//! 2. It sends a message to the compressor pool that contains a buffer of bytes to compress
26//! as well as the sender side of the one-shot channel to send the compressed bytes on.
27//!
28//! The threads in the thread pool loop continuously until the pool is shut down, and attempt
29//! first receive and compress one block, then secondly to receive and write one compressed block.
30//! A third internal channel is used to manage the queue of writes to be performed so that the
31//! individual per-writer channels (of which there may be many) are only polled if there is likely
32//! to be data available for writing. When data is available to be written, the appropriate
33//! underlying writer is locked, and the data written.
34//!
35//! When all writing to [`PooledWriter`]s is complete, the writers should be close()'d or drop()'d
36//! and then the pool should be stopped using [`Pool::stop_pool`]. Writers that are not closed
37//! may have data buffered that is never written!
38//!
39//! [`Pool::stop_pool`] will shutdown channels in a safe order ensuring that data submitted to the
40//! pool is compressed and written before threads are stopped. After initiating the pool shutdown
41//! any subsequent attempts to write to [`PooledWriter`]s will result in errors. Likewise any
42//! calls to [`PooledWriter:close`] that cause data to be flushed into the compression queue will
43//! raise errors.
44//!
45//! # Example
46//!
47//! ```rust
48//! use std::{
49//! error::Error,
50//! fs::File,
51//! io::{BufWriter, Write},
52//! path::Path,
53//! };
54//!
55//! use pooled_writer::{Compressor, PoolBuilder, Pool, bgzf::BgzfCompressor};
56//!
57//! type DynError = Box<dyn Error + 'static>;
58//!
59//! fn create_writer<P: AsRef<Path>>(name: P) -> Result<BufWriter<File>, DynError> {
60//! Ok(BufWriter::new(File::create(name)?))
61//! }
62//!
63//! fn main() -> Result<(), DynError> {
64//! let writers = vec![
65//! create_writer("/tmp/test1.txt.gz")?,
66//! create_writer("/tmp/test2.txt.gz")?,
67//! create_writer("/tmp/test3.txt.gz")?,
68//! ];
69//!
70//! let mut builder = PoolBuilder::<_, BgzfCompressor>::new()
71//! .threads(8)
72//! .compression_level(5)?;
73//!
74//! let mut pooled_writers = writers.into_iter().map(|w| builder.exchange(w)).collect::<Vec<_>>();
75//! let mut pool = builder.build()?;
76//!
77//! writeln!(&mut pooled_writers[1], "This is writer2")?;
78//! writeln!(&mut pooled_writers[0], "This is writer1")?;
79//! writeln!(&mut pooled_writers[2], "This is writer3")?;
80//! pooled_writers.into_iter().try_for_each(|w| w.close())?;
81//! pool.stop_pool()?;
82//!
83//! Ok(())
84//! }
85//! ```
86#![forbid(unsafe_code)]
87#![allow(
88 unused,
89 clippy::missing_panics_doc,
90 clippy::missing_errors_doc,
91 clippy::must_use_candidate,
92 clippy::module_name_repetitions
93)]
94
95#[cfg(feature = "bgzf_compressor")]
96pub mod bgzf;
97
98use std::time::Duration;
99use std::{
100 error::Error,
101 io::{self, Read, Write},
102 sync::Arc,
103 thread::JoinHandle,
104};
105
106use bytes::{Bytes, BytesMut};
107use flume::{self, Receiver, Sender, bounded};
108use parking_lot::{Mutex, lock_api::RawMutex};
109use thiserror::Error;
110
111/// 128 KB default buffer size, same as pigz.
112pub(crate) const BUFSIZE: usize = 128 * 1024;
113
114/// Convenience type for functions that return [`PoolError`].
115type PoolResult<T> = Result<T, PoolError>;
116
117/// Represents errors that may be generated by any `Pool` related functionality.
118#[non_exhaustive]
119#[derive(Error, Debug)]
120pub enum PoolError {
121 #[error("Failed to send over channel")]
122 ChannelSend,
123 #[error(transparent)]
124 ChannelReceive(#[from] flume::RecvError),
125
126 // TODO: figure out how to better pass in an generic / dynamic error type to this.
127 #[error("Error compressing data: {0}")]
128 CompressionError(String),
129 #[error(transparent)]
130 Io(#[from] io::Error),
131}
132
133////////////////////////////////////////////////////////////////////////////////
134// The PooledWriter and it's impls
135////////////////////////////////////////////////////////////////////////////////
136
137/// A [`PooledWriter`] is created by exchanging a writer with a [`Pool`].
138///
139/// The pooled writer will internally buffer writes, sending bytes to the [`Pool`]
140/// after the internal buffer has been filled.
141///
142/// Note that the `compressor_tx` channel is shared by all pooled writers, whereas the `writer_tx`
143/// is specific to the _underlying_ writer that this pooled writer encapsulates.
144#[derive(Debug)]
145pub struct PooledWriter {
146 /// The index/serial number of the pooled writer within the pool
147 writer_index: usize,
148 /// Channel to send messages containing bytes to compress to the compressors' pool.
149 compressor_tx: Sender<CompressorMessage>,
150 /// Channel to send the receiving end of the one-shot channel that will be
151 /// used to send the compressed bytes. This effectively "place holds" the
152 /// position of the compressed bytes in the writers queue until the compressed bytes
153 /// are ready.
154 writer_tx: Sender<oneshot::Receiver<WriterMessage>>,
155 /// The internal buffer to gather bytes to send.
156 buffer: BytesMut,
157 /// The desired size of the internal buffer.
158 buffer_size: usize,
159}
160
161impl PooledWriter {
162 /// Create a new [`PooledWriter`] that has an internal buffer capacity that matches [`bgzf::BGZF_BLOCK_SIZE`].
163 ///
164 /// # Arguments
165 /// - `index` - a usize representing that this is the nth pooled writer created within the pool
166 /// - `compressor_tx` - The channel to send uncompressed bytes to the compressor pool.
167 /// - `writer_tx` - The `Send` end of the channel that transmits the `Receiver` end of the one-shot
168 /// channel, which will be consumed when the compressor sends the compressed bytes.
169 fn new<C>(
170 index: usize,
171 compressor_tx: Sender<CompressorMessage>,
172 writer_tx: Sender<oneshot::Receiver<WriterMessage>>,
173 ) -> Self
174 where
175 C: Compressor,
176 {
177 Self {
178 writer_index: index,
179 compressor_tx,
180 writer_tx,
181 buffer: BytesMut::with_capacity(C::BLOCK_SIZE),
182 buffer_size: C::BLOCK_SIZE,
183 }
184 }
185
186 /// Test whether the internal buffer has reached capacity.
187 #[inline]
188 fn buffer_full(&self) -> bool {
189 self.buffer.len() == self.buffer_size
190 }
191
192 /// Send all bytes in the current buffer to the compressor.
193 ///
194 /// If `is_last` is `true`, the message sent to the compressor will also have the `is_last` true flag set
195 /// and the compressor will finish the BGZF stream.
196 ///
197 /// If `is_last` is not true then only full block will be sent. If `is_last` is true, an incomplete block may be set
198 /// as the final block.
199 fn flush_bytes(&mut self, is_last: bool) -> std::io::Result<()> {
200 if is_last || self.buffer_full() {
201 self.send_block(is_last)?;
202 }
203 Ok(())
204 }
205
206 /// Send a single block
207 fn send_block(&mut self, is_last: bool) -> std::io::Result<()> {
208 let bytes = self.buffer.split_to(self.buffer.len()).freeze();
209 let (mut m, r) = CompressorMessage::new_parts(self.writer_index, bytes);
210 m.is_last = is_last;
211 self.writer_tx
212 .send(r)
213 .map_err(|_e| io::Error::new(io::ErrorKind::Other, PoolError::ChannelSend))?;
214 self.compressor_tx
215 .send(m)
216 .map_err(|_e_| io::Error::new(io::ErrorKind::Other, PoolError::ChannelSend))
217 }
218
219 /// Flush any remaining bytes and consume self, triggering drops of the senders.
220 pub fn close(mut self) -> std::io::Result<()> {
221 self.flush_bytes(true)
222 }
223}
224
225impl Drop for PooledWriter {
226 /// Drop [`PooledWriter`].
227 ///
228 /// This will flush the writer.
229 fn drop(&mut self) {
230 self.flush_bytes(true).unwrap();
231 }
232}
233
234impl Write for PooledWriter {
235 /// Send all bytes in `buf` to the [`Pool`].
236 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
237 let mut bytes_added = 0;
238
239 while bytes_added < buf.len() {
240 let bytes_to_append =
241 std::cmp::min(buf.len() - bytes_added, self.buffer_size - self.buffer.len());
242
243 self.buffer.extend_from_slice(&buf[bytes_added..bytes_added + bytes_to_append]);
244 bytes_added += bytes_to_append;
245 if self.buffer_full() {
246 self.send_block(false)?;
247 }
248 }
249
250 Ok(buf.len())
251 }
252
253 /// Send whatever is in the current buffer even if it is not a full buffer.
254 fn flush(&mut self) -> std::io::Result<()> {
255 self.flush_bytes(false)
256 }
257}
258
259////////////////////////////////////////////////////////////////////////////////
260// The Compressor trait
261////////////////////////////////////////////////////////////////////////////////
262
263/// A [`Compressor`] is used in the compressor pool to compress bytes.
264///
265/// An implementation must be provided as a type to the [`Pool::new`] function so that the pool
266/// knows what kind of compression to use.
267///
268/// See the module level example for more details.
269pub trait Compressor: Sized + Send + 'static
270where
271 Self::CompressionLevel: Clone + Send + 'static,
272 Self::Error: Error + Send + 'static,
273{
274 type Error;
275 type CompressionLevel;
276
277 /// The `BLOCK_SIZE` is used to set the buffer size of the [`PooledWriter`]s and should match the max
278 /// size allowed by the block compression format being used.
279 const BLOCK_SIZE: usize = 65280;
280
281 /// Create a new compressor with the given compression level.
282 fn new(compression_level: Self::CompressionLevel) -> Self;
283
284 /// Returns the default compression level for the compressor.
285 fn default_compression_level() -> Self::CompressionLevel;
286
287 /// Create an instance of the compression level.
288 ///
289 /// The validity of the compression level should be checked here.
290 fn new_compression_level(compression_level: u8) -> Result<Self::CompressionLevel, Self::Error>;
291
292 /// Compress a set of bytes into the `output` vec. If `is_last` is true, and depending on the
293 /// block compression format, an EOF block may be appended as well.
294 fn compress(
295 &mut self,
296 input: &[u8],
297 output: &mut Vec<u8>,
298 is_last: bool,
299 ) -> Result<(), Self::Error>;
300}
301
302////////////////////////////////////////////////////////////////////////////////
303// The messages passed between threads
304////////////////////////////////////////////////////////////////////////////////
305
306/// A message that is sent from a [`PooledWriter`] to the compressor threadpool within a [`Pool`].
307struct CompressorMessage {
308 /// The index of the destination writer
309 writer_index: usize,
310 /// The bytes to compress.
311 buffer: Bytes,
312 /// Where the compressed bytes will be sent after compression.
313 oneshot_tx: oneshot::Sender<WriterMessage>,
314 /// A sentinel value to let the compressor know that the BGZF stream needs an EOF.
315 is_last: bool,
316}
317
318impl CompressorMessage {
319 fn new_parts(writer_index: usize, buffer: Bytes) -> (Self, oneshot::Receiver<WriterMessage>) {
320 let (tx, rx) = oneshot::channel();
321 let new = Self { writer_index, buffer, oneshot_tx: tx, is_last: false };
322 (new, rx)
323 }
324}
325
326/// The compressed bytes to be written to a file.
327///
328/// This is sent from the compressor threadpool to the writer queue in the writer threadpool
329/// via the one-shot channel provided by the [`PooledWriter`].
330#[derive(Debug)]
331struct WriterMessage {
332 buffer: Vec<u8>,
333}
334
335/// Internal enum used by worker threads to dispatch between compression and write work.
336enum WorkItem {
337 Compress(CompressorMessage),
338 Write(usize),
339}
340
341////////////////////////////////////////////////////////////////////////////////
342// The PoolBuilder struct and impls
343////////////////////////////////////////////////////////////////////////////////
344
345/// A struct to make building up a Pool simpler. The builder should be constructed using
346/// [`PoolBuilder::new`], which provides the user control over the sizes of the queues used for
347/// compression and writing. It should be noted that a single compression queue is created,
348/// and one writer queue per writer exchanged. A good starting point for these queue sizes is
349/// two times the number of threads.
350///
351/// Once created various functions can configure aspects of the pool. It is best practice, though
352/// not required, to configure the builder _before_ exchanging writers. The exception is
353/// `queue_size` that may _not_ be set after any writers have been exchanged. If not set manually
354/// then `queue_size` defaults to the number of threads multiplied by
355/// [`PoolBuilder::QUEUE_SIZE_THREAD_MULTIPLES`].
356///
357/// Once the builder is configured writers may be exchanged for [`PooledWriter`]s using the
358/// [`PoolBuilder::exchange`] function, which consumes the provided writer and returns a new
359/// writer that can be used in it's place.
360///
361/// After exchanging all writers the pool may be created and started with [`PoolBuilder::build`]
362/// which consumes the builder and after which no more writers may be exchanged.
363pub struct PoolBuilder<W, C>
364where
365 W: Write + Send + 'static,
366 C: Compressor,
367{
368 writer_index: usize,
369 compression_level: C::CompressionLevel,
370 queue_size: Option<usize>,
371 threads: usize,
372 compressor_tx: Option<Sender<CompressorMessage>>,
373 compressor_rx: Option<Receiver<CompressorMessage>>,
374 writers: Vec<W>,
375 writer_txs: Vec<Sender<oneshot::Receiver<WriterMessage>>>,
376 writer_rxs: Vec<Receiver<oneshot::Receiver<WriterMessage>>>,
377}
378
379impl<W, C> PoolBuilder<W, C>
380where
381 W: Write + Send + 'static,
382 C: Compressor,
383{
384 /// By default queue sizes will be set to threads * this constant.
385 pub const QUEUE_SIZE_THREAD_MULTIPLES: usize = 50;
386
387 /// The default number of threads that will be used if not otherwise configured
388 pub const DEFAULT_THREADS: usize = 4;
389
390 /// Creates a new PoolBuilder that can be used to configure and build a [`Pool`].
391 pub fn new() -> Self {
392 PoolBuilder {
393 writer_index: 0,
394 compression_level: C::default_compression_level(),
395 queue_size: None,
396 threads: Self::DEFAULT_THREADS,
397 compressor_tx: None,
398 compressor_rx: None,
399 writers: vec![],
400 writer_txs: vec![],
401 writer_rxs: vec![],
402 }
403 }
404
405 /// Sets the number of threads that will be used by the [[Pool]].
406 ///
407 /// Will panic if set to 0.
408 pub fn threads(mut self, threads: usize) -> Self {
409 assert!(threads > 0, "Must provide a number of threads greater than 0.");
410 self.threads = threads;
411 self
412 }
413
414 /// Sets the size of queues used by the pool [[Pool]]. The same size is used for
415 /// a) the queue of byte buffers to be compressed, b) the per-sample queues to receive
416 /// compressed bytes, and c) a control queue to manage writing to the underlying writers.
417 ///
418 /// In the worst case scenario the pool can be holding both queue_size uncompressed blocks
419 /// _and_ queue_size compressed blocks in memory when it cannot keep up with the incoming
420 /// load of writes.
421 ///
422 ///
423 ///
424 /// Will panic if called _after_ writers have been created because queues will already have
425 /// been created.
426 pub fn queue_size(mut self, queue_size: usize) -> Self {
427 assert!(self.writers.is_empty(), "Cannot set queue_size after writers are exchanged.");
428 self.queue_size.insert(queue_size);
429 self
430 }
431
432 /// Sets the compression level that will be used by the [[Pool]].
433 pub fn compression_level(mut self, level: u8) -> PoolResult<Self> {
434 self.compression_level = C::new_compression_level(level)
435 .map_err(|e| PoolError::CompressionError(e.to_string()))?;
436 Ok(self)
437 }
438
439 /// If queues/channels are not yet setup, initialize them.
440 fn ensure_queue_is_setup(&mut self) {
441 if self.compressor_tx.is_none() && self.compressor_rx.is_none() {
442 if self.queue_size.is_none() {
443 self.queue_size.insert(self.threads * Self::QUEUE_SIZE_THREAD_MULTIPLES);
444 }
445
446 let (tx, rx) = bounded(self.queue_size.unwrap());
447 self.compressor_tx.insert(tx);
448 self.compressor_rx.insert(rx);
449 }
450 }
451
452 /// Exchanges a writer for a [[PooledWriter]].
453 pub fn exchange(&mut self, writer: W) -> PooledWriter {
454 // Make sure queue/channel configuration is done
455 self.ensure_queue_is_setup();
456
457 let (tx, rx): (
458 Sender<oneshot::Receiver<WriterMessage>>,
459 Receiver<oneshot::Receiver<WriterMessage>>,
460 ) = flume::bounded(self.queue_size.expect("Unreachable"));
461
462 let p = PooledWriter::new::<C>(
463 self.writer_index,
464 self.compressor_tx.as_ref().expect("Unreachable").clone(),
465 tx.clone(),
466 );
467
468 self.writer_index += 1;
469 self.writers.push(writer);
470 self.writer_txs.push(tx);
471 self.writer_rxs.push(rx);
472 p
473 }
474
475 /// Consumes the builder and generates the [[Pool]] ready for use.
476 pub fn build(mut self) -> PoolResult<Pool> {
477 // Make sure the queue/channel configuration is done - this could be necessary if
478 // a pool is created by zero writers exchanged.
479 self.ensure_queue_is_setup();
480
481 // Create the channel to gracefully signal a shutdown of the pool
482 let (shutdown_tx, shutdown_rx) = flume::unbounded();
483
484 // Start the pool manager thread and thread pools
485 let handle = std::thread::spawn(move || {
486 Pool::pool_main::<W, C>(
487 self.threads,
488 self.compression_level,
489 self.compressor_rx.expect("Unreachable."),
490 self.writer_rxs,
491 self.writers,
492 shutdown_rx,
493 )
494 });
495
496 let mut pool = Pool {
497 compressor_tx: self.compressor_tx,
498 shutdown_tx: Some(shutdown_tx),
499 pool_handle: Some(handle),
500 };
501
502 Ok(pool)
503 }
504}
505
506impl<W, C> Default for PoolBuilder<W, C>
507where
508 W: Write + Send + 'static,
509 C: Compressor,
510{
511 fn default() -> Self {
512 Self::new()
513 }
514}
515
516////////////////////////////////////////////////////////////////////////////////
517// The Pool struct and impls
518////////////////////////////////////////////////////////////////////////////////
519
520/// A [`Pool`] orchestrates two different threadpools, a compressor pool and a writer pool.
521///
522/// The pool is suitable for scenarios where there are many more writers than threads, efficiently
523/// managing resources for M writers to N threads.
524#[derive(Debug)]
525pub struct Pool {
526 /// The join handle for the thread that manages all pool resources and coordination.
527 pool_handle: Option<JoinHandle<PoolResult<()>>>,
528 /// The send end of the channel for communicating with the compressor pool.
529 compressor_tx: Option<Sender<CompressorMessage>>,
530 /// Sentinel channel to tell the pool management thread to shutdown.
531 shutdown_tx: Option<Sender<()>>,
532}
533
534impl Pool {
535 /// The main "run" method for the pool that orchestrates all the pieces.
536 ///
537 /// The [`PooledWriter`]s are sending to the compressor, the compressor compresses them, then forwards the compressed bytes.
538 /// The bytes are forwarded to a queue per writer and the writer threads are iterating over that queue pulling down
539 /// all values in the queue at once and writing till the queue is empty.
540 ///
541 /// # Arguments
542 /// - `num_threads` - The number of threads to use.
543 /// - `compression_level` - The compression level to use for the [`Compressor`] pool.
544 /// - `compressor_rx ` - The receiving end of the channel for communicating with the compressor pool.
545 /// - `writer_rxs ` - The receive halves of the channels for the [`PooledWriter`]s to enqueue the one-shot channels.
546 /// - `writers` - The writers that were exchanged for [`PooledWriter`]s.
547 /// - `shutdown_rx` - Sentinel channel to tell the pool management thread to shutdown.
548 #[allow(clippy::unnecessary_wraps, clippy::needless_collect, clippy::needless_pass_by_value)]
549 fn pool_main<W, C>(
550 num_threads: usize,
551 compression_level: C::CompressionLevel,
552 compressor_rx: Receiver<CompressorMessage>,
553 writer_rxs: Vec<Receiver<oneshot::Receiver<WriterMessage>>>, // must be pass by value to allow for easy sharing between threads
554 writers: Vec<W>,
555 shutdown_rx: Receiver<()>,
556 ) -> PoolResult<()>
557 where
558 W: Write + Send + 'static,
559 C: Compressor,
560 {
561 // Add locks to the writers
562 let writers: Arc<Vec<_>> =
563 Arc::new(writers.into_iter().map(|w| Arc::new(Mutex::new(w))).collect());
564
565 // Generate one more channel for queuing up information about when a writer has data
566 // available to be written
567 let (write_available_tx, write_available_rx): (Sender<usize>, Receiver<usize>) =
568 flume::unbounded();
569
570 let thread_handles: Vec<JoinHandle<PoolResult<()>>> = (0..num_threads)
571 .map(|thread_idx| {
572 let compressor_rx = compressor_rx.clone();
573 let mut compressor = C::new(compression_level.clone());
574 let writer_rxs = writer_rxs.clone();
575 let writers = writers.clone();
576 let shutdown_rx = shutdown_rx.clone();
577 let write_available_tx = write_available_tx.clone();
578 let write_available_rx = write_available_rx.clone();
579 let select_timeout = Duration::from_millis(100);
580
581 std::thread::spawn(move || {
582 // Reuse a single compression buffer per thread to avoid
583 // re-allocating ~70KB on every block.
584 let mut compress_buf = Vec::new();
585
586 loop {
587 // Try non-blocking receives first (fast path under load).
588 // Then fall through to Selector which blocks until work
589 // arrives, avoiding the old sleep(25ms) polling delay.
590 let item = if let Ok(msg) = compressor_rx.try_recv() {
591 Some(WorkItem::Compress(msg))
592 } else if let Ok(idx) = write_available_rx.try_recv() {
593 Some(WorkItem::Write(idx))
594 } else {
595 flume::Selector::new()
596 .recv(&compressor_rx, |r| r.ok().map(WorkItem::Compress))
597 .recv(&write_available_rx, |r| r.ok().map(WorkItem::Write))
598 .wait_timeout(select_timeout)
599 .ok()
600 .flatten()
601 };
602
603 match item {
604 Some(WorkItem::Compress(message)) => {
605 let chunk = &message.buffer;
606 compress_buf.clear();
607 compressor
608 .compress(chunk, &mut compress_buf, message.is_last)
609 .map_err(|e| PoolError::CompressionError(e.to_string()))?;
610 message
611 .oneshot_tx
612 .send(WriterMessage { buffer: compress_buf.clone() })
613 .map_err(|_e| PoolError::ChannelSend);
614 write_available_tx.send(message.writer_index);
615 }
616 Some(WorkItem::Write(writer_index)) => {
617 let mut writer = writers[writer_index].lock();
618 let writer_rx = &writer_rxs[writer_index];
619 let one_shot_rx = writer_rx.recv()?;
620 let write_message =
621 one_shot_rx.recv().map_err(|_| PoolError::ChannelSend)?;
622 writer.write_all(&write_message.buffer)?;
623 }
624 None => {
625 // Timeout or channel disconnect. Check if all work
626 // is drained and shutdown was requested.
627 if shutdown_rx.is_disconnected()
628 && write_available_rx.is_empty()
629 && compressor_rx.is_empty()
630 && writer_rxs.iter().all(|w| w.is_empty())
631 {
632 break;
633 }
634 }
635 }
636 }
637
638 Ok(())
639 })
640 })
641 .collect();
642
643 // Close writer handles
644 thread_handles.into_iter().try_for_each(|handle| match handle.join() {
645 Ok(result) => result,
646 Err(e) => std::panic::resume_unwind(e),
647 });
648
649 // Flush each writer
650 writers.iter().try_for_each(|w| w.lock().flush())?;
651
652 Ok(())
653 }
654
655 /// Shutdown all pool resources and close all channels.
656 ///
657 /// Ideally the [`PooledWriter`]s should all have been flushed first, that is up to the user. Any
658 /// further attempts to send to the [`Pool`] will return an error.
659 pub fn stop_pool(&mut self) -> Result<(), PoolError> {
660 // Drop the compressor sender to disconnect the channel. Buffered
661 // messages are preserved by flume and will be drained by the worker
662 // threads before they observe the disconnect and shut down.
663 drop(self.compressor_tx.take().unwrap());
664
665 // Shutdown called to force writers to start checking their receivers for disconnection / empty
666 drop(self.shutdown_tx.take());
667
668 // Wait on the pool thread to finish and pull any errors from it
669 match self.pool_handle.take().unwrap().join() {
670 Ok(result) => result,
671 Err(e) => std::panic::resume_unwind(e),
672 }
673 }
674}
675
676impl Drop for Pool {
677 fn drop(&mut self) {
678 // Check if `stop_pool` has already been called. If it hasn't, call it.
679 if self.compressor_tx.is_some() && self.pool_handle.is_some() {
680 self.stop_pool().unwrap();
681 }
682 }
683}
684
685////////////////////////////////////////////////////////////////////////////////
686// Tests
687////////////////////////////////////////////////////////////////////////////////
688
689#[cfg(test)]
690mod test {
691 use std::{
692 assert_eq, format,
693 fs::File,
694 io::{BufReader, BufWriter},
695 path::{Path, PathBuf},
696 vec,
697 };
698
699 use crate::bgzf::BgzfCompressor;
700
701 use super::*;
702 use ::bgzf::Reader;
703 use proptest::prelude::*;
704 use tempfile::tempdir;
705
706 fn create_output_writer<P: AsRef<Path>>(path: P) -> BufWriter<File> {
707 BufWriter::new(File::create(path).unwrap())
708 }
709
710 fn create_output_file_name(name: impl AsRef<Path>, dir: impl AsRef<Path>) -> PathBuf {
711 let path = dir.as_ref().to_path_buf();
712 path.join(name)
713 }
714
715 #[test]
716 fn test_simple() {
717 let dir = tempdir().unwrap();
718 let output_names: Vec<PathBuf> = (0..20)
719 .map(|i| create_output_file_name(format!("test.{}.txt.gz", i), dir.path()))
720 .collect();
721
722 let output_writers: Vec<BufWriter<File>> =
723 output_names.iter().map(create_output_writer).collect();
724 let mut builder =
725 PoolBuilder::<_, BgzfCompressor>::new().threads(8).compression_level(2).unwrap();
726 let mut pooled_writers: Vec<PooledWriter> =
727 output_writers.into_iter().map(|w| builder.exchange(w)).collect();
728 let mut pool = builder.build().unwrap();
729
730 for (i, writer) in pooled_writers.iter_mut().enumerate() {
731 writer.write_all(format!("This is writer {}.", i).as_bytes()).unwrap();
732 }
733 pooled_writers.into_iter().try_for_each(|mut w| w.flush()).unwrap();
734 pool.stop_pool();
735
736 for (i, path) in output_names.iter().enumerate() {
737 let mut reader = Reader::new(BufReader::new(File::open(path).unwrap()));
738 let mut actual = vec![];
739 reader.read_to_end(&mut actual).unwrap();
740 assert_eq!(actual, format!("This is writer {}.", i).as_bytes());
741 }
742 }
743
744 proptest! {
745 // This test takes around 20 minutes on a 32 core machine to run but is very comprehensive.
746 // Run with `cargo test -- --ignored`
747 #[ignore]
748 #[test]
749 fn test_complete(
750 input_size in 1..=BUFSIZE * 4,
751 buf_size in 1..=BUFSIZE,
752 num_output_files in 1..2*num_cpus::get(),
753 threads in 1..=2+num_cpus::get(),
754 comp_level in 1..=8_u8,
755 write_size in 1..=2*BUFSIZE,
756 ) {
757 let dir = tempdir().unwrap();
758 let output_names: Vec<PathBuf> = (0..num_output_files)
759 .map(|i| create_output_file_name(format!("test.{}.txt.gz", i), dir.path()))
760 .collect();
761 let output_writers: Vec<_> = output_names.iter().map(create_output_writer).collect();
762
763 let mut builder = PoolBuilder::<_, BgzfCompressor>::new()
764 .threads(threads)
765 .compression_level(comp_level)?;
766
767 let mut pooled_writers: Vec<_> = output_writers.into_iter().map(|w| builder.exchange(w)).collect();
768 let mut pool = builder.build()?;
769
770 let inputs: Vec<Vec<u8>> = (0..num_output_files).map(|_| {
771 (0..input_size).map(|_| rand::random::<u8>()).collect()
772 }).collect();
773
774 let chunks = (input_size as f64 / write_size as f64).ceil() as usize;
775
776 // write a chunk to each writer (could randomly select the writers?)
777 for i in (0..chunks) {
778 for (j, writer) in pooled_writers.iter_mut().enumerate() {
779 let input = &inputs[j];
780 let bytes = &input[write_size * i..std::cmp::min(write_size * (i + 1), input.len())];
781 writer.write_all(bytes).unwrap()
782 }
783 }
784
785 pooled_writers.into_iter().try_for_each(|mut w| w.flush()).unwrap();
786 pool.stop_pool();
787
788 for (i, path) in output_names.iter().enumerate() {
789 let mut reader = Reader::new(BufReader::new(File::open(path).unwrap()));
790 let mut actual = vec![];
791 reader.read_to_end(&mut actual).unwrap();
792 assert_eq!(actual, inputs[i]);
793 }
794
795 }
796 }
797}