io_tether/
lib.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
#![doc = include_str!("../README.md")]
use std::{future::Future, io::ErrorKind, pin::Pin};

mod implementations;

pub type PinFut<O> = Pin<Box<dyn Future<Output = O> + 'static + Send>>;

/// Represents a type which drives reconnects
///
/// Since the disconnected method asynchronous, and is invoked when the underlying stream
/// disconnects, calling asynchronous functions like
/// [`tokio::time::sleep`](https://docs.rs/tokio/latest/tokio/time/fn.sleep.html) from within the
/// body, work.
///
/// # Return Type
///
/// The return types of the methods are [`PinFut`]. This has the requirement that the returned
/// future be 'static (cannot hold references to self, or any of the arguments). However, you are
/// still free to mutate data outside of the returned future.
///
/// Additionally, this method is invoked each time the I/O fails to establish a connection so
/// writing futures which do not reference their environment is a little easier than it may seem.
///
/// # Example
///
/// A very simple implementation may look something like the following:
///
/// ```no_run
/// # use std::time::Duration;
/// # use io_tether::{Context, State, Resolver, PinFut};
/// pub struct RetryResolver(bool);
///
/// impl Resolver for RetryResolver {
///     fn disconnected(&mut self, context: &Context, state: &State) -> PinFut<bool> {
///         println!("WARN: Disconnected from server {:?}", state);
///         self.0 = true;
///
///         if context.current_reconnect_attempts() >= 5 || context.total_reconnect_attempts() >= 50 {
///             return Box::pin(async move {false});
///         }
///
///         Box::pin(async move {
///             tokio::time::sleep(Duration::from_secs(10)).await;
///             true
///         })
///     }
/// }
/// ```
pub trait Resolver: Unpin {
    /// Invoked by Tether when an error/disconnect is encountered.
    ///
    /// Returning `true` will result in a reconnect being attempted via `<T as Io>::reconnect`,
    /// returning `false` will result in the error being returned from the originating call.
    ///
    /// # Note
    ///
    /// The [`State`] will describe the type of the underlying error. It can either be `State::Eof`,
    /// in which case the end of file was reached, or an error. This information can be leveraged
    /// in this function to determine whether to attempt to reconnect.
    ///
    fn disconnected(&mut self, context: &Context, state: &State) -> PinFut<bool>;

    /// Invoked within [`Tether::connect`] if the initial connection attempt fails
    fn unreachable(&mut self, context: &Context, state: &State) -> PinFut<bool> {
        self.disconnected(context, state)
    }

    /// Invoked within [`Tether::connect`] if the initial connection attempt succeeds
    fn established(&mut self, context: &Context) -> PinFut<()> {
        self.reconnected(context)
    }

    /// Invoked by Tether when the underlying I/O connection has been re-established
    fn reconnected(&mut self, _context: &Context) -> PinFut<()> {
        Box::pin(std::future::ready(()))
    }
}

/// Represents an I/O source capable of reconnecting
///
/// This trait is implemented for a number of types in the library, with the implementations placed
/// behind feature flags
pub trait Io<T>: Sized + Unpin {
    /// Initializes the connection to the I/O source
    fn connect(
        initializer: T,
    ) -> impl Future<Output = Result<Self, std::io::Error>> + 'static + Send;

    /// Re-establishes the connection to the I/O source
    fn reconnect(
        initializer: T,
    ) -> impl Future<Output = Result<Self, std::io::Error>> + 'static + Send {
        Self::connect(initializer)
    }
}

/// The underlying cause of the I/O disconnect
///
/// Currently this is either an error, or an 'end of file'.
#[derive(Debug)]
pub enum State {
    /// End of File
    ///
    /// # Note
    ///
    /// This is also emitted when the other half of a TCP connection is closed.
    Eof,
    /// An I/O Error occurred
    Err(std::io::Error),
}

impl State {
    /// A convenience function which returns whether the original error is capable of being retried
    pub fn retryable(&self) -> bool {
        use std::io::ErrorKind as Kind;

        match self {
            State::Eof => true,
            State::Err(error) => matches!(
                error.kind(),
                Kind::NotFound
                    | Kind::PermissionDenied
                    | Kind::ConnectionRefused
                    | Kind::ConnectionAborted
                    | Kind::ConnectionReset
                    | Kind::NotConnected
                    | Kind::AlreadyExists
                    | Kind::HostUnreachable
                    | Kind::AddrNotAvailable
                    | Kind::NetworkDown
                    | Kind::BrokenPipe
                    | Kind::TimedOut
                    | Kind::UnexpectedEof
                    | Kind::NetworkUnreachable
                    | Kind::AddrInUse
            ),
        }
    }
}

impl From<&State> for std::io::Error {
    fn from(value: &State) -> Self {
        match value {
            State::Eof => std::io::Error::new(ErrorKind::UnexpectedEof, "Eof error"),
            State::Err(error) => {
                // TODO: This is pretty hacky there's probably a better way
                let kind = error.kind();
                let error = error.to_string();
                std::io::Error::new(kind, error)
            }
        }
    }
}

