Skip to main content

interrupt_read/
lib.rs

1//! An interruptable [`Read`]er
2//!
3//! This crate provides the [`InterruptReader`], which can have its
4//! `read` operations interrupted by an [`Interruptor`]. They are
5//! acquired from the [`interrupt_reader::pair`] function, which
6//! returns an [`mpsc`] channel backed pair.
7//!
8//! When [`Interruptor::interrupt`] is called, the `InterruptReader`
9//! will return an erro of kind [`ErrorKind::Interrupted`]. Otherwise,
10//! it will act like any normal `Read` struct.
11//!
12//! Some things to note about this crate:
13//!
14//! - It functions by spawning a separate thread, which will actually
15//!   read from the original `Read`er, so keep that in mind.
16//! - There is some (light) overhead over the read operations.
17//! - You should _not_ wrap this struct in a [`BufReader`] since the
18//!   struct already has its own internal buffer.
19//! - This reader doesn't assume that `Ok(0)` is the end of input, and
20//!   the spawned thread will only terminate if the
21//!   [`InterruptReader`] is dropped.
22//!
23//! [`BufReader`]: std::io::BufReader
24use std::{
25    io::{BufRead, Cursor, Error, ErrorKind, Read, Take},
26    sync::mpsc,
27};
28
29/// Returns a pair of an [`InterruptReader`] and an [`Interruptor`].
30///
31/// When you call any of the reading methods of `InterruptReader`, the
32/// current thread will block, being unblocked only if:
33///
34/// - The underlying [`Read`]er has more bytes or returned an
35///   [`Error`].
36/// - The [`Interruptor::interrupt`] function was called.
37///
38/// In the former case, it works just like a regular read, giving an
39/// [`std::io::Result`], depending on the operation.
40/// If the latter happens, however, an [`Error`] of type
41/// [`ErrorKind::Interrupted`] will be received, meaning that reading
42/// operations have been interrupted for some user defined reason.
43///
44/// If the channel was interrupted this way, further reads will work
45/// just fine, until another interrupt comes through, creating a
46/// read/interrupt cycle.
47///
48/// Behind the scenes, this is done through channels and a spawned
49/// thread, but no timeout is used, all operations are blocking.
50///
51/// [`Error`]: std::io::Error
52/// [`ErrorKind::Interrupted`]: std::io::ErrorKind::Interrupted
53pub fn pair(mut reader: impl Read + Send + 'static) -> (InterruptReader, Interruptor) {
54    let (event_tx, event_rx) = mpsc::channel();
55    let (buffer_tx, buffer_rx) = mpsc::channel();
56
57    std::thread::spawn({
58        let event_tx = event_tx.clone();
59        move || {
60            // Same capacity as BufReader
61            let mut buf = vec![0; 8 * 1024];
62
63            loop {
64                match reader.read(&mut buf) {
65                    Ok(num_bytes) => {
66                        // This means the InterruptReader has been dropped, so no more reading
67                        // will be done.
68                        let event = Event::Buf(std::mem::take(&mut buf), num_bytes);
69                        if event_tx.send(event).is_err() {
70                            break;
71                        }
72
73                        buf = match buffer_rx.recv() {
74                            Ok(buf) => buf,
75                            // Same as before.
76                            Err(_) => break,
77                        }
78                    }
79                    Err(err) => {
80                        if event_tx.send(Event::Err(err)).is_err() {
81                            break;
82                        }
83                    }
84                }
85            }
86        }
87    });
88
89    let interrupt_reader = InterruptReader { cursor: None, buffer_tx, event_rx };
90    let interruptor = Interruptor(event_tx);
91
92    (interrupt_reader, interruptor)
93}
94
95#[derive(Debug)]
96pub struct InterruptReader {
97    cursor: Option<Take<Cursor<Vec<u8>>>>,
98    buffer_tx: mpsc::Sender<Vec<u8>>,
99    event_rx: mpsc::Receiver<Event>,
100}
101
102impl Read for InterruptReader {
103    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
104        if let Some(cursor) = self.cursor.as_mut() {
105            deal_with_interrupt(&self.event_rx)?;
106
107            match cursor.read(buf) {
108                Ok(0) => {
109                    let buffer = self.cursor.take().unwrap().into_inner().into_inner();
110                    match self.buffer_tx.send(buffer) {
111                        Ok(()) => self.read(buf),
112                        // Now we handle that.
113                        Err(_) => Ok(0),
114                    }
115                }
116                Ok(num_bytes) => Ok(num_bytes),
117                Err(_) => unreachable!("Afaik, this shouldn't happen if T is Vec<u8>"),
118            }
119        } else {
120            match self.event_rx.recv() {
121                Ok(Event::Buf(buffer, len)) => {
122                    self.cursor = Some(Cursor::new(buffer).take(len as u64));
123                    if len == 0 { Ok(0) } else { self.read(buf) }
124                }
125                Ok(Event::Err(err)) => Err(err),
126                Ok(Event::Interrupt) => Err(interrupt_error()),
127                Err(_) => Ok(0),
128            }
129        }
130    }
131}
132
133impl BufRead for InterruptReader {
134    fn fill_buf(&mut self) -> std::io::Result<&[u8]> {
135        if let Some(cursor) = self.cursor.as_mut() {
136            deal_with_interrupt(&self.event_rx)?;
137
138            let (addr, len) = {
139                let buf = cursor.fill_buf()?;
140                ((buf as *const [u8]).addr(), buf.len())
141            };
142
143            if len == 0 {
144                let buffer = self.cursor.take().unwrap().into_inner().into_inner();
145                match self.buffer_tx.send(buffer) {
146                    Ok(()) => self.fill_buf(),
147                    Err(_) => Ok(&[]),
148                }
149            } else {
150                let buffer = self.cursor.as_ref().unwrap().get_ref().get_ref();
151                let buf_addr = (buffer.as_slice() as *const [u8]).addr();
152
153                // First time the borrow checker actually forced me to do something
154                // inconvenient, instead of the safe alternative.
155                Ok(&buffer[addr - buf_addr..(addr - buf_addr) + len])
156            }
157        } else {
158            match self.event_rx.recv() {
159                Ok(Event::Buf(buffer, len)) => {
160                    self.cursor = Some(Cursor::new(buffer).take(len as u64));
161                    if len == 0 { Ok(&[]) } else { self.fill_buf() }
162                }
163                Ok(Event::Err(err)) => Err(err),
164                Ok(Event::Interrupt) => Err(interrupt_error()),
165                Err(_) => Ok(&[]),
166            }
167        }
168    }
169
170    fn consume(&mut self, amount: usize) {
171        if let Some(cursor) = self.cursor.as_mut() {
172            cursor.consume(amount);
173        }
174    }
175}
176
177/// An interruptor for an [`InterruptReader`].
178///
179/// This struct serves the purpose of interrupting any of the [`Read`]
180/// or [`ReadBuf`] functions being performend on the `InterruptReader`
181///
182/// If it is dropped, the `InterruptReader` will no longer be able to
183/// be interrupted.
184#[derive(Debug, Clone)]
185pub struct Interruptor(mpsc::Sender<Event>);
186
187impl Interruptor {
188    /// Interrupts the [`InterruptReader`]
189    ///
190    /// This will send an interrupt event to the reader, which makes
191    /// the next `read` operation return [`Err`], with an
192    /// [`ErrorKind::Interrupted`].
193    ///
194    /// Subsequent `read` operations proceed as normal.
195    pub fn interrupt(&self) -> Result<(), InterruptError> {
196        self.0.send(Event::Interrupt).map_err(|_| InterruptError)
197    }
198}
199
200/// An error ocurred while calling [`Interruptor::interrupt`].
201///
202/// This means that the receiving [`InterruptReader`] has been
203/// dropped.
204#[derive(Debug, Clone, Copy)]
205pub struct InterruptError;
206
207impl std::fmt::Display for InterruptError {
208    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
209        f.write_str("InterruptError")
210    }
211}
212
213impl std::error::Error for InterruptError {}
214
215#[derive(Debug)]
216enum Event {
217    Buf(Vec<u8>, usize),
218    Err(std::io::Error),
219    Interrupt,
220}
221
222fn interrupt_error() -> Error {
223    Error::new(
224        ErrorKind::Interrupted,
225        "An Interruptor has interrupted this operation.",
226    )
227}
228
229fn deal_with_interrupt(event_rx: &mpsc::Receiver<Event>) -> std::io::Result<()> {
230    match event_rx.try_recv() {
231        Ok(Event::Interrupt) => Err(interrupt_error()),
232        Ok(_) => unreachable!("This should not be possible"),
233        // The channel was dropped, but no need to handle that right now.
234        Err(_) => Ok(()),
235    }
236}