io_tether/implementations/
tokio.rs

1use std::{ops::ControlFlow, pin::Pin, task::Poll};
2
3use tokio::io::{AsyncRead, AsyncWrite};
4
5use crate::{Reason, Source, State, TetherInner, ready::ready};
6
7use super::{Connector, IoInto, Resolver, Tether};
8
9use super::connected::connected;
10
11impl<C, R> TetherInner<C, R>
12where
13    C: Connector + Unpin,
14    C::Output: AsyncRead + Unpin,
15    R: Resolver<C> + Unpin,
16{
17    fn poll_read_inner(
18        &mut self,
19        state: &mut State<C::Output>,
20        cx: &mut std::task::Context<'_>,
21        buf: &mut tokio::io::ReadBuf<'_>,
22    ) -> Poll<ControlFlow<std::io::Result<()>>> {
23        let result = {
24            let depth = buf.filled().len();
25            let inner_pin = std::pin::pin!(&mut self.io);
26            let result = ready!(inner_pin.poll_read(cx, buf));
27            let read_bytes = buf.filled().len().saturating_sub(depth);
28            result.map(|_| read_bytes)
29        };
30
31        match result {
32            Ok(0) => {
33                self.set_disconnected(state, Reason::Eof, Source::Io);
34                Poll::Ready(ControlFlow::Continue(()))
35            }
36            Ok(_) => Poll::Ready(ControlFlow::Break(Ok(()))),
37            Err(error) => {
38                self.set_disconnected(state, Reason::Err(error), Source::Io);
39                Poll::Ready(ControlFlow::Continue(()))
40            }
41        }
42    }
43}
44
45impl<C, R> AsyncRead for Tether<C, R>
46where
47    C: Connector + Unpin,
48    C::Output: AsyncRead + Unpin,
49    R: Resolver<C> + Unpin,
50{
51    fn poll_read(
52        self: Pin<&mut Self>,
53        cx: &mut std::task::Context<'_>,
54        buf: &mut tokio::io::ReadBuf<'_>,
55    ) -> Poll<std::io::Result<()>> {
56        let me = self.get_mut();
57
58        connected!(me, poll_read_inner, cx, Ok(()), buf);
59    }
60}
61
62impl<C, R> TetherInner<C, R>
63where
64    C: Connector + Unpin,
65    C::Output: AsyncWrite + Unpin,
66    R: Resolver<C> + Unpin,
67{
68    fn poll_write_inner(
69        &mut self,
70        state: &mut State<C::Output>,
71        cx: &mut std::task::Context<'_>,
72        buf: &[u8],
73    ) -> Poll<ControlFlow<std::io::Result<usize>>> {
74        if let Some(reason) = self.last_write.take() {
75            self.set_disconnected(state, reason, Source::Io);
76            return Poll::Ready(ControlFlow::Continue(()));
77        }
78
79        let result = {
80            let inner_pin = std::pin::pin!(&mut self.io);
81            ready!(inner_pin.poll_write(cx, buf))
82        };
83
84        // NOTE: It is important that in error branches we return ControlFlow::Continue. Otherwise,
85        // we will break out of the reconnect loop, and drop the data that was written by the caller
86        let reason = match result {
87            Ok(0) => Reason::Eof,
88            Ok(wrote) => return Poll::Ready(ControlFlow::Break(Ok(wrote))),
89            Err(error) => Reason::Err(error),
90        };
91
92        if !self.config.keep_data_on_failed_write {
93            self.last_write = Some(reason);
94            // NOTE: We have no control over the buffer that is passed to us. The only way we can
95            // ensure we are not passed the same buffer the next call, is by reporting that we
96            // successfully wrote the data to the underlying object.
97            //
98            // This is not ideal, but it is the best we can do for now
99            return Poll::Ready(ControlFlow::Break(Ok(buf.len())));
100        }
101
102        self.set_disconnected(state, reason, Source::Io);
103        Poll::Ready(ControlFlow::Continue(()))
104    }
105
106    fn poll_flush_inner(
107        &mut self,
108        state: &mut State<C::Output>,
109        cx: &mut std::task::Context<'_>,
110    ) -> Poll<ControlFlow<std::io::Result<()>>> {
111        let result = {
112            let inner_pin = std::pin::pin!(&mut self.io);
113            ready!(inner_pin.poll_flush(cx))
114        };
115
116        match result {
117            Ok(()) => Poll::Ready(ControlFlow::Break(Ok(()))),
118            Err(error) => {
119                self.set_disconnected(state, Reason::Err(error), Source::Io);
120                Poll::Ready(ControlFlow::Continue(()))
121            }
122        }
123    }
124
125    fn poll_shutdown_inner(
126        &mut self,
127        state: &mut State<C::Output>,
128        cx: &mut std::task::Context<'_>,
129    ) -> Poll<ControlFlow<std::io::Result<()>>> {
130        let result = {
131            let inner_pin = std::pin::pin!(&mut self.io);
132            ready!(inner_pin.poll_shutdown(cx))
133        };
134
135        match result {
136            Ok(()) => Poll::Ready(ControlFlow::Break(Ok(()))),
137            Err(error) => {
138                self.set_disconnected(state, Reason::Err(error), Source::Io);
139                Poll::Ready(ControlFlow::Continue(()))
140            }
141        }
142    }
143}
144
145impl<C, R> AsyncWrite for Tether<C, R>
146where
147    C: Connector + Unpin,
148    C::Output: AsyncWrite + Unpin,
149    R: Resolver<C> + Unpin,
150{
151    fn poll_write(
152        self: Pin<&mut Self>,
153        cx: &mut std::task::Context<'_>,
154        buf: &[u8],
155    ) -> Poll<Result<usize, std::io::Error>> {
156        let me = self.get_mut();
157
158        connected!(me, poll_write_inner, cx, Ok(0), buf);
159    }
160
161    fn poll_flush(
162        self: Pin<&mut Self>,
163        cx: &mut std::task::Context<'_>,
164    ) -> Poll<Result<(), std::io::Error>> {
165        let me = self.get_mut();
166
167        connected!(me, poll_flush_inner, cx, Ok(()),);
168    }
169
170    fn poll_shutdown(
171        self: Pin<&mut Self>,
172        cx: &mut std::task::Context<'_>,
173    ) -> Poll<Result<(), std::io::Error>> {
174        let me = self.get_mut();
175
176        connected!(me, poll_shutdown_inner, cx, Ok(()),);
177    }
178}