1use std::sync::{
13 atomic::{AtomicBool, Ordering},
14 Arc,
15};
16
17#[derive(Clone, Default, Debug)]
25pub struct CancellationToken {
26 cancelled: Arc<AtomicBool>,
27}
28
29impl PartialEq for CancellationToken {
30 fn eq(&self, other: &Self) -> bool {
31 Arc::ptr_eq(&self.cancelled, &other.cancelled)
32 }
33}
34
35impl Eq for CancellationToken {}
36
37impl Ord for CancellationToken {
38 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
39 self.cancelled.as_ptr().cmp(&other.cancelled.as_ptr())
40 }
41}
42
43impl PartialOrd for CancellationToken {
44 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
45 Some(self.cmp(other))
46 }
47}
48
49impl std::hash::Hash for CancellationToken {
50 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
51 self.cancelled.as_ptr().hash(state);
52 }
53}
54
55impl CancellationToken {
56 pub fn new() -> Self {
58 Self::default()
59 }
60 pub fn cancel(&self) {
65 self.cancelled.store(true, Ordering::Relaxed);
66 }
67 pub fn check(&self) -> std::io::Result<()> {
71 let cancelled = self.cancelled.load(Ordering::Relaxed);
72 if cancelled {
73 Err(std::io::ErrorKind::BrokenPipe.into())
74 } else {
75 Ok(())
76 }
77 }
78}
79
80pub struct CancellationGuard(pub CancellationToken);
82
83impl Drop for CancellationGuard {
84 fn drop(&mut self) {
85 self.0.cancel();
86 }
87}
88
89pub struct Cancellable<T> {
91 inner: T,
92 token: CancellationToken,
93}
94
95impl<T> Cancellable<T> {
96 pub fn new(inner: T, token: CancellationToken) -> Self {
98 Self { inner, token }
99 }
100 pub fn token(&self) -> &CancellationToken {
104 &self.token
105 }
106 pub fn into_inner(self) -> T {
108 self.inner
109 }
110 pub fn get_ref(&self) -> &T {
112 &self.inner
113 }
114 pub fn get_mut(&mut self) -> &mut T {
116 &mut self.inner
117 }
118}
119
120impl<T: std::io::Read> std::io::Read for Cancellable<T> {
121 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
122 self.token.check()?;
123 self.inner.read(buf)
124 }
125
126 fn read_vectored(&mut self, bufs: &mut [std::io::IoSliceMut<'_>]) -> std::io::Result<usize> {
127 self.token.check()?;
128 self.inner.read_vectored(bufs)
129 }
130
131 fn read_to_end(&mut self, buf: &mut Vec<u8>) -> std::io::Result<usize> {
132 self.token.check()?;
133 self.inner.read_to_end(buf)
134 }
135
136 fn read_to_string(&mut self, buf: &mut String) -> std::io::Result<usize> {
137 self.token.check()?;
138 self.inner.read_to_string(buf)
139 }
140
141 fn read_exact(&mut self, buf: &mut [u8]) -> std::io::Result<()> {
142 self.token.check()?;
143 self.inner.read_exact(buf)
144 }
145}
146
147impl<T: std::io::Write> std::io::Write for Cancellable<T> {
148 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
149 self.token.check()?;
150 self.inner.write(buf)
151 }
152
153 fn flush(&mut self) -> std::io::Result<()> {
154 self.token.check()?;
155 self.inner.flush()
156 }
157 fn write_vectored(&mut self, bufs: &[std::io::IoSlice<'_>]) -> std::io::Result<usize> {
158 self.token.check()?;
159 self.inner.write_vectored(bufs)
160 }
161
162 fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> {
163 self.token.check()?;
164 self.inner.write_all(buf)
165 }
166
167 fn write_fmt(&mut self, fmt: std::fmt::Arguments<'_>) -> std::io::Result<()> {
168 self.token.check()?;
169 self.inner.write_fmt(fmt)
170 }
171}
172
173impl<T: std::io::Seek> std::io::Seek for Cancellable<T> {
174 fn seek(&mut self, from: std::io::SeekFrom) -> std::io::Result<u64> {
175 self.token.check()?;
176 self.inner.seek(from)
177 }
178
179 fn rewind(&mut self) -> std::io::Result<()> {
180 self.token.check()?;
181 self.inner.rewind()
182 }
183
184 fn stream_position(&mut self) -> std::io::Result<u64> {
185 self.token.check()?;
186 self.inner.stream_position()
187 }
188
189 fn seek_relative(&mut self, offset: i64) -> std::io::Result<()> {
190 self.token.check()?;
191 self.inner.seek_relative(offset)
192 }
193}
194
195impl<T: std::io::BufRead> std::io::BufRead for Cancellable<T> {
196 fn fill_buf(&mut self) -> std::io::Result<&[u8]> {
198 self.token.check()?;
199 self.inner.fill_buf()
200 }
201 fn consume(&mut self, amt: usize) {
202 self.inner.consume(amt)
203 }
204}
205
206#[cfg(test)]
207mod test {
208 use super::*;
209 use std::io::{self, Read, Seek, Write};
210 use std::time::Duration;
211
212 fn inf_write(ct: CancellationToken) -> io::Result<()> {
213 let w = io::empty();
214 let mut w = Cancellable::new(w, ct);
215 for _i in 0..10 {
216 w.write_all(&[0])?;
217 std::thread::sleep(Duration::from_millis(100));
218 }
219 Ok(())
220 }
221
222 fn inf_read(ct: CancellationToken) -> io::Result<()> {
223 let r = io::empty();
224 let mut r = Cancellable::new(r, ct);
225 let mut data = [0];
226 for _i in 0..10 {
227 r.read(&mut data)?;
228 std::thread::sleep(Duration::from_millis(100));
229 }
230 Ok(())
231 }
232
233 fn inf_seek(ct: CancellationToken) -> io::Result<()> {
234 let r = io::empty();
235 let mut r = Cancellable::new(r, ct);
236 for _i in 0..10 {
237 r.seek(io::SeekFrom::Start(0))?;
238 std::thread::sleep(Duration::from_millis(100));
239 }
240 Ok(())
241 }
242
243 #[test]
244 fn test_write() {
245 let ct = CancellationToken::new();
246 let th = std::thread::spawn({
247 let ct = ct.clone();
248 move || {
249 inf_write(ct).unwrap();
250 }
251 });
252 ct.cancel();
253 let err = th.join().unwrap_err();
254 let err = err.downcast::<String>().unwrap();
255 assert!(err.contains("BrokenPipe"));
256 }
257
258 #[test]
259 fn test_guard() {
260 let th;
261 {
262 let cg = CancellationGuard(CancellationToken::new());
263 th = std::thread::spawn({
264 let ct = cg.0.clone();
265 move || {
266 inf_write(ct).unwrap();
267 }
268 });
269 }
270 let err = th.join().unwrap_err();
271 let err = err.downcast::<String>().unwrap();
272 assert!(err.contains("BrokenPipe"));
273 }
274
275 #[test]
276 fn test_read() {
277 let ct = CancellationToken::new();
278 let th = std::thread::spawn({
279 let ct = ct.clone();
280 move || {
281 inf_read(ct).unwrap();
282 }
283 });
284 ct.cancel();
285 let err = th.join().unwrap_err();
286 let err = err.downcast::<String>().unwrap();
287 assert!(err.contains("BrokenPipe"));
288 }
289
290 #[test]
291 fn test_seek() {
292 let ct = CancellationToken::new();
293 let th = std::thread::spawn({
294 let ct = ct.clone();
295 move || {
296 inf_seek(ct).unwrap();
297 }
298 });
299 ct.cancel();
300 let err = th.join().unwrap_err();
301 let err = err.downcast::<String>().unwrap();
302 assert!(err.contains("BrokenPipe"));
303 }
304}