1use std::{
4 io::{self, Read},
5 thread::JoinHandle,
6};
7
8use bytes::{BufMut, Bytes, BytesMut};
9use flate2::read::MultiGzDecoder;
10pub use flate2::Compression;
11use flume::{bounded, unbounded, Receiver, Sender};
12use log::warn;
13
14use crate::{BlockFormatSpec, Check, GzpError, BUFSIZE, DICT_SIZE};
15
16#[derive(Debug)]
17pub struct ParDecompressBuilder<F>
18where
19 F: BlockFormatSpec,
20{
21 buffer_size: usize,
22 num_threads: usize,
23 format: F,
24 pin_threads: Option<usize>,
25}
26
27impl<F> ParDecompressBuilder<F>
28where
29 F: BlockFormatSpec,
30{
31 pub fn new() -> Self {
32 Self {
33 buffer_size: BUFSIZE,
34 num_threads: num_cpus::get(),
35 format: F::new(),
36 pin_threads: None,
37 }
38 }
39
40 pub fn buffer_size(mut self, buffer_size: usize) -> Result<Self, GzpError> {
41 if buffer_size < DICT_SIZE {
42 return Err(GzpError::BufferSize(buffer_size, DICT_SIZE));
43 }
44 self.buffer_size = buffer_size;
45 Ok(self)
46 }
47
48 pub fn num_threads(mut self, num_threads: usize) -> Result<Self, GzpError> {
50 if num_threads == 0 {
51 return Err(GzpError::NumThreads(num_threads));
52 }
53 self.num_threads = num_threads;
54 Ok(self)
55 }
56
57 pub fn pin_threads(mut self, pin_threads: Option<usize>) -> Self {
59 if core_affinity::get_core_ids().is_none() {
60 warn!("Pinning threads is not supported on your platform. Please see core_affinity_rs. No threads will be pinned, but everything will work.");
61 self.pin_threads = None;
62 } else {
63 self.pin_threads = pin_threads;
64 }
65 self
66 }
67
68 pub fn from_reader<R: Read + Send + 'static>(self, reader: R) -> ParDecompress<F> {
70 let (tx_reader, rx_reader) = bounded(self.num_threads * 2);
71 let buffer_size = self.buffer_size;
72 let format = self.format;
73 let pin_threads = self.pin_threads;
74 let handle = std::thread::spawn(move || {
75 ParDecompress::run(&tx_reader, reader, self.num_threads, format, pin_threads)
76 });
77 ParDecompress {
78 handle: Some(handle),
79 rx_reader: Some(rx_reader),
80 buffer: BytesMut::new(),
81 buffer_size,
82 format,
83 }
84 }
85
86 pub fn maybe_num_threads(mut self, num_threads: usize) -> Self {
88 self.num_threads = num_threads;
89 self
90 }
91
92 pub fn maybe_par_from_reader<R: Read + Send + 'static>(self, reader: R) -> Box<dyn Read> {
94 if self.num_threads == 0 {
95 Box::new(MultiGzDecoder::new(reader))
96 } else {
97 Box::new(self.from_reader(reader))
98 }
99 }
100}
101
102impl<F> Default for ParDecompressBuilder<F>
103where
104 F: BlockFormatSpec,
105{
106 fn default() -> Self {
107 Self::new()
108 }
109}
110
111#[allow(unused)]
112pub struct ParDecompress<F>
113where
114 F: BlockFormatSpec,
115{
116 handle: Option<std::thread::JoinHandle<Result<(), GzpError>>>,
117 rx_reader: Option<Receiver<Receiver<BytesMut>>>,
118 buffer: BytesMut,
119 buffer_size: usize,
120 format: F,
121}
122
123impl<F> ParDecompress<F>
124where
125 F: BlockFormatSpec,
126{
127 pub fn builder() -> ParDecompressBuilder<F> {
128 ParDecompressBuilder::new()
129 }
130
131 #[allow(clippy::needless_collect)]
132 fn run<R>(
133 tx_reader: &Sender<Receiver<BytesMut>>,
134 mut reader: R,
135 num_threads: usize,
136 format: F,
137 pin_threads: Option<usize>,
138 ) -> Result<(), GzpError>
139 where
140 R: Read + Send + 'static,
141 {
142 let (tx, rx): (Sender<DMessage>, Receiver<DMessage>) = bounded(num_threads * 2);
143
144 let (core_ids, pin_threads) = if let Some(core_ids) = core_affinity::get_core_ids() {
145 (core_ids, pin_threads)
146 } else {
147 (vec![], None)
150 };
151 let handles: Vec<JoinHandle<Result<(), GzpError>>> = (0..num_threads)
152 .map(|i| {
153 let rx = rx.clone();
154 let core_ids = core_ids.clone();
155 std::thread::spawn(move || -> Result<(), GzpError> {
156 if let Some(pin_at) = pin_threads {
157 if let Some(id) = core_ids.get(pin_at + i) {
158 core_affinity::set_for_current(*id);
159 }
160 }
161 let mut decompressor = format.create_decompressor();
162 while let Ok(m) = rx.recv() {
163 let check_values = format.get_footer_values(&m.buffer[..]);
164 let result = if check_values.amount != 0 {
165 format.decode_block(
166 &mut decompressor,
167 &m.buffer[..m.buffer.len() - 8],
168 check_values.amount as usize,
169 )?
170 } else {
171 vec![]
172 };
173
174 let mut check = F::B::new();
175 check.update(&result);
176
177 if check.sum() != check_values.sum {
178 return Err(GzpError::InvalidCheck {
179 found: check.sum(),
180 expected: check_values.sum,
181 });
182 }
183 m.oneshot
184 .send(BytesMut::from(&result[..]))
185 .map_err(|_e| GzpError::ChannelSend)?;
186 }
187 Ok(())
188 })
189 })
190 .collect();
193
194 loop {
196 let mut buf = vec![0; F::HEADER_SIZE];
198 if let Ok(()) = reader.read_exact(&mut buf) {
199 format.check_header(&buf)?;
200 let size = format.get_block_size(&buf)?;
201 let mut remainder = vec![0; size - F::HEADER_SIZE];
202 reader.read_exact(&mut remainder)?;
203 let (m, r) = DMessage::new_parts(Bytes::from(remainder));
204
205 tx_reader.send(r).map_err(|_e| GzpError::ChannelSend)?;
206 tx.send(m).map_err(|_e| GzpError::ChannelSend)?;
207 } else {
208 break; }
210 }
211 drop(tx);
212
213 handles
215 .into_iter()
216 .try_for_each(|handle| match handle.join() {
217 Ok(result) => result,
218 Err(e) => std::panic::resume_unwind(e),
219 })
220 }
221
222 pub fn finish(&mut self) -> Result<(), GzpError> {
224 if self.rx_reader.is_some() {
225 drop(self.rx_reader.take());
226 }
227 if self.handle.is_some() {
228 match self.handle.take().unwrap().join() {
229 Ok(result) => result,
230 Err(e) => std::panic::resume_unwind(e),
231 }
232 } else {
233 Ok(())
234 }
235 }
236}
237
238#[derive(Debug)]
239#[allow(dead_code)]
240pub(crate) struct DMessage {
241 buffer: Bytes,
242 oneshot: Sender<BytesMut>,
243 is_last: bool,
244}
245
246impl DMessage {
247 pub(crate) fn new_parts(buffer: Bytes) -> (Self, Receiver<BytesMut>) {
248 let (tx, rx) = unbounded();
249 (
250 DMessage {
251 buffer,
252 oneshot: tx,
253 is_last: false,
254 },
255 rx,
256 )
257 }
258}
259
260impl<F> Read for ParDecompress<F>
261where
262 F: BlockFormatSpec,
263{
264 fn read(&mut self, mut buf: &mut [u8]) -> io::Result<usize> {
266 let mut bytes_copied = 0;
267 let asked_for_bytes = buf.len();
268 loop {
269 if bytes_copied == asked_for_bytes {
270 break;
271 }
272
273 if !self.buffer.is_empty() {
275 let curr_len = self.buffer.len();
276 let to_copy = &self
277 .buffer
278 .split_to(std::cmp::min(buf.remaining_mut(), curr_len));
279
280 buf.put(&to_copy[..]);
281 bytes_copied += to_copy.len();
282 } else if self.rx_reader.is_some() {
283 match self.rx_reader.as_mut().unwrap().recv() {
285 Ok(new_buffer_chan) => {
286 self.buffer = match new_buffer_chan.recv() {
287 Ok(b) => b,
288 Err(_recv_error) => {
289 let error = match self.handle.take().unwrap().join() {
293 Ok(result) => result,
294 Err(e) => std::panic::resume_unwind(e),
295 };
296
297 let err = match error {
298 Ok(()) => {
299 self.rx_reader.take();
300 break;
301 } Err(GzpError::Io(ioerr)) => ioerr,
303 Err(err) => io::Error::other(err),
304 };
305 self.rx_reader.take();
306 return Err(err);
307 }
308 };
309 }
310 Err(_recv_error) => {
311 let error = match self.handle.take().unwrap().join() {
315 Ok(result) => result,
316 Err(e) => std::panic::resume_unwind(e),
317 };
318
319 let err = match error {
320 Ok(()) => {
321 self.rx_reader.take();
322 break;
323 } Err(GzpError::Io(ioerr)) => ioerr,
325 Err(err) => io::Error::other(err),
326 };
327 self.rx_reader.take();
328 return Err(err);
329 }
330 }
331 } else {
332 break;
333 }
334 }
335 Ok(bytes_copied)
336 }
337}
338
339impl<F> Drop for ParDecompress<F>
340where
341 F: BlockFormatSpec,
342{
343 fn drop(&mut self) {
344 if self.rx_reader.is_some() {
345 match self.finish() {
346 Ok(()) | Err(GzpError::ChannelSend) => (),
348 Err(err) => std::panic::resume_unwind(Box::new(err)),
349 }
350 }
351 }
352}