io_tether/
implementations.rs

1use std::{ops::ControlFlow, pin::Pin, task::Poll};
2
3use tokio::io::{AsyncRead, AsyncWrite};
4
5use crate::{Reason, State, TetherInner};
6
7use super::{Io, Resolver, Tether, ready::ready};
8
9/// I want to avoid implementing From<Reason> for Result<T, std::io::Error>, because it's not a
10/// generally applicable transformation. In the specific case of AsyncRead and AsyncWrite, we can
11/// map to those, but there's no guarantee that Ok(0) always implies Eof for some arbitrary
12/// Result<usize, std::io::Error>
13trait IoInto<T>: Sized {
14    fn io_into(self) -> T;
15}
16
17impl IoInto<Result<usize, std::io::Error>> for Reason {
18    fn io_into(self) -> Result<usize, std::io::Error> {
19        match self {
20            Reason::Eof => Ok(0),
21            Reason::Err(error) => Err(error),
22        }
23    }
24}
25
26impl IoInto<Result<(), std::io::Error>> for Reason {
27    fn io_into(self) -> Result<(), std::io::Error> {
28        match self {
29            Reason::Eof => Ok(()),
30            Reason::Err(error) => Err(error),
31        }
32    }
33}
34
35macro_rules! connected {
36    ($me:expr, $poll_method:ident, $cx:expr, $($args:expr),*) => {
37        loop {
38            match $me.state {
39                State::Connected => {
40                    let new = Pin::new(&mut $me.inner);
41                    let cont = ready!(new.$poll_method($cx, $($args),*));
42
43                    match cont {
44                        ControlFlow::Continue(fut) => $me.state = fut,
45                        ControlFlow::Break(val) => return Poll::Ready(val),
46                    }
47                }
48                State::Disconnected(ref mut fut) => {
49                    let retry = ready!(fut.as_mut().poll($cx));
50
51                    if retry {
52                        $me.set_reconnecting();
53                    } else {
54                        let opt_reason = $me.inner.context.reason.take();
55                        let reason = opt_reason.expect("Can only enter Disconnected state with Reason");
56                        return Poll::Ready(reason.io_into());
57                    }
58                }
59                State::Reconnecting(ref mut fut) => {
60                    let result = ready!(fut.as_mut().poll($cx));
61                    $me.inner.context.increment_attempts();
62
63                    match result {
64                        Ok(new_io) => {
65                            $me.inner.io = new_io;
66                            $me.set_reconnected();
67                        }
68                        Err(error) => {
69                            $me.set_disconnected(Reason::Err(error));
70                        },
71                    }
72                }
73                State::Reconnected(ref mut fut) => {
74                    ready!(fut.as_mut().poll($cx));
75                    $me.set_connected();
76                }
77            }
78        }
79    };
80}
81
82impl<C, R> TetherInner<C, R>
83where
84    C: Io + Unpin,
85    C::Output: AsyncRead + Unpin,
86    R: Resolver<C> + Unpin,
87{
88    fn poll_read_inner(
89        mut self: Pin<&mut Self>,
90        cx: &mut std::task::Context<'_>,
91        buf: &mut tokio::io::ReadBuf<'_>,
92    ) -> Poll<ControlFlow<std::io::Result<()>, State<C::Output>>> {
93        let mut me = self.as_mut();
94
95        let result = {
96            let depth = buf.filled().len();
97            let inner_pin = std::pin::pin!(&mut me.io);
98            let result = ready!(inner_pin.poll_read(cx, buf));
99            let read_bytes = buf.filled().len().saturating_sub(depth);
100            result.map(|_| read_bytes)
101        };
102
103        match result {
104            Ok(0) => {
105                me.context.reason = Some(Reason::Eof);
106                let fut = self.disconnected();
107                Poll::Ready(ControlFlow::Continue(State::Disconnected(fut)))
108            }
109            Ok(_) => Poll::Ready(ControlFlow::Break(Ok(()))),
110            Err(error) => {
111                me.context.reason = Some(Reason::Err(error));
112                let fut = self.disconnected();
113                Poll::Ready(ControlFlow::Continue(State::Disconnected(fut)))
114            }
115        }
116    }
117}
118
119impl<C, R> AsyncRead for Tether<C, R>
120where
121    C: Io + Unpin,
122    C::Output: AsyncRead + Unpin,
123    R: Resolver<C> + Unpin,
124{
125    fn poll_read(
126        mut self: Pin<&mut Self>,
127        cx: &mut std::task::Context<'_>,
128        buf: &mut tokio::io::ReadBuf<'_>,
129    ) -> Poll<std::io::Result<()>> {
130        let mut me = self.as_mut();
131
132        connected!(me, poll_read_inner, cx, buf);
133    }
134}
135
136impl<C, R> TetherInner<C, R>
137where
138    C: Io + Unpin,
139    C::Output: AsyncWrite + Unpin,
140    R: Resolver<C> + Unpin,
141{
142    fn poll_write_inner(
143        mut self: Pin<&mut Self>,
144        cx: &mut std::task::Context<'_>,
145        buf: &[u8],
146    ) -> Poll<ControlFlow<std::io::Result<usize>, State<C::Output>>> {
147        let mut me = self.as_mut();
148
149        let result = {
150            let inner_pin = std::pin::pin!(&mut me.io);
151            ready!(inner_pin.poll_write(cx, buf))
152        };
153
154        match result {
155            Ok(0) => {
156                me.context.reason = Some(Reason::Eof);
157                let fut = me.disconnected();
158                Poll::Ready(ControlFlow::Continue(State::Disconnected(fut)))
159            }
160            Ok(wrote) => Poll::Ready(ControlFlow::Break(Ok(wrote))),
161            Err(error) => {
162                me.context.reason = Some(Reason::Err(error));
163                let fut = me.disconnected();
164                Poll::Ready(ControlFlow::Continue(State::Disconnected(fut)))
165            }
166        }
167    }
168
169    fn poll_flush_inner(
170        mut self: Pin<&mut Self>,
171        cx: &mut std::task::Context<'_>,
172    ) -> Poll<ControlFlow<std::io::Result<()>, State<C::Output>>> {
173        let mut me = self.as_mut();
174
175        let result = {
176            let inner_pin = std::pin::pin!(&mut me.io);
177            ready!(inner_pin.poll_flush(cx))
178        };
179
180        match result {
181            Ok(()) => Poll::Ready(ControlFlow::Break(Ok(()))),
182            Err(error) => {
183                me.context.reason = Some(Reason::Err(error));
184                let fut = me.disconnected();
185                Poll::Ready(ControlFlow::Continue(State::Disconnected(fut)))
186            }
187        }
188    }
189
190    fn poll_shutdown_inner(
191        mut self: Pin<&mut Self>,
192        cx: &mut std::task::Context<'_>,
193    ) -> Poll<ControlFlow<std::io::Result<()>, State<C::Output>>> {
194        let mut me = self.as_mut();
195
196        let result = {
197            let inner_pin = std::pin::pin!(&mut me.io);
198            ready!(inner_pin.poll_shutdown(cx))
199        };
200
201        match result {
202            Ok(()) => Poll::Ready(ControlFlow::Break(Ok(()))),
203            Err(error) => {
204                me.context.reason = Some(Reason::Err(error));
205                let fut = me.disconnected();
206                Poll::Ready(ControlFlow::Continue(State::Disconnected(fut)))
207            }
208        }
209    }
210}
211
212impl<C, R> AsyncWrite for Tether<C, R>
213where
214    C: Io + Unpin,
215    C::Output: AsyncWrite + Unpin,
216    R: Resolver<C> + Unpin,
217{
218    fn poll_write(
219        mut self: Pin<&mut Self>,
220        cx: &mut std::task::Context<'_>,
221        buf: &[u8],
222    ) -> Poll<Result<usize, std::io::Error>> {
223        let mut me = self.as_mut();
224
225        connected!(me, poll_write_inner, cx, buf);
226    }
227
228    fn poll_flush(
229        mut self: Pin<&mut Self>,
230        cx: &mut std::task::Context<'_>,
231    ) -> Poll<Result<(), std::io::Error>> {
232        let mut me = self.as_mut();
233
234        connected!(me, poll_flush_inner, cx,);
235    }
236
237    fn poll_shutdown(
238        mut self: Pin<&mut Self>,
239        cx: &mut std::task::Context<'_>,
240    ) -> Poll<Result<(), std::io::Error>> {
241        let mut me = self.as_mut();
242
243        connected!(me, poll_shutdown_inner, cx,);
244    }
245}