1use std::{ops::ControlFlow, pin::Pin, task::Poll};
2
3use tokio::io::{AsyncRead, AsyncWrite};
4
5use crate::{Reason, State, TetherInner};
6
7use super::{Io, Resolver, Tether, ready::ready};
8
9trait IoInto<T>: Sized {
14 fn io_into(self) -> T;
15}
16
17impl IoInto<Result<usize, std::io::Error>> for Reason {
18 fn io_into(self) -> Result<usize, std::io::Error> {
19 match self {
20 Reason::Eof => Ok(0),
21 Reason::Err(error) => Err(error),
22 }
23 }
24}
25
26impl IoInto<Result<(), std::io::Error>> for Reason {
27 fn io_into(self) -> Result<(), std::io::Error> {
28 match self {
29 Reason::Eof => Ok(()),
30 Reason::Err(error) => Err(error),
31 }
32 }
33}
34
35macro_rules! connected {
36 ($me:expr, $poll_method:ident, $cx:expr, $($args:expr),*) => {
37 loop {
38 match $me.state {
39 State::Connected => {
40 let new = Pin::new(&mut $me.inner);
41 let cont = ready!(new.$poll_method($cx, $($args),*));
42
43 match cont {
44 ControlFlow::Continue(fut) => $me.state = fut,
45 ControlFlow::Break(val) => return Poll::Ready(val),
46 }
47 }
48 State::Disconnected(ref mut fut) => {
49 let retry = ready!(fut.as_mut().poll($cx));
50
51 if retry {
52 $me.set_reconnecting();
53 } else {
54 let opt_reason = $me.inner.context.reason.take();
55 let reason = opt_reason.expect("Can only enter Disconnected state with Reason");
56 return Poll::Ready(reason.io_into());
57 }
58 }
59 State::Reconnecting(ref mut fut) => {
60 let result = ready!(fut.as_mut().poll($cx));
61 $me.inner.context.increment_attempts();
62
63 match result {
64 Ok(new_io) => {
65 $me.inner.io = new_io;
66 $me.set_reconnected();
67 }
68 Err(error) => {
69 $me.set_disconnected(Reason::Err(error));
70 },
71 }
72 }
73 State::Reconnected(ref mut fut) => {
74 ready!(fut.as_mut().poll($cx));
75 $me.set_connected();
76 }
77 }
78 }
79 };
80}
81
82impl<C, R> TetherInner<C, R>
83where
84 C: Io + Unpin,
85 C::Output: AsyncRead + Unpin,
86 R: Resolver<C> + Unpin,
87{
88 fn poll_read_inner(
89 mut self: Pin<&mut Self>,
90 cx: &mut std::task::Context<'_>,
91 buf: &mut tokio::io::ReadBuf<'_>,
92 ) -> Poll<ControlFlow<std::io::Result<()>, State<C::Output>>> {
93 let mut me = self.as_mut();
94
95 let result = {
96 let depth = buf.filled().len();
97 let inner_pin = std::pin::pin!(&mut me.io);
98 let result = ready!(inner_pin.poll_read(cx, buf));
99 let read_bytes = buf.filled().len().saturating_sub(depth);
100 result.map(|_| read_bytes)
101 };
102
103 match result {
104 Ok(0) => {
105 me.context.reason = Some(Reason::Eof);
106 let fut = self.disconnected();
107 Poll::Ready(ControlFlow::Continue(State::Disconnected(fut)))
108 }
109 Ok(_) => Poll::Ready(ControlFlow::Break(Ok(()))),
110 Err(error) => {
111 me.context.reason = Some(Reason::Err(error));
112 let fut = self.disconnected();
113 Poll::Ready(ControlFlow::Continue(State::Disconnected(fut)))
114 }
115 }
116 }
117}
118
119impl<C, R> AsyncRead for Tether<C, R>
120where
121 C: Io + Unpin,
122 C::Output: AsyncRead + Unpin,
123 R: Resolver<C> + Unpin,
124{
125 fn poll_read(
126 mut self: Pin<&mut Self>,
127 cx: &mut std::task::Context<'_>,
128 buf: &mut tokio::io::ReadBuf<'_>,
129 ) -> Poll<std::io::Result<()>> {
130 let mut me = self.as_mut();
131
132 connected!(me, poll_read_inner, cx, buf);
133 }
134}
135
136impl<C, R> TetherInner<C, R>
137where
138 C: Io + Unpin,
139 C::Output: AsyncWrite + Unpin,
140 R: Resolver<C> + Unpin,
141{
142 fn poll_write_inner(
143 mut self: Pin<&mut Self>,
144 cx: &mut std::task::Context<'_>,
145 buf: &[u8],
146 ) -> Poll<ControlFlow<std::io::Result<usize>, State<C::Output>>> {
147 let mut me = self.as_mut();
148
149 let result = {
150 let inner_pin = std::pin::pin!(&mut me.io);
151 ready!(inner_pin.poll_write(cx, buf))
152 };
153
154 match result {
155 Ok(0) => {
156 me.context.reason = Some(Reason::Eof);
157 let fut = me.disconnected();
158 Poll::Ready(ControlFlow::Continue(State::Disconnected(fut)))
159 }
160 Ok(wrote) => Poll::Ready(ControlFlow::Break(Ok(wrote))),
161 Err(error) => {
162 me.context.reason = Some(Reason::Err(error));
163 let fut = me.disconnected();
164 Poll::Ready(ControlFlow::Continue(State::Disconnected(fut)))
165 }
166 }
167 }
168
169 fn poll_flush_inner(
170 mut self: Pin<&mut Self>,
171 cx: &mut std::task::Context<'_>,
172 ) -> Poll<ControlFlow<std::io::Result<()>, State<C::Output>>> {
173 let mut me = self.as_mut();
174
175 let result = {
176 let inner_pin = std::pin::pin!(&mut me.io);
177 ready!(inner_pin.poll_flush(cx))
178 };
179
180 match result {
181 Ok(()) => Poll::Ready(ControlFlow::Break(Ok(()))),
182 Err(error) => {
183 me.context.reason = Some(Reason::Err(error));
184 let fut = me.disconnected();
185 Poll::Ready(ControlFlow::Continue(State::Disconnected(fut)))
186 }
187 }
188 }
189
190 fn poll_shutdown_inner(
191 mut self: Pin<&mut Self>,
192 cx: &mut std::task::Context<'_>,
193 ) -> Poll<ControlFlow<std::io::Result<()>, State<C::Output>>> {
194 let mut me = self.as_mut();
195
196 let result = {
197 let inner_pin = std::pin::pin!(&mut me.io);
198 ready!(inner_pin.poll_shutdown(cx))
199 };
200
201 match result {
202 Ok(()) => Poll::Ready(ControlFlow::Break(Ok(()))),
203 Err(error) => {
204 me.context.reason = Some(Reason::Err(error));
205 let fut = me.disconnected();
206 Poll::Ready(ControlFlow::Continue(State::Disconnected(fut)))
207 }
208 }
209 }
210}
211
212impl<C, R> AsyncWrite for Tether<C, R>
213where
214 C: Io + Unpin,
215 C::Output: AsyncWrite + Unpin,
216 R: Resolver<C> + Unpin,
217{
218 fn poll_write(
219 mut self: Pin<&mut Self>,
220 cx: &mut std::task::Context<'_>,
221 buf: &[u8],
222 ) -> Poll<Result<usize, std::io::Error>> {
223 let mut me = self.as_mut();
224
225 connected!(me, poll_write_inner, cx, buf);
226 }
227
228 fn poll_flush(
229 mut self: Pin<&mut Self>,
230 cx: &mut std::task::Context<'_>,
231 ) -> Poll<Result<(), std::io::Error>> {
232 let mut me = self.as_mut();
233
234 connected!(me, poll_flush_inner, cx,);
235 }
236
237 fn poll_shutdown(
238 mut self: Pin<&mut Self>,
239 cx: &mut std::task::Context<'_>,
240 ) -> Poll<Result<(), std::io::Error>> {
241 let mut me = self.as_mut();
242
243 connected!(me, poll_shutdown_inner, cx,);
244 }
245}