noodles_bgzf/io/
multithreaded_reader.rs

1use std::{
2    io::{self, BufRead, Read, Seek, SeekFrom},
3    mem,
4    num::NonZeroUsize,
5    thread::{self, JoinHandle},
6};
7
8use crossbeam_channel::{Receiver, Sender};
9
10use super::Block;
11use crate::{gzi, VirtualPosition};
12
13type BufferedTx = Sender<io::Result<Buffer>>;
14type BufferedRx = Receiver<io::Result<Buffer>>;
15type InflateTx = Sender<(Buffer, BufferedTx)>;
16type InflateRx = Receiver<(Buffer, BufferedTx)>;
17type ReadTx = Sender<BufferedRx>;
18type ReadRx = Receiver<BufferedRx>;
19type RecycleTx = Sender<Buffer>;
20type RecycleRx = Receiver<Buffer>;
21
22enum State<R> {
23    Paused(R),
24    Running {
25        reader_handle: JoinHandle<Result<R, ReadError<R>>>,
26        inflater_handles: Vec<JoinHandle<()>>,
27        read_rx: ReadRx,
28        recycle_tx: RecycleTx,
29    },
30    Done,
31}
32
33#[derive(Debug, Default)]
34struct Buffer {
35    buf: Vec<u8>,
36    block: Block,
37}
38
39/// A multithreaded BGZF reader.
40///
41/// This is a multithreaded BGZF reader that uses a thread pool to decompress block data. It places
42/// the inner reader on its own thread to read raw frames asynchronously.
43pub struct MultithreadedReader<R> {
44    state: State<R>,
45    worker_count: NonZeroUsize,
46    position: u64,
47    buffer: Buffer,
48}
49
50impl<R> MultithreadedReader<R> {
51    /// Returns the current position of the stream.
52    ///
53    /// # Examples
54    ///
55    /// ```
56    /// # use std::io;
57    /// use noodles_bgzf as bgzf;
58    /// let reader = bgzf::MultithreadedReader::new(io::empty());
59    /// assert_eq!(reader.position(), 0);
60    /// ```
61    pub fn position(&self) -> u64 {
62        self.position
63    }
64
65    /// Returns the current virtual position of the stream.
66    ///
67    /// # Examples
68    ///
69    /// ```
70    /// # use std::io;
71    /// use noodles_bgzf as bgzf;
72    /// let reader = bgzf::MultithreadedReader::new(io::empty());
73    /// assert_eq!(reader.virtual_position(), bgzf::VirtualPosition::MIN);
74    /// ```
75    pub fn virtual_position(&self) -> VirtualPosition {
76        self.buffer.block.virtual_position()
77    }
78
79    /// Shuts down the reader.
80    ///
81    /// # Examples
82    ///
83    /// ```
84    /// # use std::io;
85    /// use noodles_bgzf as bgzf;
86    /// let mut reader = bgzf::MultithreadedReader::new(io::empty());
87    /// reader.finish()?;
88    /// # Ok::<_, io::Error>(())
89    /// ```
90    pub fn finish(&mut self) -> io::Result<R> {
91        let state = mem::replace(&mut self.state, State::Done);
92
93        match state {
94            State::Paused(inner) => Ok(inner),
95            State::Running {
96                reader_handle,
97                mut inflater_handles,
98                recycle_tx,
99                ..
100            } => {
101                drop(recycle_tx);
102
103                for handle in inflater_handles.drain(..) {
104                    handle.join().unwrap();
105                }
106
107                reader_handle.join().unwrap().map_err(|e| e.1)
108            }
109            State::Done => panic!("invalid state"),
110        }
111    }
112}
113
114impl<R> MultithreadedReader<R>
115where
116    R: Read + Send + 'static,
117{
118    /// Creates a multithreaded BGZF reader with a worker count of 1.
119    ///
120    /// # Examples
121    ///
122    /// ```
123    /// # use std::io;
124    /// use noodles_bgzf as bgzf;
125    /// let reader = bgzf::MultithreadedReader::new(io::empty());
126    /// ```
127    pub fn new(inner: R) -> Self {
128        Self::with_worker_count(NonZeroUsize::MIN, inner)
129    }
130
131    /// Creates a multithreaded BGZF reader with a worker count.
132    ///
133    /// # Examples
134    ///
135    /// ```
136    /// # use std::io;
137    /// use std::num::NonZeroUsize;
138    /// use noodles_bgzf as bgzf;
139    /// let reader = bgzf::MultithreadedReader::with_worker_count(NonZeroUsize::MIN, io::empty());
140    /// ```
141    pub fn with_worker_count(worker_count: NonZeroUsize, inner: R) -> Self {
142        Self {
143            state: State::Paused(inner),
144            worker_count,
145            position: 0,
146            buffer: Buffer::default(),
147        }
148    }
149
150    /// Returns a mutable reference to the underlying reader.
151    ///
152    /// # Examples
153    ///
154    /// ```
155    /// # use std::io;
156    /// use noodles_bgzf as bgzf;
157    /// let mut reader = bgzf::MultithreadedReader::new(io::empty());
158    /// let _inner = reader.get_mut();
159    /// ```
160    pub fn get_mut(&mut self) -> &mut R {
161        self.pause();
162
163        match &mut self.state {
164            State::Paused(inner) => inner,
165            _ => panic!("invalid state"),
166        }
167    }
168
169    fn resume(&mut self) {
170        if matches!(self.state, State::Running { .. }) {
171            return;
172        }
173
174        let state = mem::replace(&mut self.state, State::Done);
175
176        let State::Paused(inner) = state else {
177            panic!("invalid state");
178        };
179
180        let worker_count = self.worker_count.get();
181
182        let (inflate_tx, inflate_rx) = crossbeam_channel::bounded(worker_count);
183        let (read_tx, read_rx) = crossbeam_channel::bounded(worker_count);
184        let (recycle_tx, recycle_rx) = crossbeam_channel::bounded(worker_count);
185
186        for _ in 0..worker_count {
187            recycle_tx.send(Buffer::default()).unwrap();
188        }
189
190        let reader_handle = spawn_reader(inner, inflate_tx, read_tx, recycle_rx);
191        let inflater_handles = spawn_inflaters(self.worker_count, inflate_rx);
192
193        self.state = State::Running {
194            reader_handle,
195            inflater_handles,
196            read_rx,
197            recycle_tx,
198        };
199    }
200
201    fn pause(&mut self) {
202        if matches!(self.state, State::Paused(_)) {
203            return;
204        }
205
206        let state = mem::replace(&mut self.state, State::Done);
207
208        let State::Running {
209            reader_handle,
210            mut inflater_handles,
211            recycle_tx,
212            ..
213        } = state
214        else {
215            panic!("invalid state");
216        };
217
218        drop(recycle_tx);
219
220        for handle in inflater_handles.drain(..) {
221            handle.join().unwrap();
222        }
223
224        // Discard read errors.
225        let inner = match reader_handle.join().unwrap() {
226            Ok(inner) => inner,
227            Err(ReadError(inner, _)) => inner,
228        };
229
230        self.state = State::Paused(inner);
231    }
232
233    fn read_block(&mut self) -> io::Result<()> {
234        self.resume();
235
236        let State::Running {
237            read_rx,
238            recycle_tx,
239            ..
240        } = &self.state
241        else {
242            panic!("invalid state");
243        };
244
245        while let Some(mut buffer) = recv_buffer(read_rx)? {
246            buffer.block.set_position(self.position);
247            self.position += buffer.block.size();
248
249            let prev_buffer = mem::replace(&mut self.buffer, buffer);
250            recycle_tx.send(prev_buffer).ok();
251
252            if self.buffer.block.data().len() > 0 {
253                break;
254            }
255        }
256
257        Ok(())
258    }
259}
260
261impl<R> Drop for MultithreadedReader<R> {
262    fn drop(&mut self) {
263        if !matches!(self.state, State::Done) {
264            let _ = self.finish();
265        }
266    }
267}
268
269impl<R> Read for MultithreadedReader<R>
270where
271    R: Read + Send + 'static,
272{
273    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
274        let mut src = self.fill_buf()?;
275        let amt = src.read(buf)?;
276        self.consume(amt);
277        Ok(amt)
278    }
279
280    fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
281        use super::reader::default_read_exact;
282
283        if let Some(src) = self.buffer.block.data().as_ref().get(..buf.len()) {
284            buf.copy_from_slice(src);
285            self.consume(src.len());
286            Ok(())
287        } else {
288            default_read_exact(self, buf)
289        }
290    }
291}
292
293impl<R> BufRead for MultithreadedReader<R>
294where
295    R: Read + Send + 'static,
296{
297    fn fill_buf(&mut self) -> io::Result<&[u8]> {
298        if !self.buffer.block.data().has_remaining() {
299            self.read_block()?;
300        }
301
302        Ok(self.buffer.block.data().as_ref())
303    }
304
305    fn consume(&mut self, amt: usize) {
306        self.buffer.block.data_mut().consume(amt);
307    }
308}
309
310impl<R> crate::io::Read for MultithreadedReader<R>
311where
312    R: Read + Send + 'static,
313{
314    fn virtual_position(&self) -> VirtualPosition {
315        self.buffer.block.virtual_position()
316    }
317}
318
319impl<R> crate::io::BufRead for MultithreadedReader<R> where R: Read + Send + 'static {}
320
321impl<R> crate::io::Seek for MultithreadedReader<R>
322where
323    R: Read + Send + Seek + 'static,
324{
325    fn seek_to_virtual_position(&mut self, pos: VirtualPosition) -> io::Result<VirtualPosition> {
326        let (cpos, upos) = pos.into();
327
328        self.get_mut().seek(SeekFrom::Start(cpos))?;
329        self.position = cpos;
330
331        self.read_block()?;
332
333        self.buffer.block.data_mut().set_position(usize::from(upos));
334
335        Ok(pos)
336    }
337
338    fn seek_with_index(&mut self, index: &gzi::Index, pos: SeekFrom) -> io::Result<u64> {
339        let SeekFrom::Start(pos) = pos else {
340            unimplemented!();
341        };
342
343        let virtual_position = index.query(pos)?;
344        self.seek_to_virtual_position(virtual_position)?;
345        Ok(pos)
346    }
347}
348
349fn recv_buffer(read_rx: &ReadRx) -> io::Result<Option<Buffer>> {
350    if let Ok(buffered_rx) = read_rx.recv() {
351        if let Ok(buffer) = buffered_rx.recv() {
352            return buffer.map(Some);
353        }
354    }
355
356    Ok(None)
357}
358
359struct ReadError<R>(R, io::Error);
360
361fn spawn_reader<R>(
362    mut reader: R,
363    inflate_tx: InflateTx,
364    read_tx: ReadTx,
365    recycle_rx: RecycleRx,
366) -> JoinHandle<Result<R, ReadError<R>>>
367where
368    R: Read + Send + 'static,
369{
370    use super::reader::frame::read_frame_into;
371
372    thread::spawn(move || {
373        while let Ok(mut buffer) = recycle_rx.recv() {
374            match read_frame_into(&mut reader, &mut buffer.buf) {
375                Ok(result) if result.is_none() => break,
376                Ok(_) => {}
377                Err(e) => return Err(ReadError(reader, e)),
378            }
379
380            let (buffered_tx, buffered_rx) = crossbeam_channel::bounded(1);
381
382            inflate_tx.send((buffer, buffered_tx)).unwrap();
383            read_tx.send(buffered_rx).unwrap();
384        }
385
386        Ok(reader)
387    })
388}
389
390fn spawn_inflaters(worker_count: NonZeroUsize, inflate_rx: InflateRx) -> Vec<JoinHandle<()>> {
391    use super::reader::frame::parse_block;
392
393    (0..worker_count.get())
394        .map(|_| {
395            let inflate_rx = inflate_rx.clone();
396
397            thread::spawn(move || {
398                while let Ok((mut buffer, buffered_tx)) = inflate_rx.recv() {
399                    let result = parse_block(&buffer.buf, &mut buffer.block).map(|_| buffer);
400                    buffered_tx.send(result).unwrap();
401                }
402            })
403        })
404        .collect()
405}
406
407#[cfg(test)]
408mod tests {
409    use std::io::Cursor;
410
411    use super::*;
412
413    #[test]
414    fn test_seek_to_virtual_position() -> Result<(), Box<dyn std::error::Error>> {
415        use crate::io::Seek;
416
417        #[rustfmt::skip]
418        static DATA: &[u8] = &[
419            // block 0 (b"noodles")
420            0x1f, 0x8b, 0x08, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0x06, 0x00, 0x42, 0x43,
421            0x02, 0x00, 0x22, 0x00, 0xcb, 0xcb, 0xcf, 0x4f, 0xc9, 0x49, 0x2d, 0x06, 0x00, 0xa1,
422            0x58, 0x2a, 0x80, 0x07, 0x00, 0x00, 0x00,
423            // EOF block
424            0x1f, 0x8b, 0x08, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0x06, 0x00, 0x42, 0x43,
425            0x02, 0x00, 0x1b, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
426        ];
427
428        const EOF_VIRTUAL_POSITION: VirtualPosition = match VirtualPosition::new(63, 0) {
429            Some(pos) => pos,
430            None => unreachable!(),
431        };
432
433        const VIRTUAL_POSITION: VirtualPosition = match VirtualPosition::new(0, 3) {
434            Some(pos) => pos,
435            None => unreachable!(),
436        };
437
438        let mut reader =
439            MultithreadedReader::with_worker_count(NonZeroUsize::MIN, Cursor::new(DATA));
440
441        let mut buf = Vec::new();
442        reader.read_to_end(&mut buf)?;
443
444        assert_eq!(reader.virtual_position(), EOF_VIRTUAL_POSITION);
445
446        reader.seek_to_virtual_position(VIRTUAL_POSITION)?;
447
448        buf.clear();
449        reader.read_to_end(&mut buf)?;
450
451        assert_eq!(buf, b"dles");
452        assert_eq!(reader.virtual_position(), EOF_VIRTUAL_POSITION);
453
454        Ok(())
455    }
456}