io_tether/
implementations.rs

1use std::{ops::ControlFlow, pin::Pin, task::Poll};
2
3use tokio::io::{AsyncRead, AsyncWrite};
4
5use crate::{Reason, State, TetherInner};
6
7use super::{ready::ready, Io, Resolver, Tether};
8
9macro_rules! connected {
10    ($me:expr, $poll_method:ident, $cx:expr, $($args:expr),*) => {
11        loop {
12            match $me.state {
13                State::Connected => {
14                    let new = Pin::new(&mut $me.inner);
15                    let cont = ready!(new.$poll_method($cx, $($args),*));
16
17                    match cont {
18                        ControlFlow::Continue(fut) => $me.state = fut,
19                        ControlFlow::Break(val) => return Poll::Ready(val),
20                    }
21                }
22                State::Disconnected(ref mut fut) => {
23                    let retry = ready!(fut.as_mut().poll($cx));
24
25                    if retry {
26                        let reconnect_fut = $me.inner.connector.reconnect();
27                        $me.state = State::Reconnecting(reconnect_fut);
28                    } else {
29                        let err = $me.inner.context.reason.take().into();
30                        return Poll::Ready(Err(err));
31                    }
32                }
33                State::Reconnecting(ref mut fut) => {
34                    let result = ready!(fut.as_mut().poll($cx));
35                    $me.inner.context.increment_attempts();
36
37                    match result {
38                        Ok(new_io) => {
39                            $me.inner.io = new_io;
40                            let fut = $me.inner.reconnected();
41                            $me.state = State::Reconnected(fut);
42                        }
43                        Err(error) => $me.inner.context.reason = Reason::Err(error),
44                    }
45                }
46                State::Reconnected(ref mut fut) => {
47                    ready!(fut.as_mut().poll($cx));
48                    $me.reconnect();
49                }
50            }
51        }
52    };
53}
54
55impl<C, R> TetherInner<C, R>
56where
57    C: Io + Unpin,
58    C::Output: AsyncRead + Unpin,
59    R: Resolver<C> + Unpin,
60{
61    fn poll_read_inner(
62        mut self: Pin<&mut Self>,
63        cx: &mut std::task::Context<'_>,
64        buf: &mut tokio::io::ReadBuf<'_>,
65    ) -> Poll<ControlFlow<std::io::Result<()>, State<C::Output>>> {
66        let mut me = self.as_mut();
67
68        let result = {
69            let depth = buf.filled().len();
70            let inner_pin = std::pin::pin!(&mut me.io);
71            let result = ready!(inner_pin.poll_read(cx, buf));
72            let read_bytes = buf.filled().len().saturating_sub(depth);
73            result.map(|_| read_bytes)
74        };
75
76        match result {
77            Ok(0) => {
78                me.context.reason = Reason::Eof;
79                let fut = self.disconnected();
80                Poll::Ready(ControlFlow::Continue(State::Disconnected(fut)))
81            }
82            Ok(_) => Poll::Ready(ControlFlow::Break(Ok(()))),
83            Err(error) => {
84                me.context.reason = Reason::Err(error);
85                let fut = self.disconnected();
86                Poll::Ready(ControlFlow::Continue(State::Disconnected(fut)))
87            }
88        }
89    }
90}
91
92impl<C, R> AsyncRead for Tether<C, R>
93where
94    C: Io + Unpin,
95    C::Output: AsyncRead + Unpin,
96    R: Resolver<C> + Unpin,
97{
98    fn poll_read(
99        mut self: Pin<&mut Self>,
100        cx: &mut std::task::Context<'_>,
101        buf: &mut tokio::io::ReadBuf<'_>,
102    ) -> Poll<std::io::Result<()>> {
103        let mut me = self.as_mut();
104
105        connected!(me, poll_read_inner, cx, buf);
106    }
107}
108
109impl<C, R> TetherInner<C, R>
110where
111    C: Io + Unpin,
112    C::Output: AsyncWrite + Unpin,
113    R: Resolver<C> + Unpin,
114{
115    fn poll_write_inner(
116        mut self: Pin<&mut Self>,
117        cx: &mut std::task::Context<'_>,
118        buf: &[u8],
119    ) -> Poll<ControlFlow<std::io::Result<usize>, State<C::Output>>> {
120        let mut me = self.as_mut();
121
122        let result = {
123            let inner_pin = std::pin::pin!(&mut me.io);
124            ready!(inner_pin.poll_write(cx, buf))
125        };
126
127        match result {
128            Ok(0) => {
129                me.context.reason = Reason::Eof;
130                let fut = me.disconnected();
131                Poll::Ready(ControlFlow::Continue(State::Disconnected(fut)))
132            }
133            Ok(wrote) => Poll::Ready(ControlFlow::Break(Ok(wrote))),
134            Err(error) => {
135                me.context.reason = Reason::Err(error);
136                let fut = me.disconnected();
137                Poll::Ready(ControlFlow::Continue(State::Disconnected(fut)))
138            }
139        }
140    }
141
142    fn poll_flush_inner(
143        mut self: Pin<&mut Self>,
144        cx: &mut std::task::Context<'_>,
145    ) -> Poll<ControlFlow<std::io::Result<()>, State<C::Output>>> {
146        let mut me = self.as_mut();
147
148        let result = {
149            let inner_pin = std::pin::pin!(&mut me.io);
150            ready!(inner_pin.poll_flush(cx))
151        };
152
153        match result {
154            Ok(()) => Poll::Ready(ControlFlow::Break(Ok(()))),
155            Err(error) => {
156                me.context.reason = Reason::Err(error);
157                let fut = me.disconnected();
158                Poll::Ready(ControlFlow::Continue(State::Disconnected(fut)))
159            }
160        }
161    }
162
163    fn poll_shutdown_inner(
164        mut self: Pin<&mut Self>,
165        cx: &mut std::task::Context<'_>,
166    ) -> Poll<ControlFlow<std::io::Result<()>, State<C::Output>>> {
167        let mut me = self.as_mut();
168
169        let result = {
170            let inner_pin = std::pin::pin!(&mut me.io);
171            ready!(inner_pin.poll_shutdown(cx))
172        };
173
174        match result {
175            Ok(()) => Poll::Ready(ControlFlow::Break(Ok(()))),
176            Err(error) => {
177                me.context.reason = Reason::Err(error);
178                let fut = me.disconnected();
179                Poll::Ready(ControlFlow::Continue(State::Disconnected(fut)))
180            }
181        }
182    }
183}
184
185impl<C, R> AsyncWrite for Tether<C, R>
186where
187    C: Io + Unpin,
188    C::Output: AsyncWrite + Unpin,
189    R: Resolver<C> + Unpin,
190{
191    fn poll_write(
192        mut self: Pin<&mut Self>,
193        cx: &mut std::task::Context<'_>,
194        buf: &[u8],
195    ) -> Poll<Result<usize, std::io::Error>> {
196        let mut me = self.as_mut();
197
198        connected!(me, poll_write_inner, cx, buf);
199    }
200
201    fn poll_flush(
202        mut self: Pin<&mut Self>,
203        cx: &mut std::task::Context<'_>,
204    ) -> Poll<Result<(), std::io::Error>> {
205        let mut me = self.as_mut();
206
207        connected!(me, poll_flush_inner, cx,);
208    }
209
210    fn poll_shutdown(
211        mut self: Pin<&mut Self>,
212        cx: &mut std::task::Context<'_>,
213    ) -> Poll<Result<(), std::io::Error>> {
214        let mut me = self.as_mut();
215
216        connected!(me, poll_shutdown_inner, cx,);
217    }
218}