1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
#![doc = include_str!("../README.md")]
use std::{future::Future, task::Poll};

mod implementations;

/// Represents a type which drives reconnects
///
/// Since the disconnected method asynchronous, and is invoked when the underlying stream
/// disconnects things like `tokio::time::sleep` work out of the box.
// TODO: Remove the Unpin restriction
pub trait TetherResolver: Unpin {
    type Error;

    /// Invoked by Tether when an error/disconnect is encountered.
    ///
    /// Returning `true` will result in a reconnect being attempted via `<T as TetherIo>::reconnect`,
    /// returning `false` will result in the error being returned from the originating call.
    ///
    /// # Note
    ///
    /// The [`State`] will describe the type of the underlying error. It can either be `State::Eof`,
    /// in which case the end of file was reached, or an error. This information can be leveraged
    /// in this function to determine whether to attempt to reconnect.
    fn disconnected(
        &mut self,
        context: &Context,
        state: &State<Self::Error>,
    ) -> impl Future<Output = bool> + Send;
}

/// Represents an I/O source capable of reconnecting
///
/// This trait is implemented for a number of types in the library, with the implementations placed
/// behind feature flags
pub trait TetherIo<T>: Sized + Unpin {
    type Error;

    /// Initializes the connection to the I/O source
    fn connect(initializer: &T) -> impl Future<Output = Result<Self, Self::Error>> + Send;

    /// Re-establishes the connection to the I/O source
    fn reconnect(initializer: &T) -> impl Future<Output = Result<Self, Self::Error>> + Send {
        Self::connect(initializer)
    }
}

enum Status<E> {
    Success,
    Failover(State<E>),
}

/// The type of disconnect that was detected
///
/// Currently this is either an error, or an 'end of file'
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum State<E> {
    /// End of File
    Eof,
    /// I/O error
    Err(E),
}

impl Into<std::io::Error> for State<std::io::Error> {
    fn into(self) -> std::io::Error {
        match self {
            State::Eof => std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "Eof error"),
            State::Err(error) => error,
        }
    }
}

/// A wrapper type which contains the underying I/O object, it's initializer, and resolver.
///
/// This in the main type exposed by the library. It implements [`AsyncRead`](tokio::io::AsyncRead)
/// and [`AsyncWrite`](tokio::io::AsyncWrite) whenever the underying I/O object implements them.
///
/// Calling things like `read_buf` will result in the I/O automatically reconnecting if an error
/// is detected during the underlying I/O call.
///
/// # Note
///
/// Currently, there is no way to obtain a reference into the underlying I/O object. And the only
/// way to reclaim the inner I/O type is by calling [`Tether::into_inner`]. This is by design, since
/// in the future there may be reason to add unsafe code which cannot be guaranteed if outside
/// callers can obtain references. In the future I may add these as unsafe functions if those
/// cases can be described.
pub struct Tether<I, T, R> {
    context: Context,
    initializer: I,
    inner: T,
    resolver: R,
}

impl<I, T, R> Tether<I, T, R> {
    /// Construct a tether object from an existing I/O source
    pub fn new(inner: T, initializer: I, resolver: R) -> Self {
        Self {
            context: Context::default(),
            initializer,
            inner,
            resolver,
        }
    }

    /// Returns a reference to the resolver
    pub fn get_resolver(&self) -> &R {
        &self.resolver
    }

    /// Returns a mutable reference to the resolver
    pub fn get_resolver_mut(&mut self) -> &mut R {
        &mut self.resolver
    }

    /// Returns a reference to the initializer
    pub fn get_initializer(&self) -> &I {
        &self.initializer
    }

    /// Returns a mutable reference to the initializer
    pub fn get_initializer_mut(&mut self) -> &mut I {
        &mut self.initializer
    }

    /// Consume the Tether, and return the underlying I/O type
    pub fn into_inner(self) -> T {
        self.inner
    }

    /// Returns a reference to the context
    pub fn get_context(&self) -> &Context {
        &self.context
    }

    /// Returns a mutable reference to the context
    pub fn get_context_mut(&mut self) -> &mut Context {
        &mut self.context
    }
}

