1use std::{
2 io::{self, BufRead, Read, Seek, SeekFrom},
3 mem,
4 num::NonZeroUsize,
5 thread::{self, JoinHandle},
6};
7
8use crossbeam_channel::{Receiver, Sender};
9
10use super::Block;
11use crate::{gzi, VirtualPosition};
12
13type BufferedTx = Sender<io::Result<Buffer>>;
14type BufferedRx = Receiver<io::Result<Buffer>>;
15type InflateTx = Sender<(Buffer, BufferedTx)>;
16type InflateRx = Receiver<(Buffer, BufferedTx)>;
17type ReadTx = Sender<BufferedRx>;
18type ReadRx = Receiver<BufferedRx>;
19type RecycleTx = Sender<Buffer>;
20type RecycleRx = Receiver<Buffer>;
21
22enum State<R> {
23 Paused(R),
24 Running {
25 reader_handle: JoinHandle<Result<R, ReadError<R>>>,
26 inflater_handles: Vec<JoinHandle<()>>,
27 read_rx: ReadRx,
28 recycle_tx: RecycleTx,
29 },
30 Done,
31}
32
33#[derive(Debug, Default)]
34struct Buffer {
35 buf: Vec<u8>,
36 block: Block,
37}
38
39pub struct MultithreadedReader<R> {
44 state: State<R>,
45 worker_count: NonZeroUsize,
46 position: u64,
47 buffer: Buffer,
48}
49
50impl<R> MultithreadedReader<R> {
51 pub fn position(&self) -> u64 {
62 self.position
63 }
64
65 pub fn virtual_position(&self) -> VirtualPosition {
76 self.buffer.block.virtual_position()
77 }
78
79 pub fn finish(&mut self) -> io::Result<R> {
91 let state = mem::replace(&mut self.state, State::Done);
92
93 match state {
94 State::Paused(inner) => Ok(inner),
95 State::Running {
96 reader_handle,
97 mut inflater_handles,
98 recycle_tx,
99 ..
100 } => {
101 drop(recycle_tx);
102
103 for handle in inflater_handles.drain(..) {
104 handle.join().unwrap();
105 }
106
107 reader_handle.join().unwrap().map_err(|e| e.1)
108 }
109 State::Done => panic!("invalid state"),
110 }
111 }
112}
113
114impl<R> MultithreadedReader<R>
115where
116 R: Read + Send + 'static,
117{
118 pub fn new(inner: R) -> Self {
128 Self::with_worker_count(NonZeroUsize::MIN, inner)
129 }
130
131 pub fn with_worker_count(worker_count: NonZeroUsize, inner: R) -> Self {
142 Self {
143 state: State::Paused(inner),
144 worker_count,
145 position: 0,
146 buffer: Buffer::default(),
147 }
148 }
149
150 pub fn get_mut(&mut self) -> &mut R {
161 self.pause();
162
163 match &mut self.state {
164 State::Paused(inner) => inner,
165 _ => panic!("invalid state"),
166 }
167 }
168
169 fn resume(&mut self) {
170 if matches!(self.state, State::Running { .. }) {
171 return;
172 }
173
174 let state = mem::replace(&mut self.state, State::Done);
175
176 let State::Paused(inner) = state else {
177 panic!("invalid state");
178 };
179
180 let worker_count = self.worker_count.get();
181
182 let (inflate_tx, inflate_rx) = crossbeam_channel::bounded(worker_count);
183 let (read_tx, read_rx) = crossbeam_channel::bounded(worker_count);
184 let (recycle_tx, recycle_rx) = crossbeam_channel::bounded(worker_count);
185
186 for _ in 0..worker_count {
187 recycle_tx.send(Buffer::default()).unwrap();
188 }
189
190 let reader_handle = spawn_reader(inner, inflate_tx, read_tx, recycle_rx);
191 let inflater_handles = spawn_inflaters(self.worker_count, inflate_rx);
192
193 self.state = State::Running {
194 reader_handle,
195 inflater_handles,
196 read_rx,
197 recycle_tx,
198 };
199 }
200
201 fn pause(&mut self) {
202 if matches!(self.state, State::Paused(_)) {
203 return;
204 }
205
206 let state = mem::replace(&mut self.state, State::Done);
207
208 let State::Running {
209 reader_handle,
210 mut inflater_handles,
211 recycle_tx,
212 ..
213 } = state
214 else {
215 panic!("invalid state");
216 };
217
218 drop(recycle_tx);
219
220 for handle in inflater_handles.drain(..) {
221 handle.join().unwrap();
222 }
223
224 let inner = match reader_handle.join().unwrap() {
226 Ok(inner) => inner,
227 Err(ReadError(inner, _)) => inner,
228 };
229
230 self.state = State::Paused(inner);
231 }
232
233 fn read_block(&mut self) -> io::Result<()> {
234 self.resume();
235
236 let State::Running {
237 read_rx,
238 recycle_tx,
239 ..
240 } = &self.state
241 else {
242 panic!("invalid state");
243 };
244
245 while let Some(mut buffer) = recv_buffer(read_rx)? {
246 buffer.block.set_position(self.position);
247 self.position += buffer.block.size();
248
249 let prev_buffer = mem::replace(&mut self.buffer, buffer);
250 recycle_tx.send(prev_buffer).ok();
251
252 if self.buffer.block.data().len() > 0 {
253 break;
254 }
255 }
256
257 Ok(())
258 }
259}
260
261impl<R> Drop for MultithreadedReader<R> {
262 fn drop(&mut self) {
263 if !matches!(self.state, State::Done) {
264 let _ = self.finish();
265 }
266 }
267}
268
269impl<R> Read for MultithreadedReader<R>
270where
271 R: Read + Send + 'static,
272{
273 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
274 let mut src = self.fill_buf()?;
275 let amt = src.read(buf)?;
276 self.consume(amt);
277 Ok(amt)
278 }
279
280 fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
281 use super::reader::default_read_exact;
282
283 if let Some(src) = self.buffer.block.data().as_ref().get(..buf.len()) {
284 buf.copy_from_slice(src);
285 self.consume(src.len());
286 Ok(())
287 } else {
288 default_read_exact(self, buf)
289 }
290 }
291}
292
293impl<R> BufRead for MultithreadedReader<R>
294where
295 R: Read + Send + 'static,
296{
297 fn fill_buf(&mut self) -> io::Result<&[u8]> {
298 if !self.buffer.block.data().has_remaining() {
299 self.read_block()?;
300 }
301
302 Ok(self.buffer.block.data().as_ref())
303 }
304
305 fn consume(&mut self, amt: usize) {
306 self.buffer.block.data_mut().consume(amt);
307 }
308}
309
310impl<R> crate::io::Read for MultithreadedReader<R>
311where
312 R: Read + Send + 'static,
313{
314 fn virtual_position(&self) -> VirtualPosition {
315 self.buffer.block.virtual_position()
316 }
317}
318
319impl<R> crate::io::BufRead for MultithreadedReader<R> where R: Read + Send + 'static {}
320
321impl<R> crate::io::Seek for MultithreadedReader<R>
322where
323 R: Read + Send + Seek + 'static,
324{
325 fn seek_to_virtual_position(&mut self, pos: VirtualPosition) -> io::Result<VirtualPosition> {
326 let (cpos, upos) = pos.into();
327
328 self.get_mut().seek(SeekFrom::Start(cpos))?;
329 self.position = cpos;
330
331 self.read_block()?;
332
333 self.buffer.block.data_mut().set_position(usize::from(upos));
334
335 Ok(pos)
336 }
337
338 fn seek_with_index(&mut self, index: &gzi::Index, pos: SeekFrom) -> io::Result<u64> {
339 let SeekFrom::Start(pos) = pos else {
340 unimplemented!();
341 };
342
343 let virtual_position = index.query(pos)?;
344 self.seek_to_virtual_position(virtual_position)?;
345 Ok(pos)
346 }
347}
348
349fn recv_buffer(read_rx: &ReadRx) -> io::Result<Option<Buffer>> {
350 if let Ok(buffered_rx) = read_rx.recv() {
351 if let Ok(buffer) = buffered_rx.recv() {
352 return buffer.map(Some);
353 }
354 }
355
356 Ok(None)
357}
358
359struct ReadError<R>(R, io::Error);
360
361fn spawn_reader<R>(
362 mut reader: R,
363 inflate_tx: InflateTx,
364 read_tx: ReadTx,
365 recycle_rx: RecycleRx,
366) -> JoinHandle<Result<R, ReadError<R>>>
367where
368 R: Read + Send + 'static,
369{
370 use super::reader::frame::read_frame_into;
371
372 thread::spawn(move || {
373 while let Ok(mut buffer) = recycle_rx.recv() {
374 match read_frame_into(&mut reader, &mut buffer.buf) {
375 Ok(result) if result.is_none() => break,
376 Ok(_) => {}
377 Err(e) => return Err(ReadError(reader, e)),
378 }
379
380 let (buffered_tx, buffered_rx) = crossbeam_channel::bounded(1);
381
382 inflate_tx.send((buffer, buffered_tx)).unwrap();
383 read_tx.send(buffered_rx).unwrap();
384 }
385
386 Ok(reader)
387 })
388}
389
390fn spawn_inflaters(worker_count: NonZeroUsize, inflate_rx: InflateRx) -> Vec<JoinHandle<()>> {
391 use super::reader::frame::parse_block;
392
393 (0..worker_count.get())
394 .map(|_| {
395 let inflate_rx = inflate_rx.clone();
396
397 thread::spawn(move || {
398 while let Ok((mut buffer, buffered_tx)) = inflate_rx.recv() {
399 let result = parse_block(&buffer.buf, &mut buffer.block).map(|_| buffer);
400 buffered_tx.send(result).unwrap();
401 }
402 })
403 })
404 .collect()
405}
406
407#[cfg(test)]
408mod tests {
409 use std::io::Cursor;
410
411 use super::*;
412
413 #[test]
414 fn test_seek_to_virtual_position() -> Result<(), Box<dyn std::error::Error>> {
415 use crate::io::Seek;
416
417 #[rustfmt::skip]
418 static DATA: &[u8] = &[
419 0x1f, 0x8b, 0x08, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0x06, 0x00, 0x42, 0x43,
421 0x02, 0x00, 0x22, 0x00, 0xcb, 0xcb, 0xcf, 0x4f, 0xc9, 0x49, 0x2d, 0x06, 0x00, 0xa1,
422 0x58, 0x2a, 0x80, 0x07, 0x00, 0x00, 0x00,
423 0x1f, 0x8b, 0x08, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0x06, 0x00, 0x42, 0x43,
425 0x02, 0x00, 0x1b, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
426 ];
427
428 const EOF_VIRTUAL_POSITION: VirtualPosition = match VirtualPosition::new(63, 0) {
429 Some(pos) => pos,
430 None => unreachable!(),
431 };
432
433 const VIRTUAL_POSITION: VirtualPosition = match VirtualPosition::new(0, 3) {
434 Some(pos) => pos,
435 None => unreachable!(),
436 };
437
438 let mut reader =
439 MultithreadedReader::with_worker_count(NonZeroUsize::MIN, Cursor::new(DATA));
440
441 let mut buf = Vec::new();
442 reader.read_to_end(&mut buf)?;
443
444 assert_eq!(reader.virtual_position(), EOF_VIRTUAL_POSITION);
445
446 reader.seek_to_virtual_position(VIRTUAL_POSITION)?;
447
448 buf.clear();
449 reader.read_to_end(&mut buf)?;
450
451 assert_eq!(buf, b"dles");
452 assert_eq!(reader.virtual_position(), EOF_VIRTUAL_POSITION);
453
454 Ok(())
455 }
456}