1use std::{
2 io,
3 pin::Pin,
4 sync::Arc,
5 task::{Context, Poll, Waker},
6 time::Duration,
7};
8
9use error::ResumableIOError;
10use futures::{future::select, FutureExt};
11use tokio::{
12 io::{AsyncRead, AsyncWrite},
13 sync::{
14 mpsc::{UnboundedReceiver, UnboundedSender},
15 oneshot::{self, Receiver, Sender},
16 },
17 time::Sleep,
18};
19mod error;
20pub struct ResumableIO<IO> {
21 bytes_read: usize,
22 bytes_written: usize,
23 timeout_duration: Duration,
24 error_reporter: UnboundedSender<IntruptedIO<IO>>,
25 current_io: ResumableCurrentIO<IO>,
26 reliable: bool,
27}
28
29impl<IO> ResumableIO<IO>
30where
31 IO: AsyncRead + AsyncWrite,
32{
33 pub fn new(
34 io: Option<IO>,
35 timeout_duration: Duration,
36 ) -> (Self, UnboundedReceiver<IntruptedIO<IO>>) {
37 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
38 (
39 Self {
40 current_io: io.map(ResumableCurrentIO::Ok).unwrap_or_default(),
41 timeout_duration,
42 error_reporter: tx,
43 bytes_read: 0,
44 bytes_written: 0,
45 reliable: true,
46 },
47 rx,
48 )
49 }
50}
51
52impl<IO> AsyncRead for ResumableIO<IO>
53where
54 IO: AsyncRead + AsyncWrite + std::marker::Unpin,
55{
56 fn poll_read(
57 mut self: Pin<&mut Self>,
58 cx: &mut Context<'_>,
59 buf: &mut tokio::io::ReadBuf<'_>,
60 ) -> Poll<io::Result<()>> {
61 match &mut self.current_io {
62 ResumableCurrentIO::Uninitialized => {
63 let e = Arc::new(io::Error::from(io::ErrorKind::NotConnected));
64 let (intrupted_io, rx) = IntruptedIO::new(e.clone(), 0, 0, cx.waker().clone());
65 self.error_reporter
66 .send(intrupted_io)
67 .or(Err(io::Error::from(e.kind())))?;
68 self.current_io = ResumableCurrentIO::Err(
69 e,
70 rx,
71 Box::pin(tokio::time::sleep(self.timeout_duration)),
72 );
73 Poll::Pending
74 }
75 ResumableCurrentIO::Ok(ref mut io) => match Pin::new(io).poll_read(cx, buf) {
76 Poll::Ready(Ok(_)) => {
77 self.bytes_read += buf.filled().len();
78 Poll::Ready(Ok(()))
79 }
80 Poll::Ready(Err(e)) => {
81 let error = Arc::new(e);
82 let (intrupted_io, rx) = IntruptedIO::new(
83 error.clone(),
84 self.bytes_read,
85 self.bytes_written,
86 cx.waker().clone(),
87 );
88 self.error_reporter
89 .send(intrupted_io)
90 .or(Err(io::Error::from(error.kind())))?;
91 self.current_io = ResumableCurrentIO::Err(
92 error,
93 rx,
94 Box::pin(tokio::time::sleep(self.timeout_duration)),
95 );
96 Poll::Pending
97 }
98 Poll::Pending => Poll::Pending,
99 },
100 ResumableCurrentIO::Err(e, io_receiver, timeout) => {
101 match select(io_receiver, timeout).poll_unpin(cx) {
102 Poll::Ready(either) => match either {
103 futures::future::Either::Left((io, _timeout)) => match io {
104 Ok(Some(io)) => {
105 self.current_io = ResumableCurrentIO::Ok(io);
106 self.poll_read(cx, buf)
107 }
108 Err(_) | Ok(None) => Poll::Ready(Err(io::Error::from(e.kind()))),
109 },
110 futures::future::Either::Right((_timeout, io)) => {
111 io.close();
112 Poll::Ready(Err(io::Error::from(e.kind())))
113 }
114 },
115 Poll::Pending => Poll::Pending,
116 }
117 }
118 }
119 }
120}
121
122impl<IO> AsyncWrite for ResumableIO<IO>
123where
124 IO: AsyncRead + AsyncWrite + std::marker::Unpin,
125{
126 fn poll_write(
127 mut self: Pin<&mut Self>,
128 cx: &mut Context<'_>,
129 buf: &[u8],
130 ) -> Poll<io::Result<usize>> {
131 match &mut self.current_io {
132 ResumableCurrentIO::Uninitialized => {
133 let e = Arc::new(io::Error::from(io::ErrorKind::NotConnected));
134 let (intrupted_io, rx) = IntruptedIO::new(e.clone(), 0, 0, cx.waker().clone());
135 self.error_reporter
136 .send(intrupted_io)
137 .or(Err(io::Error::from(e.kind())))?;
138 self.current_io = ResumableCurrentIO::Err(
139 e,
140 rx,
141 Box::pin(tokio::time::sleep(self.timeout_duration)),
142 );
143 Poll::Pending
144 }
145 ResumableCurrentIO::Ok(ref mut io) => match Pin::new(io).poll_write(cx, buf) {
146 Poll::Ready(Ok(n)) => {
147 self.bytes_written += n;
148 Poll::Ready(Ok(n))
149 }
150 Poll::Ready(Err(e)) => {
151 let error = Arc::new(e);
152 let (intrupted_io, rx) = IntruptedIO::new(
153 error.clone(),
154 self.bytes_read,
155 self.bytes_written,
156 cx.waker().clone(),
157 );
158 self.error_reporter
159 .send(intrupted_io)
160 .or(Err(io::Error::from(error.kind())))?;
161 self.current_io = ResumableCurrentIO::Err(
162 error,
163 rx,
164 Box::pin(tokio::time::sleep(self.timeout_duration)),
165 );
166 Poll::Pending
167 }
168 Poll::Pending => Poll::Pending,
169 },
170 ResumableCurrentIO::Err(e, io_receiver, timeout) => {
171 match select(io_receiver, timeout).poll_unpin(cx) {
172 Poll::Ready(either) => match either {
173 futures::future::Either::Left((io, _timeout)) => match io {
174 Ok(Some(io)) => {
175 self.current_io = ResumableCurrentIO::Ok(io);
176 self.poll_write(cx, buf)
177 }
178 Err(_) | Ok(None) => Poll::Ready(Err(io::Error::from(e.kind()))),
179 },
180 futures::future::Either::Right((_timeout, io)) => {
181 io.close();
182 Poll::Ready(Err(io::Error::from(e.kind())))
183 }
184 },
185 Poll::Pending => Poll::Pending,
186 }
187 }
188 }
189 }
190
191 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
192 match &mut self.current_io {
193 ResumableCurrentIO::Uninitialized => Poll::Ready(Ok(())),
194 ResumableCurrentIO::Ok(io) => match Pin::new(io).poll_flush(cx) {
195 Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
196 Poll::Ready(Err(e)) => {
197 if self.reliable {
198 return Poll::Ready(Err(e));
199 }
200 let error = Arc::new(e);
201 let (intrupted_io, rx) = IntruptedIO::new(
202 error.clone(),
203 self.bytes_read,
204 self.bytes_written,
205 cx.waker().clone(),
206 );
207 self.error_reporter
208 .send(intrupted_io)
209 .or(Err(io::Error::from(error.kind())))?;
210 self.current_io = ResumableCurrentIO::Err(
211 error,
212 rx,
213 Box::pin(tokio::time::sleep(self.timeout_duration)),
214 );
215 Poll::Pending
216 }
217 Poll::Pending => Poll::Pending,
218 },
219 ResumableCurrentIO::Err(e, io_receiver, timeout) => {
220 match select(io_receiver, timeout).poll_unpin(cx) {
221 Poll::Ready(either) => match either {
222 futures::future::Either::Left((io, _timeout)) => match io {
223 Ok(Some(io)) => {
224 self.current_io = ResumableCurrentIO::Ok(io);
225 Poll::Ready(Ok(()))
226 }
227 Err(_) | Ok(None) => Poll::Ready(Err(io::Error::from(e.kind()))),
228 },
229 futures::future::Either::Right((_timeout, io)) => {
230 io.close();
231 Poll::Ready(Err(io::Error::from(e.kind())))
232 }
233 },
234 Poll::Pending => Poll::Pending,
235 }
236 }
237 }
238 }
239
240 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
241 match &mut self.current_io {
242 ResumableCurrentIO::Uninitialized => Poll::Ready(Ok(())),
243 ResumableCurrentIO::Ok(io) => Pin::new(io).poll_shutdown(cx),
244 ResumableCurrentIO::Err(e, _, _) => return Poll::Ready(Err(io::Error::from(e.kind()))),
245 }
246 }
247}
248#[derive(Default)]
249enum ResumableCurrentIO<IO> {
250 #[default]
251 Uninitialized,
252 Ok(IO),
253 Err(Arc<io::Error>, Receiver<Option<IO>>, Pin<Box<Sleep>>),
254}
255
256pub struct IntruptedIO<IO> {
257 new_io_sender: Option<Sender<Option<IO>>>,
258 error: Arc<io::Error>,
259 bytes_read: usize,
260 bytes_written: usize,
261 wake: Waker,
262}
263
264impl<IO> IntruptedIO<IO> {
265 fn new(
266 error: Arc<io::Error>,
267 bytes_read: usize,
268 bytes_written: usize,
269 wake: Waker,
270 ) -> (Self, Receiver<Option<IO>>) {
271 let (new_io_sender, new_io_receiver) = oneshot::channel();
272 (
273 Self {
274 new_io_sender: Some(new_io_sender),
275 error,
276 bytes_read,
277 bytes_written,
278 wake,
279 },
280 new_io_receiver,
281 )
282 }
283 pub fn send_new_io(mut self, new_io: Option<IO>) -> Result<(), ResumableIOError> {
284 let sender = self
285 .new_io_sender
286 .take()
287 .ok_or(ResumableIOError::SenderIsUsed)?;
288 sender
289 .send(new_io)
290 .or(Err(ResumableIOError::ChannelIsClosed))?;
291 self.wake.wake();
292 Ok(())
293 }
294 pub fn error(&self) -> &io::Error {
295 &self.error
296 }
297 pub fn bytes_read(&self) -> usize {
298 self.bytes_read
299 }
300 pub fn bytes_written(&self) -> usize {
301 self.bytes_written
302 }
303}