impl<I, T, R> Tether<I, T, R>
where
    T: TetherIo<I>,
{
    /// Connect to the I/O source
    ///
    /// Invokes TetherIo::connect to establish the connection, the same method which is called
    /// when Tether attempts to reconnect.
    pub async fn connect(initializer: I, resolver: R) -> Result<Self, T::Error> {
        let inner = T::connect(&initializer).await?;

        Ok(Self::new(inner, initializer, resolver))
    }
}

/// Contains metrics about the underlying connection
///
/// Passed to the [`TetherResolver`], with each call to `disconnect`.
///
/// Currently tracks the number of reconnect attempts, but in the future may be expanded to include
/// additional metrics.
#[derive(Default, Debug, Clone, PartialEq, Eq)]
pub struct Context {
    reconnection_attempts: usize,
}

impl Context {
    /// The number of times a reconnect has been attempted.
    ///
    /// The first time [`TetherResolver::disconnected`] is invoked this will return `1`.
    pub fn reconnect_count(&self) -> usize {
        self.reconnection_attempts
    }
}

impl<I, T, R> Tether<I, T, R>
where
    T: TetherIo<I, Error = R::Error>,
    R: TetherResolver,
{
    pub(crate) fn poll_reconnect(
        &mut self,
        cx: &mut std::task::Context<'_>,
        mut state: State<R::Error>,
    ) -> Poll<Status<R::Error>> {
        loop {
            self.context.reconnection_attempts += 1;

            // NOTE: Prevent holding the ref to error outside this block
            let retry = {
                let mut resolver_pin = std::pin::pin!(&mut self.resolver);
                let resolver_fut = resolver_pin.disconnected(&self.context, &state);
                let resolver_fut_pin = std::pin::pin!(resolver_fut);
                ready::ready!(resolver_fut_pin.poll(cx))
            };

            if !retry {
                return Poll::Ready(Status::Failover(state));
            }

            let fut = T::reconnect(&self.initializer);
            let fut_pin = std::pin::pin!(fut);
            match ready::ready!(fut_pin.poll(cx)) {
                Ok(new_stream) => {
                    // NOTE: This is why we need the underlying stream to be Unpin, since we swap
                    // it with a new one of the same type. Not aware of a safe alternative
                    self.inner = new_stream;
                    return Poll::Ready(Status::Success);
                }
                Err(new_error) => state = State::Err(new_error),
            }
        }
    }
}

pub(crate) mod ready {
    macro_rules! ready {
        ($e:expr $(,)?) => {
            match $e {
                std::task::Poll::Ready(t) => t,
                std::task::Poll::Pending => return std::task::Poll::Pending,
            }
        };
    }

    pub(crate) use ready;
}

#[cfg(test)]
mod tests {
    use super::*;
    use tokio::{
        io::{AsyncReadExt, AsyncWriteExt},
        net::{TcpListener, TcpStream},
        sync::mpsc,
    };

    async fn create_tcp_pair() -> (TcpStream, TcpStream) {
        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
        let addr = listener.local_addr().unwrap();

        let (client, server) = tokio::join!(TcpStream::connect(addr), listener.accept());
        (client.unwrap(), server.unwrap().0)
    }

    pub struct CallbackResolver {
        inner: mpsc::Sender<()>,
    }

    impl TetherResolver for CallbackResolver {
        type Error = std::io::Error;

        async fn disconnected(&mut self, _context: &Context, _state: &State<Self::Error>) -> bool {
            self.inner.send(()).await.unwrap();
            false
        }
    }

    #[tokio::test]
    async fn disconnect_triggers_callback() {
        let (tx, mut rx) = mpsc::channel(1);
        let (client, mut server) = create_tcp_pair().await;
        let resolver = CallbackResolver { inner: tx };

        let mut tether = Tether::new(client, "", resolver);
        let mut buf = Vec::new();

        server.write_all(b"foo-bar").await.unwrap();
        tether.read_buf(&mut buf).await.unwrap();

        assert_eq!(&buf, b"foo-bar");
        server.shutdown().await.unwrap();
        tether.read_buf(&mut buf).await.unwrap();
        assert!(rx.recv().await.is_some());
    }
}