Skip to main content

interrupt_read/
lib.rs

1//! An interruptable [`Read`]er
2//!
3//! This crate provides the [`InterruptReader`], which can have its
4//! `read` operations interrupted by an [`Interruptor`]. They are
5//! acquired from the [`interrupt_reader::pair`] function, which
6//! returns an [`mpsc`] channel backed pair.
7//!
8//! When [`Interruptor::interrupt`] is called, the `InterruptReader`
9//! will return an erro of kind [`ErrorKind::Other`] with a payload of
10//! [`InterruptReceived`] (you can check for that using the
11//! [`is_interrupt`] function). Otherwise, it will act like any normal
12//! `Read` struct.
13//!
14//! When an interrupt is received, _the underlying data is not lost_,
15//! it still exists, and if you call a reading function again, it will
16//! be retrieved, unless another interrupt is sent before that.
17//!
18//! Some things to note about this crate:
19//!
20//! - It functions by spawning a separate thread, which will actually
21//!   read from the original `Read`er, so keep that in mind.
22//! - There is some (light) overhead over the read operations.
23//! - You should _not_ wrap this struct in a [`BufReader`] since the
24//!   struct already has its own internal buffer.
25//! - This reader doesn't assume that `Ok(0)` is the end of input, and
26//!   the spawned thread will only terminate if the
27//!   [`InterruptReader`] is dropped.
28//!
29//! # Note
30//!
31//! The reason why this function returns [`ErrorKind::Other`], rather
32//! than [`ErrorKind::Interrupted`] is that the latter error is
33//! ignored by functions like [`BufRead::read_line`] and
34//! [`BufRead::read_until`], which is probably not what you want to
35//! happen.
36//!
37//! [`BufReader`]: std::io::BufReader
38//! [`ErrorKind::Other`]: std::io::ErrorKind::Other
39//! [`ErrorKind::Interrupted`]: std::io::ErrorKind::Interrupted
40//! [`interrupt_reader::pair`]: pair
41use std::{
42    io::{BufRead, Cursor, Error, Read, Take},
43    sync::{
44        Arc,
45        atomic::{AtomicBool, Ordering::Relaxed},
46        mpsc,
47    },
48    thread::JoinHandle,
49};
50
51/// Returns a pair of an [`InterruptReader`] and an [`Interruptor`].
52///
53/// When you call any of the reading methods of `InterruptReader`, the
54/// current thread will block, being unblocked only if:
55///
56/// - The underlying [`Read`]er has more bytes or returned an
57///   [`Error`].
58/// - The [`Interruptor::interrupt`] function was called.
59///
60/// In the former case, it works just like a regular read, giving an
61/// [`std::io::Result`], depending on the operation.
62/// If the latter happens, however, an [`Error`] of type
63/// [`ErrorKind::Other`] with a payload of [`InterruptReceived`],
64/// meaning that reading operations have been interrupted for some
65/// user defined reason.
66///
67/// You can check if an [`std::io::Error`] is of this type by
68/// calling the [`is_interrupt`] function.
69///
70/// If the channel was interrupted this way, further reads will work
71/// just fine, until another interrupt comes through, creating a
72/// read/interrupt cycle.
73///
74/// Behind the scenes, this is done through channels and a spawned
75/// thread, but no timeout is used, all operations are blocking.
76///
77/// [`Error`]: std::io::Error
78/// [`ErrorKind::Other`]: std::io::ErrorKind::Other
79pub fn pair<R: Read + Send + 'static>(mut reader: R) -> (InterruptReader<R>, Interruptor) {
80    let (event_tx, event_rx) = mpsc::channel();
81    let (buffer_tx, buffer_rx) = mpsc::channel();
82    let is_reading = Arc::new(AtomicBool::new(true));
83
84    let join_handle = std::thread::spawn({
85        let event_tx = event_tx.clone();
86        let is_reading = is_reading.clone();
87        move || {
88            // Same capacity as BufReader
89            let mut buf = vec![0; 8 * 1024];
90            is_reading.store(true, Relaxed);
91
92            let reader = loop {
93                match reader.read(&mut buf) {
94                    Ok(num_bytes) => {
95                        // This means the InterruptReader has been dropped, so no more reading
96                        // will be done.
97                        let event = Event::Buf(std::mem::take(&mut buf), num_bytes);
98                        if event_tx.send(event).is_err() {
99                            break reader;
100                        }
101
102                        buf = match buffer_rx.recv() {
103                            Ok(buf) => buf,
104                            // Same as before.
105                            Err(_) => break reader,
106                        }
107                    }
108                    Err(err) => {
109                        if event_tx.send(Event::Err(err)).is_err() {
110                            break reader;
111                        }
112                    }
113                }
114            };
115            is_reading.store(false, Relaxed);
116            reader
117        }
118    });
119
120    let interrupt_reader = InterruptReader {
121        is_reading,
122        cursor: None,
123        buffer_tx,
124        event_rx,
125        join_handle,
126    };
127    let interruptor = Interruptor(event_tx);
128
129    (interrupt_reader, interruptor)
130}
131
132/// An interruptable, buffered [`Read`]er.
133///
134/// This reader is created by wrapping a `Read` struct in the
135/// [`interrupt_read::pair`] function, which also returns an
136/// [`Interruptor`], which is capable of sending interrupt signals,
137/// which make any `read` operations on the `InterruptReader` return
138/// an error of kind [`ErrorKind::Other`], with a payload of
139/// [`InterruptReceived`].
140///
141/// When an interrupt is received, _the underlying data is not lost_,
142/// it still exists, and if you call a reading function again, it will
143/// be retrieved, unless another interrupt is sent before that.
144///
145/// You can check if an [`std::io::Error`] is of this type by
146/// calling the [`is_interrupt`] function.
147///
148/// # Examples
149///
150/// One potential application of this struct is if you want to stop a
151/// thread that is reading from the stdout of a child process without
152/// necessarily terminating said childrop_:
153///
154/// ```rust
155/// use std::{
156///     io::{BufRead, ErrorKind},
157///     process::{Child, Command, Stdio},
158///     time::Duration,
159/// };
160///
161/// use interrupt_read::{is_interrupt, pair};
162///
163/// struct ChildKiller(Child);
164/// impl Drop for ChildKiller {
165///     fn drop(&mut self) {
166///         _ = self.0.kill();
167///     }
168/// }
169///
170/// # match main() {
171/// #     Ok(()) => {}
172/// #     Err(err) => panic!("{err}")
173/// # }
174/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
175/// // Prints "hello\n" every second forever.
176/// let mut child = Command::new("bash")
177///     .args(["-c", r#"while true; do echo "hello"; sleep 1; done"#])
178///     .stdout(Stdio::piped())
179///     .spawn()
180///     .unwrap();
181///
182/// let (mut stdout, interruptor) = pair(child.stdout.take().unwrap());
183/// let _child_killer = ChildKiller(child);
184///
185/// let join_handle = std::thread::spawn(move || {
186///     let mut string = String::new();
187///     loop {
188///         match stdout.read_line(&mut string) {
189///             Ok(0) => break Ok(string),
190///             Ok(_) => {}
191///             Err(err) if is_interrupt(&err) => {
192///                 break Ok(string);
193///             }
194///             Err(err) => break Err(err),
195///         }
196///     }
197/// });
198///
199/// std::thread::sleep(Duration::new(3, 1_000_000));
200///
201/// interruptor.interrupt()?;
202///
203/// let result = join_handle.join().unwrap()?;
204///
205/// assert_eq!(result, "hello\nhello\nhello\n");
206///
207/// Ok(())
208/// # }
209/// ```
210///
211/// [`interrupt_read::pair`]: pair
212/// [`ErrorKind::Other`]: std::io::ErrorKind::Other
213#[derive(Debug)]
214pub struct InterruptReader<R> {
215    is_reading: Arc<AtomicBool>,
216    cursor: Option<Take<Cursor<Vec<u8>>>>,
217    buffer_tx: mpsc::Sender<Vec<u8>>,
218    event_rx: mpsc::Receiver<Event>,
219    join_handle: JoinHandle<R>,
220}
221
222impl<R: Read> InterruptReader<R> {
223    /// Unwraps this `InterruptReader`, returning the underlying
224    /// reader.
225    ///
226    /// Note that any leftover data in the internal buffer is lost.
227    /// Therefore, a following read from the underlying reader may
228    /// lead to data loss.
229    ///
230    /// This may return [`Err`] if the underlying joined thread has
231    /// panicked, probably because the [`Read`]er has done so.
232    pub fn into_inner(self) -> std::thread::Result<R> {
233        let Self { buffer_tx, event_rx, join_handle, .. } = self;
234        drop((event_rx, buffer_tx));
235        join_handle.join()
236    }
237
238    /// Wether the reader thread is still active.
239    pub fn is_reading(&self) -> bool {
240        self.is_reading.load(Relaxed)
241    }
242
243    /// A function that returns `true` if the reader thread is still
244    /// active.
245    pub fn is_reading_fn(&self) -> impl Fn() -> bool + Send + Sync + 'static {
246        let is_reading = self.is_reading.clone();
247        move || is_reading.load(Relaxed)
248    }
249}
250
251impl<R: Read> Read for InterruptReader<R> {
252    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
253        if let Some(cursor) = self.cursor.as_mut() {
254            deal_with_interrupt(&self.event_rx)?;
255
256            match cursor.read(buf) {
257                Ok(0) => {
258                    let buffer = self.cursor.take().unwrap().into_inner().into_inner();
259                    match self.buffer_tx.send(buffer) {
260                        Ok(()) => self.read(buf),
261                        // Now we handle that.
262                        Err(_) => Ok(0),
263                    }
264                }
265                Ok(num_bytes) => Ok(num_bytes),
266                Err(_) => unreachable!("Afaik, this shouldn't happen if T is Vec<u8>"),
267            }
268        } else {
269            match self.event_rx.recv() {
270                Ok(Event::Buf(buffer, len)) => {
271                    self.cursor = Some(Cursor::new(buffer).take(len as u64));
272                    if len == 0 { Ok(0) } else { self.read(buf) }
273                }
274                Ok(Event::Err(err)) => Err(err),
275                Ok(Event::Interrupt) => Err(interrupt_error()),
276                Err(_) => Ok(0),
277            }
278        }
279    }
280}
281
282impl<R: Read> BufRead for InterruptReader<R> {
283    fn fill_buf(&mut self) -> std::io::Result<&[u8]> {
284        if let Some(cursor) = self.cursor.as_mut() {
285            deal_with_interrupt(&self.event_rx)?;
286
287            let (addr, len) = {
288                let buf = cursor.fill_buf()?;
289                ((buf as *const [u8]).addr(), buf.len())
290            };
291
292            if len == 0 {
293                let buffer = self.cursor.take().unwrap().into_inner().into_inner();
294                match self.buffer_tx.send(buffer) {
295                    Ok(()) => self.fill_buf(),
296                    Err(_) => Ok(&[]),
297                }
298            } else {
299                let buffer = self.cursor.as_ref().unwrap().get_ref().get_ref();
300                let buf_addr = (buffer.as_slice() as *const [u8]).addr();
301
302                // First time the borrow checker actually forced me to do something
303                // inconvenient, instead of the safe alternative.
304                Ok(&buffer[addr - buf_addr..(addr - buf_addr) + len])
305            }
306        } else {
307            match self.event_rx.recv() {
308                Ok(Event::Buf(buffer, len)) => {
309                    self.cursor = Some(Cursor::new(buffer).take(len as u64));
310                    if len == 0 { Ok(&[]) } else { self.fill_buf() }
311                }
312                Ok(Event::Err(err)) => Err(err),
313                Ok(Event::Interrupt) => Err(interrupt_error()),
314                Err(_) => Ok(&[]),
315            }
316        }
317    }
318
319    fn consume(&mut self, amount: usize) {
320        if let Some(cursor) = self.cursor.as_mut() {
321            cursor.consume(amount);
322        }
323    }
324}
325
326/// An interruptor for an [`InterruptReader`].
327///
328/// This struct serves the purpose of interrupting any of the [`Read`]
329/// or [`BufRead`] functions being performend on the `InterruptReader`
330///
331/// If it is dropped, the `InterruptReader` will no longer be able to
332/// be interrupted.
333#[derive(Debug, Clone)]
334pub struct Interruptor(mpsc::Sender<Event>);
335
336impl Interruptor {
337    /// Interrupts the [`InterruptReader`]
338    ///
339    /// This will send an interrupt event to the reader, which makes
340    /// the next `read` operation return [`Err`], with an
341    /// [`ErrorKind::Other`] with a payload of [`InterruptReceived`].
342    ///
343    /// You can check if an [`std::io::Error`] is of this type by
344    /// calling the [`is_interrupt`] function.
345    ///
346    /// Subsequent `read` operations proceed as normal.
347    ///
348    /// [`ErrorKind::Other`]: std::io::ErrorKind::Other
349    pub fn interrupt(&self) -> Result<(), InterruptSendError> {
350        self.0
351            .send(Event::Interrupt)
352            .map_err(|_| InterruptSendError)
353    }
354}
355
356/// An error occurred while calling [`Interruptor::interrupt`].
357///
358/// This means that the receiving [`InterruptReader`] has been
359/// dropped.
360#[derive(Debug, Clone, Copy)]
361pub struct InterruptSendError;
362
363impl std::fmt::Display for InterruptSendError {
364    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
365        f.write_str("InterruptReader has been dropped")
366    }
367}
368
369impl std::error::Error for InterruptSendError {}
370
371/// Indicates that an [`Interruptor`] has called
372/// [`Interruptor::interrupt`], causing a read operation to be
373/// interrupted.
374#[derive(Debug, Clone, Copy)]
375pub struct InterruptReceived;
376
377impl std::fmt::Display for InterruptReceived {
378    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
379        f.write_str("Interruptor has interrupted")
380    }
381}
382
383impl std::error::Error for InterruptReceived {}
384
385#[derive(Debug)]
386enum Event {
387    Buf(Vec<u8>, usize),
388    Err(std::io::Error),
389    Interrupt,
390}
391
392/// Wether the error in question originated from an [`Interruptor`]
393/// calling [`Interruptor::interrupt`].
394///
395/// This just checks if the error is of type [`InterruptReceived`].
396///
397/// # Examples
398///
399/// ```
400/// use std::io::{BufRead, Read, Result};
401///
402/// use interrupt_read::{InterruptReader, is_interrupt};
403///
404/// // Read until either `Ok(0)` or an interrupt occurred.
405/// fn interrupt_read_loop(mut reader: InterruptReader<impl Read>) -> Result<String> {
406///     let mut string = String::new();
407///     loop {
408///         match reader.read_line(&mut string) {
409///             Ok(0) => break Ok(string),
410///             Ok(_) => {}
411///             Err(err) if is_interrupt(&err) => break Ok(string),
412///             Err(err) => break Err(err),
413///         }
414///     }
415/// }
416/// ```
417pub fn is_interrupt(err: &Error) -> bool {
418    err.get_ref()
419        .is_some_and(|err| err.is::<InterruptReceived>())
420}
421
422fn interrupt_error() -> Error {
423    Error::other(InterruptReceived)
424}
425
426fn deal_with_interrupt(event_rx: &mpsc::Receiver<Event>) -> std::io::Result<()> {
427    match event_rx.try_recv() {
428        Ok(Event::Interrupt) => Err(interrupt_error()),
429        Ok(_) => unreachable!("This should not be possible"),
430        // The channel was dropped, but no need to handle that right now.
431        Err(_) => Ok(()),
432    }
433}