resumable_io/
lib.rs

1use std::{
2    io,
3    pin::Pin,
4    sync::Arc,
5    task::{Context, Poll, Waker},
6    time::Duration,
7};
8
9use error::ResumableIOError;
10use futures::{future::select, FutureExt};
11use tokio::{
12    io::{AsyncRead, AsyncWrite},
13    sync::{
14        mpsc::{UnboundedReceiver, UnboundedSender},
15        oneshot::{self, Receiver, Sender},
16    },
17    time::Sleep,
18};
19mod error;
20pub struct ResumableIO<IO> {
21    bytes_read: usize,
22    bytes_written: usize,
23    timeout_duration: Duration,
24    error_reporter: UnboundedSender<IntruptedIO<IO>>,
25    current_io: ResumableCurrentIO<IO>,
26    reliable: bool,
27}
28
29impl<IO> ResumableIO<IO>
30where
31    IO: AsyncRead + AsyncWrite,
32{
33    pub fn new(
34        io: Option<IO>,
35        timeout_duration: Duration,
36    ) -> (Self, UnboundedReceiver<IntruptedIO<IO>>) {
37        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
38        (
39            Self {
40                current_io: io.map(ResumableCurrentIO::Ok).unwrap_or_default(),
41                timeout_duration,
42                error_reporter: tx,
43                bytes_read: 0,
44                bytes_written: 0,
45                reliable: true,
46            },
47            rx,
48        )
49    }
50}
51
52impl<IO> AsyncRead for ResumableIO<IO>
53where
54    IO: AsyncRead + AsyncWrite + std::marker::Unpin,
55{
56    fn poll_read(
57        mut self: Pin<&mut Self>,
58        cx: &mut Context<'_>,
59        buf: &mut tokio::io::ReadBuf<'_>,
60    ) -> Poll<io::Result<()>> {
61        match &mut self.current_io {
62            ResumableCurrentIO::Uninitialized => {
63                let e = Arc::new(io::Error::from(io::ErrorKind::NotConnected));
64                let (intrupted_io, rx) = IntruptedIO::new(e.clone(), 0, 0, cx.waker().clone());
65                self.error_reporter
66                    .send(intrupted_io)
67                    .or(Err(io::Error::from(e.kind())))?;
68                self.current_io = ResumableCurrentIO::Err(
69                    e,
70                    rx,
71                    Box::pin(tokio::time::sleep(self.timeout_duration)),
72                );
73                Poll::Pending
74            }
75            ResumableCurrentIO::Ok(ref mut io) => match Pin::new(io).poll_read(cx, buf) {
76                Poll::Ready(Ok(_)) => {
77                    self.bytes_read += buf.filled().len();
78                    Poll::Ready(Ok(()))
79                }
80                Poll::Ready(Err(e)) => {
81                    let error = Arc::new(e);
82                    let (intrupted_io, rx) = IntruptedIO::new(
83                        error.clone(),
84                        self.bytes_read,
85                        self.bytes_written,
86                        cx.waker().clone(),
87                    );
88                    self.error_reporter
89                        .send(intrupted_io)
90                        .or(Err(io::Error::from(error.kind())))?;
91                    self.current_io = ResumableCurrentIO::Err(
92                        error,
93                        rx,
94                        Box::pin(tokio::time::sleep(self.timeout_duration)),
95                    );
96                    Poll::Pending
97                }
98                Poll::Pending => Poll::Pending,
99            },
100            ResumableCurrentIO::Err(e, io_receiver, timeout) => {
101                match select(io_receiver, timeout).poll_unpin(cx) {
102                    Poll::Ready(either) => match either {
103                        futures::future::Either::Left((io, _timeout)) => match io {
104                            Ok(Some(io)) => {
105                                self.current_io = ResumableCurrentIO::Ok(io);
106                                self.poll_read(cx, buf)
107                            }
108                            Err(_) | Ok(None) => Poll::Ready(Err(io::Error::from(e.kind()))),
109                        },
110                        futures::future::Either::Right((_timeout, io)) => {
111                            io.close();
112                            Poll::Ready(Err(io::Error::from(e.kind())))
113                        }
114                    },
115                    Poll::Pending => Poll::Pending,
116                }
117            }
118        }
119    }
120}
121
122impl<IO> AsyncWrite for ResumableIO<IO>
123where
124    IO: AsyncRead + AsyncWrite + std::marker::Unpin,
125{
126    fn poll_write(
127        mut self: Pin<&mut Self>,
128        cx: &mut Context<'_>,
129        buf: &[u8],
130    ) -> Poll<io::Result<usize>> {
131        match &mut self.current_io {
132            ResumableCurrentIO::Uninitialized => {
133                let e = Arc::new(io::Error::from(io::ErrorKind::NotConnected));
134                let (intrupted_io, rx) = IntruptedIO::new(e.clone(), 0, 0, cx.waker().clone());
135                self.error_reporter
136                    .send(intrupted_io)
137                    .or(Err(io::Error::from(e.kind())))?;
138                self.current_io = ResumableCurrentIO::Err(
139                    e,
140                    rx,
141                    Box::pin(tokio::time::sleep(self.timeout_duration)),
142                );
143                Poll::Pending
144            }
145            ResumableCurrentIO::Ok(ref mut io) => match Pin::new(io).poll_write(cx, buf) {
146                Poll::Ready(Ok(n)) => {
147                    self.bytes_written += n;
148                    Poll::Ready(Ok(n))
149                }
150                Poll::Ready(Err(e)) => {
151                    let error = Arc::new(e);
152                    let (intrupted_io, rx) = IntruptedIO::new(
153                        error.clone(),
154                        self.bytes_read,
155                        self.bytes_written,
156                        cx.waker().clone(),
157                    );
158                    self.error_reporter
159                        .send(intrupted_io)
160                        .or(Err(io::Error::from(error.kind())))?;
161                    self.current_io = ResumableCurrentIO::Err(
162                        error,
163                        rx,
164                        Box::pin(tokio::time::sleep(self.timeout_duration)),
165                    );
166                    Poll::Pending
167                }
168                Poll::Pending => Poll::Pending,
169            },
170            ResumableCurrentIO::Err(e, io_receiver, timeout) => {
171                match select(io_receiver, timeout).poll_unpin(cx) {
172                    Poll::Ready(either) => match either {
173                        futures::future::Either::Left((io, _timeout)) => match io {
174                            Ok(Some(io)) => {
175                                self.current_io = ResumableCurrentIO::Ok(io);
176                                self.poll_write(cx, buf)
177                            }
178                            Err(_) | Ok(None) => Poll::Ready(Err(io::Error::from(e.kind()))),
179                        },
180                        futures::future::Either::Right((_timeout, io)) => {
181                            io.close();
182                            Poll::Ready(Err(io::Error::from(e.kind())))
183                        }
184                    },
185                    Poll::Pending => Poll::Pending,
186                }
187            }
188        }
189    }
190
191    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
192        match &mut self.current_io {
193            ResumableCurrentIO::Uninitialized => Poll::Ready(Ok(())),
194            ResumableCurrentIO::Ok(io) => match Pin::new(io).poll_flush(cx) {
195                Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
196                Poll::Ready(Err(e)) => {
197                    if self.reliable {
198                        return Poll::Ready(Err(e));
199                    }
200                    let error = Arc::new(e);
201                    let (intrupted_io, rx) = IntruptedIO::new(
202                        error.clone(),
203                        self.bytes_read,
204                        self.bytes_written,
205                        cx.waker().clone(),
206                    );
207                    self.error_reporter
208                        .send(intrupted_io)
209                        .or(Err(io::Error::from(error.kind())))?;
210                    self.current_io = ResumableCurrentIO::Err(
211                        error,
212                        rx,
213                        Box::pin(tokio::time::sleep(self.timeout_duration)),
214                    );
215                    Poll::Pending
216                }
217                Poll::Pending => Poll::Pending,
218            },
219            ResumableCurrentIO::Err(e, io_receiver, timeout) => {
220                match select(io_receiver, timeout).poll_unpin(cx) {
221                    Poll::Ready(either) => match either {
222                        futures::future::Either::Left((io, _timeout)) => match io {
223                            Ok(Some(io)) => {
224                                self.current_io = ResumableCurrentIO::Ok(io);
225                                Poll::Ready(Ok(()))
226                            }
227                            Err(_) | Ok(None) => Poll::Ready(Err(io::Error::from(e.kind()))),
228                        },
229                        futures::future::Either::Right((_timeout, io)) => {
230                            io.close();
231                            Poll::Ready(Err(io::Error::from(e.kind())))
232                        }
233                    },
234                    Poll::Pending => Poll::Pending,
235                }
236            }
237        }
238    }
239
240    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
241        match &mut self.current_io {
242            ResumableCurrentIO::Uninitialized => Poll::Ready(Ok(())),
243            ResumableCurrentIO::Ok(io) => Pin::new(io).poll_shutdown(cx),
244            ResumableCurrentIO::Err(e, _, _) => return Poll::Ready(Err(io::Error::from(e.kind()))),
245        }
246    }
247}
248#[derive(Default)]
249enum ResumableCurrentIO<IO> {
250    #[default]
251    Uninitialized,
252    Ok(IO),
253    Err(Arc<io::Error>, Receiver<Option<IO>>, Pin<Box<Sleep>>),
254}
255
256pub struct IntruptedIO<IO> {
257    new_io_sender: Option<Sender<Option<IO>>>,
258    error: Arc<io::Error>,
259    bytes_read: usize,
260    bytes_written: usize,
261    wake: Waker,
262}
263
264impl<IO> IntruptedIO<IO> {
265    fn new(
266        error: Arc<io::Error>,
267        bytes_read: usize,
268        bytes_written: usize,
269        wake: Waker,
270    ) -> (Self, Receiver<Option<IO>>) {
271        let (new_io_sender, new_io_receiver) = oneshot::channel();
272        (
273            Self {
274                new_io_sender: Some(new_io_sender),
275                error,
276                bytes_read,
277                bytes_written,
278                wake,
279            },
280            new_io_receiver,
281        )
282    }
283    pub fn send_new_io(mut self, new_io: Option<IO>) -> Result<(), ResumableIOError> {
284        let sender = self
285            .new_io_sender
286            .take()
287            .ok_or(ResumableIOError::SenderIsUsed)?;
288        sender
289            .send(new_io)
290            .or(Err(ResumableIOError::ChannelIsClosed))?;
291        self.wake.wake();
292        Ok(())
293    }
294    pub fn error(&self) -> &io::Error {
295        &self.error
296    }
297    pub fn bytes_read(&self) -> usize {
298        self.bytes_read
299    }
300    pub fn bytes_written(&self) -> usize {
301        self.bytes_written
302    }
303}