/// A wrapper type which contains the underlying I/O object, it's initializer, and resolver.
///
/// This in the main type exposed by the library. It implements [`AsyncRead`](tokio::io::AsyncRead)
/// and [`AsyncWrite`](tokio::io::AsyncWrite) whenever the underlying I/O object implements them.
///
/// Calling things like
/// [`read_buf`](https://docs.rs/tokio/latest/tokio/io/trait.AsyncReadExt.html#method.read_buf) will
/// result in the I/O automatically reconnecting if an error is detected during the underlying I/O
/// call.
///
/// # Note
///
/// Currently, there is no way to obtain a reference into the underlying I/O object. And the only
/// way to reclaim the inner I/O type is by calling [`Tether::into_inner`]. This is by design, since
/// in the future there may be reason to add unsafe code which cannot be guaranteed if outside
/// callers can obtain references. In the future I may add these as unsafe functions if those cases
/// can be described.
pub struct Tether<I, T: Io<I>, R> {
    state: StateMachine<T>,
    inner: TetherInner<I, T, R>,
}

/// The inner type for tether.
///
/// Helps satisfy the borrow checker when we need to mutate this while holding a mutable ref to the
/// larger futs state machine
struct TetherInner<I, T: Io<I>, R> {
    context: Context,
    initializer: I,
    io: T,
    resolver: R,
    state: State,
}

impl<I, T: Io<I>, R: Resolver> TetherInner<I, T, R> {
    fn disconnected(&mut self) -> PinFut<bool> {
        self.resolver.disconnected(&self.context, &self.state)
    }

    fn reconnected(&mut self) -> PinFut<()> {
        self.resolver.reconnected(&self.context)
    }
}

impl<I, T: Io<I>, R: Resolver> Tether<I, T, R> {
    /// Construct a tether object from an existing I/O source
    ///
    /// # Note
    ///
    /// Often a simpler way to construct a [`Tether`] object is through [`Tether::connect`]
    pub fn new(inner: T, initializer: I, resolver: R, context: Context) -> Self {
        Self {
            state: Default::default(),
            inner: TetherInner {
                context,
                initializer,
                io: inner,
                resolver,
                state: State::Eof,
            },
        }
    }

    fn reconnect(&mut self) {
        self.state = StateMachine::Connected;
        self.inner.context.reset();
    }

    /// Returns a reference to the initializer
    pub fn get_initializer(&self) -> &I {
        &self.inner.initializer
    }

    /// Returns a mutable reference to the initializer
    pub fn get_initializer_mut(&mut self) -> &mut I {
        &mut self.inner.initializer
    }

    /// Consume the Tether, and return the underlying I/O type
    pub fn into_inner(self) -> T {
        self.inner.io
    }
}

impl<I, T, R> Tether<I, T, R>
where
    R: Resolver,
    T: Io<I>,
    I: Clone,
{
    /// Connect to the I/O source, retrying on a failure.
    pub async fn connect(initializer: I, mut resolver: R) -> Result<Self, std::io::Error> {
        let mut context = Context::default();

        loop {
            let state = match T::connect(initializer.clone()).await {
                Ok(io) => {
                    resolver.established(&context).await;
                    context.reset();
                    return Ok(Self::new(io, initializer, resolver, context));
                }
                Err(error) => State::Err(error),
            };

            context.increment_attempts();

            if !resolver.unreachable(&context, &state).await {
                let State::Err(error) = state else {
                    unreachable!("state is immutable and established as Err above");
                };

                return Err(error);
            }
        }
    }
}

#[derive(Default)]
enum StateMachine<T> {
    #[default]
    Connected,
    Disconnected(PinFut<bool>),
    Reconnecting(PinFut<Result<T, std::io::Error>>),
    Reconnected(PinFut<()>),
}

/// Contains metrics about the underlying connection
///
/// Passed to the [`Resolver`], with each call to `disconnect`.
///
/// Currently tracks the number of reconnect attempts, but in the future may be expanded to include
/// additional metrics.
#[derive(Default, Debug, Clone, PartialEq, Eq)]
pub struct Context {
    total_attempts: usize,
    current_attempts: usize,
}

impl Context {
    /// The total number of times a reconnect has been attempted.
    ///
    /// The first time [`Resolver::disconnected`] is invoked this will return `1`.
    pub fn total_reconnect_attempts(&self) -> usize {
        self.total_attempts
    }

    fn increment_attempts(&mut self) {
        self.current_attempts += 1;
        self.total_attempts += 1;
    }

    fn reset(&mut self) {
        self.current_attempts = 0;
    }

    /// The number of reconnect attempts since the last successful connection. Reset each time
    /// the connection is re-established
    pub fn current_reconnect_attempts(&self) -> usize {
        self.current_attempts
    }
}

pub(crate) mod ready {
    macro_rules! ready {
        ($e:expr $(,)?) => {
            match $e {
                std::task::Poll::Ready(t) => t,
                std::task::Poll::Pending => return std::task::Poll::Pending,
            }
        };
    }

    pub(crate) use ready;
}