cancel_rw/
lib.rs

1/// New type to cancel synchronous reads and writes.
2///
3/// This crate provides a _new-type_ [Cancellable] that can be used to wrap
4/// a `Read`, `Write` or `Seek`, so that its operation can be interrupted at
5/// any time.
6///
7/// To signal the cancellation event, you first create a [CancellationToken],
8/// and then call its `CancellationToken::cancel` member function.
9///
10/// You can use the same `CancellationToken for as many `Cancellable` objects
11/// as you need.
12use std::sync::{
13    atomic::{AtomicBool, Ordering},
14    Arc,
15};
16
17/// This type signals a cancellation event.
18///
19/// It is `Sync` and `Send` so you can share it between threads freely.
20///
21/// It also implements `Eq`, `Ord` and `Hash`, with some arbitrary ordering,
22/// so that you can use it as a cheap identifier for your interruptible actions.
23/// All clones of the same token will compare equal.
24#[derive(Clone, Default, Debug)]
25pub struct CancellationToken {
26    cancelled: Arc<AtomicBool>,
27}
28
29impl PartialEq for CancellationToken {
30    fn eq(&self, other: &Self) -> bool {
31        Arc::ptr_eq(&self.cancelled, &other.cancelled)
32    }
33}
34
35impl Eq for CancellationToken {}
36
37impl Ord for CancellationToken {
38    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
39        self.cancelled.as_ptr().cmp(&other.cancelled.as_ptr())
40    }
41}
42
43impl PartialOrd for CancellationToken {
44    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
45        Some(self.cmp(other))
46    }
47}
48
49impl std::hash::Hash for CancellationToken {
50    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
51        self.cancelled.as_ptr().hash(state);
52    }
53}
54
55impl CancellationToken {
56    /// Creates a new `CancellationToken`, in a non-cancelled state.
57    pub fn new() -> Self {
58        Self::default()
59    }
60    /// Signals this token as _cancelled_.
61    ///
62    /// Note that it takes a non-mutable `self`, so you are able to cancel a
63    /// shared token.
64    pub fn cancel(&self) {
65        self.cancelled.store(true, Ordering::Relaxed);
66    }
67    /// Checks whether a token is cancelled.
68    ///
69    /// It returns `Ok(())` if non-cancelled, `Err(ErrorKind::BrokenPipe)` if cancelled.
70    pub fn check(&self) -> std::io::Result<()> {
71        let cancelled = self.cancelled.load(Ordering::Relaxed);
72        if cancelled {
73            Err(std::io::ErrorKind::BrokenPipe.into())
74        } else {
75            Ok(())
76        }
77    }
78}
79
80/// A newtype around `CancellationToken` that automatically cancels on `drop`.
81pub struct CancellationGuard(pub CancellationToken);
82
83impl Drop for CancellationGuard {
84    fn drop(&mut self) {
85        self.0.cancel();
86    }
87}
88
89/// A newtype around any `Read`, `Write` or `Seek` value, that makes it cancellable.
90pub struct Cancellable<T> {
91    inner: T,
92    token: CancellationToken,
93}
94
95impl<T> Cancellable<T> {
96    /// Wraps a value as `Cancellable`.
97    pub fn new(inner: T, token: CancellationToken) -> Self {
98        Self { inner, token }
99    }
100    /// Gets the inner token.
101    ///
102    /// You will probably need to clone it if you want store it somewhere.
103    pub fn token(&self) -> &CancellationToken {
104        &self.token
105    }
106    /// Unwraps the inner value.
107    pub fn into_inner(self) -> T {
108        self.inner
109    }
110    /// Gets a reference to the inner value.
111    pub fn get_ref(&self) -> &T {
112        &self.inner
113    }
114    /// Gets a mutable reference to the inner value.
115    pub fn get_mut(&mut self) -> &mut T {
116        &mut self.inner
117    }
118}
119
120impl<T: std::io::Read> std::io::Read for Cancellable<T> {
121    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
122        self.token.check()?;
123        self.inner.read(buf)
124    }
125
126    fn read_vectored(&mut self, bufs: &mut [std::io::IoSliceMut<'_>]) -> std::io::Result<usize> {
127        self.token.check()?;
128        self.inner.read_vectored(bufs)
129    }
130
131    fn read_to_end(&mut self, buf: &mut Vec<u8>) -> std::io::Result<usize> {
132        self.token.check()?;
133        self.inner.read_to_end(buf)
134    }
135
136    fn read_to_string(&mut self, buf: &mut String) -> std::io::Result<usize> {
137        self.token.check()?;
138        self.inner.read_to_string(buf)
139    }
140
141    fn read_exact(&mut self, buf: &mut [u8]) -> std::io::Result<()> {
142        self.token.check()?;
143        self.inner.read_exact(buf)
144    }
145}
146
147impl<T: std::io::Write> std::io::Write for Cancellable<T> {
148    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
149        self.token.check()?;
150        self.inner.write(buf)
151    }
152
153    fn flush(&mut self) -> std::io::Result<()> {
154        self.token.check()?;
155        self.inner.flush()
156    }
157    fn write_vectored(&mut self, bufs: &[std::io::IoSlice<'_>]) -> std::io::Result<usize> {
158        self.token.check()?;
159        self.inner.write_vectored(bufs)
160    }
161
162    fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> {
163        self.token.check()?;
164        self.inner.write_all(buf)
165    }
166
167    fn write_fmt(&mut self, fmt: std::fmt::Arguments<'_>) -> std::io::Result<()> {
168        self.token.check()?;
169        self.inner.write_fmt(fmt)
170    }
171}
172
173impl<T: std::io::Seek> std::io::Seek for Cancellable<T> {
174    fn seek(&mut self, from: std::io::SeekFrom) -> std::io::Result<u64> {
175        self.token.check()?;
176        self.inner.seek(from)
177    }
178
179    fn rewind(&mut self) -> std::io::Result<()> {
180        self.token.check()?;
181        self.inner.rewind()
182    }
183
184    fn stream_position(&mut self) -> std::io::Result<u64> {
185        self.token.check()?;
186        self.inner.stream_position()
187    }
188
189    fn seek_relative(&mut self, offset: i64) -> std::io::Result<()> {
190        self.token.check()?;
191        self.inner.seek_relative(offset)
192    }
193}
194
195impl<T: std::io::BufRead> std::io::BufRead for Cancellable<T> {
196    // Provided methods are not wrapped, probably not worth it
197    fn fill_buf(&mut self) -> std::io::Result<&[u8]> {
198        self.token.check()?;
199        self.inner.fill_buf()
200    }
201    fn consume(&mut self, amt: usize) {
202        self.inner.consume(amt)
203    }
204}
205
206#[cfg(test)]
207mod test {
208    use super::*;
209    use std::io::{self, Read, Seek, Write};
210    use std::time::Duration;
211
212    fn inf_write(ct: CancellationToken) -> io::Result<()> {
213        let w = io::empty();
214        let mut w = Cancellable::new(w, ct);
215        for _i in 0..10 {
216            w.write_all(&[0])?;
217            std::thread::sleep(Duration::from_millis(100));
218        }
219        Ok(())
220    }
221
222    fn inf_read(ct: CancellationToken) -> io::Result<()> {
223        let r = io::empty();
224        let mut r = Cancellable::new(r, ct);
225        let mut data = [0];
226        for _i in 0..10 {
227            r.read(&mut data)?;
228            std::thread::sleep(Duration::from_millis(100));
229        }
230        Ok(())
231    }
232
233    fn inf_seek(ct: CancellationToken) -> io::Result<()> {
234        let r = io::empty();
235        let mut r = Cancellable::new(r, ct);
236        for _i in 0..10 {
237            r.seek(io::SeekFrom::Start(0))?;
238            std::thread::sleep(Duration::from_millis(100));
239        }
240        Ok(())
241    }
242
243    #[test]
244    fn test_write() {
245        let ct = CancellationToken::new();
246        let th = std::thread::spawn({
247            let ct = ct.clone();
248            move || {
249                inf_write(ct).unwrap();
250            }
251        });
252        ct.cancel();
253        let err = th.join().unwrap_err();
254        let err = err.downcast::<String>().unwrap();
255        assert!(err.contains("BrokenPipe"));
256    }
257
258    #[test]
259    fn test_guard() {
260        let th;
261        {
262            let cg = CancellationGuard(CancellationToken::new());
263            th = std::thread::spawn({
264                let ct = cg.0.clone();
265                move || {
266                    inf_write(ct).unwrap();
267                }
268            });
269        }
270        let err = th.join().unwrap_err();
271        let err = err.downcast::<String>().unwrap();
272        assert!(err.contains("BrokenPipe"));
273    }
274
275    #[test]
276    fn test_read() {
277        let ct = CancellationToken::new();
278        let th = std::thread::spawn({
279            let ct = ct.clone();
280            move || {
281                inf_read(ct).unwrap();
282            }
283        });
284        ct.cancel();
285        let err = th.join().unwrap_err();
286        let err = err.downcast::<String>().unwrap();
287        assert!(err.contains("BrokenPipe"));
288    }
289
290    #[test]
291    fn test_seek() {
292        let ct = CancellationToken::new();
293        let th = std::thread::spawn({
294            let ct = ct.clone();
295            move || {
296                inf_seek(ct).unwrap();
297            }
298        });
299        ct.cancel();
300        let err = th.join().unwrap_err();
301        let err = err.downcast::<String>().unwrap();
302        assert!(err.contains("BrokenPipe"));
303    }
304}