gzp/par/
decompress.rs

1//! Parallel decompression for block type gzip formats (mgzip, bgzf)
2
3use std::{
4    io::{self, Read},
5    thread::JoinHandle,
6};
7
8use bytes::{BufMut, Bytes, BytesMut};
9use flate2::read::MultiGzDecoder;
10pub use flate2::Compression;
11use flume::{bounded, unbounded, Receiver, Sender};
12use log::warn;
13
14use crate::{BlockFormatSpec, Check, GzpError, BUFSIZE, DICT_SIZE};
15
16#[derive(Debug)]
17pub struct ParDecompressBuilder<F>
18where
19    F: BlockFormatSpec,
20{
21    buffer_size: usize,
22    num_threads: usize,
23    format: F,
24    pin_threads: Option<usize>,
25}
26
27impl<F> ParDecompressBuilder<F>
28where
29    F: BlockFormatSpec,
30{
31    pub fn new() -> Self {
32        Self {
33            buffer_size: BUFSIZE,
34            num_threads: num_cpus::get(),
35            format: F::new(),
36            pin_threads: None,
37        }
38    }
39
40    pub fn buffer_size(mut self, buffer_size: usize) -> Result<Self, GzpError> {
41        if buffer_size < DICT_SIZE {
42            return Err(GzpError::BufferSize(buffer_size, DICT_SIZE));
43        }
44        self.buffer_size = buffer_size;
45        Ok(self)
46    }
47
48    /// Set the number of threads and verify that that they are > 0 ensuring the mulit-threaded decompression will be attempted.
49    pub fn num_threads(mut self, num_threads: usize) -> Result<Self, GzpError> {
50        if num_threads == 0 {
51            return Err(GzpError::NumThreads(num_threads));
52        }
53        self.num_threads = num_threads;
54        Ok(self)
55    }
56
57    /// Set the [`pin_threads`](ParDecompressBuilder.pin_threads).
58    pub fn pin_threads(mut self, pin_threads: Option<usize>) -> Self {
59        if core_affinity::get_core_ids().is_none() {
60            warn!("Pinning threads is not supported on your platform. Please see core_affinity_rs. No threads will be pinned, but everything will work.");
61            self.pin_threads = None;
62        } else {
63            self.pin_threads = pin_threads;
64        }
65        self
66    }
67
68    /// Build a guaranteed multi-threaded decompressor
69    pub fn from_reader<R: Read + Send + 'static>(self, reader: R) -> ParDecompress<F> {
70        let (tx_reader, rx_reader) = bounded(self.num_threads * 2);
71        let buffer_size = self.buffer_size;
72        let format = self.format;
73        let pin_threads = self.pin_threads;
74        let handle = std::thread::spawn(move || {
75            ParDecompress::run(&tx_reader, reader, self.num_threads, format, pin_threads)
76        });
77        ParDecompress {
78            handle: Some(handle),
79            rx_reader: Some(rx_reader),
80            buffer: BytesMut::new(),
81            buffer_size,
82            format,
83        }
84    }
85
86    /// Set the number of threads and allow 0 threads.
87    pub fn maybe_num_threads(mut self, num_threads: usize) -> Self {
88        self.num_threads = num_threads;
89        self
90    }
91
92    /// If `num_threads` is 0, this returns a single-threaded decompressor
93    pub fn maybe_par_from_reader<R: Read + Send + 'static>(self, reader: R) -> Box<dyn Read> {
94        if self.num_threads == 0 {
95            Box::new(MultiGzDecoder::new(reader))
96        } else {
97            Box::new(self.from_reader(reader))
98        }
99    }
100}
101
102impl<F> Default for ParDecompressBuilder<F>
103where
104    F: BlockFormatSpec,
105{
106    fn default() -> Self {
107        Self::new()
108    }
109}
110
111#[allow(unused)]
112pub struct ParDecompress<F>
113where
114    F: BlockFormatSpec,
115{
116    handle: Option<std::thread::JoinHandle<Result<(), GzpError>>>,
117    rx_reader: Option<Receiver<Receiver<BytesMut>>>,
118    buffer: BytesMut,
119    buffer_size: usize,
120    format: F,
121}
122
123impl<F> ParDecompress<F>
124where
125    F: BlockFormatSpec,
126{
127    pub fn builder() -> ParDecompressBuilder<F> {
128        ParDecompressBuilder::new()
129    }
130
131    #[allow(clippy::needless_collect)]
132    fn run<R>(
133        tx_reader: &Sender<Receiver<BytesMut>>,
134        mut reader: R,
135        num_threads: usize,
136        format: F,
137        pin_threads: Option<usize>,
138    ) -> Result<(), GzpError>
139    where
140        R: Read + Send + 'static,
141    {
142        let (tx, rx): (Sender<DMessage>, Receiver<DMessage>) = bounded(num_threads * 2);
143
144        let (core_ids, pin_threads) = if let Some(core_ids) = core_affinity::get_core_ids() {
145            (core_ids, pin_threads)
146        } else {
147            // Handle the case where core affinity doesn't work for a platform.
148            // We test and warn in the constructors for this case, so no warning should be needed here.
149            (vec![], None)
150        };
151        let handles: Vec<JoinHandle<Result<(), GzpError>>> = (0..num_threads)
152            .map(|i| {
153                let rx = rx.clone();
154                let core_ids = core_ids.clone();
155                std::thread::spawn(move || -> Result<(), GzpError> {
156                    if let Some(pin_at) = pin_threads {
157                        if let Some(id) = core_ids.get(pin_at + i) {
158                            core_affinity::set_for_current(*id);
159                        }
160                    }
161                    let mut decompressor = format.create_decompressor();
162                    while let Ok(m) = rx.recv() {
163                        let check_values = format.get_footer_values(&m.buffer[..]);
164                        let result = if check_values.amount != 0 {
165                            format.decode_block(
166                                &mut decompressor,
167                                &m.buffer[..m.buffer.len() - 8],
168                                check_values.amount as usize,
169                            )?
170                        } else {
171                            vec![]
172                        };
173
174                        let mut check = F::B::new();
175                        check.update(&result);
176
177                        if check.sum() != check_values.sum {
178                            return Err(GzpError::InvalidCheck {
179                                found: check.sum(),
180                                expected: check_values.sum,
181                            });
182                        }
183                        m.oneshot
184                            .send(BytesMut::from(&result[..]))
185                            .map_err(|_e| GzpError::ChannelSend)?;
186                    }
187                    Ok(())
188                })
189            })
190            // This collect is needed to force the evaluation, otherwise this thread will block on writes waiting
191            // for data to show up that will never come since the iterator is lazy.
192            .collect();
193
194        // Reader
195        loop {
196            // Read gzip header
197            let mut buf = vec![0; F::HEADER_SIZE];
198            if let Ok(()) = reader.read_exact(&mut buf) {
199                format.check_header(&buf)?;
200                let size = format.get_block_size(&buf)?;
201                let mut remainder = vec![0; size - F::HEADER_SIZE];
202                reader.read_exact(&mut remainder)?;
203                let (m, r) = DMessage::new_parts(Bytes::from(remainder));
204
205                tx_reader.send(r).map_err(|_e| GzpError::ChannelSend)?;
206                tx.send(m).map_err(|_e| GzpError::ChannelSend)?;
207            } else {
208                break; // EOF
209            }
210        }
211        drop(tx);
212
213        // Gracefully shutdown the compression threads
214        handles
215            .into_iter()
216            .try_for_each(|handle| match handle.join() {
217                Ok(result) => result,
218                Err(e) => std::panic::resume_unwind(e),
219            })
220    }
221
222    /// Close things in such a way as to get errors
223    pub fn finish(&mut self) -> Result<(), GzpError> {
224        if self.rx_reader.is_some() {
225            drop(self.rx_reader.take());
226        }
227        if self.handle.is_some() {
228            match self.handle.take().unwrap().join() {
229                Ok(result) => result,
230                Err(e) => std::panic::resume_unwind(e),
231            }
232        } else {
233            Ok(())
234        }
235    }
236}
237
238#[derive(Debug)]
239#[allow(dead_code)]
240pub(crate) struct DMessage {
241    buffer: Bytes,
242    oneshot: Sender<BytesMut>,
243    is_last: bool,
244}
245
246impl DMessage {
247    pub(crate) fn new_parts(buffer: Bytes) -> (Self, Receiver<BytesMut>) {
248        let (tx, rx) = unbounded();
249        (
250            DMessage {
251                buffer,
252                oneshot: tx,
253                is_last: false,
254            },
255            rx,
256        )
257    }
258}
259
260impl<F> Read for ParDecompress<F>
261where
262    F: BlockFormatSpec,
263{
264    // Ok(0) means done
265    fn read(&mut self, mut buf: &mut [u8]) -> io::Result<usize> {
266        let mut bytes_copied = 0;
267        let asked_for_bytes = buf.len();
268        loop {
269            if bytes_copied == asked_for_bytes {
270                break;
271            }
272
273            // First try to use up anything in current buffer
274            if !self.buffer.is_empty() {
275                let curr_len = self.buffer.len();
276                let to_copy = &self
277                    .buffer
278                    .split_to(std::cmp::min(buf.remaining_mut(), curr_len));
279
280                buf.put(&to_copy[..]);
281                bytes_copied += to_copy.len();
282            } else if self.rx_reader.is_some() {
283                // Then pull from channel of buffers
284                match self.rx_reader.as_mut().unwrap().recv() {
285                    Ok(new_buffer_chan) => {
286                        self.buffer = match new_buffer_chan.recv() {
287                            Ok(b) => b,
288                            Err(_recv_error) => {
289                                // If an error occurred receiving, that means the senders have been dropped and the
290                                // decompressor thread hit an error. Collect that error here, and if it was an Io
291                                // error, preserve it.
292                                let error = match self.handle.take().unwrap().join() {
293                                    Ok(result) => result,
294                                    Err(e) => std::panic::resume_unwind(e),
295                                };
296
297                                let err = match error {
298                                    Ok(()) => {
299                                        self.rx_reader.take();
300                                        break;
301                                    } // finished reading file
302                                    Err(GzpError::Io(ioerr)) => ioerr,
303                                    Err(err) => io::Error::other(err),
304                                };
305                                self.rx_reader.take();
306                                return Err(err);
307                            }
308                        };
309                    }
310                    Err(_recv_error) => {
311                        // If an error occurred receiving, that means the senders have been dropped and the
312                        // decompressor thread hit an error. Collect that error here, and if it was an Io
313                        // error, preserve it.
314                        let error = match self.handle.take().unwrap().join() {
315                            Ok(result) => result,
316                            Err(e) => std::panic::resume_unwind(e),
317                        };
318
319                        let err = match error {
320                            Ok(()) => {
321                                self.rx_reader.take();
322                                break;
323                            } // finished reading file
324                            Err(GzpError::Io(ioerr)) => ioerr,
325                            Err(err) => io::Error::other(err),
326                        };
327                        self.rx_reader.take();
328                        return Err(err);
329                    }
330                }
331            } else {
332                break;
333            }
334        }
335        Ok(bytes_copied)
336    }
337}
338
339impl<F> Drop for ParDecompress<F>
340where
341    F: BlockFormatSpec,
342{
343    fn drop(&mut self) {
344        if self.rx_reader.is_some() {
345            match self.finish() {
346                // ChannelSend errors are acceptable since we just dropped the receiver to cause the shutdown
347                Ok(()) | Err(GzpError::ChannelSend) => (),
348                Err(err) => std::panic::resume_unwind(Box::new(err)),
349            }
350        }
351    }
352}