io_tether/
lib.rs

1#![doc = include_str!("../README.md")]
2use std::{future::Future, io::ErrorKind, pin::Pin};
3
4#[cfg(feature = "fs")]
5pub mod fs;
6mod implementations;
7#[cfg(feature = "net")]
8pub mod tcp;
9#[cfg(all(feature = "net", target_family = "unix"))]
10pub mod unix;
11
12/// A dynamically dispatched static future
13pub type PinFut<O> = Pin<Box<dyn Future<Output = O> + 'static + Send>>;
14
15/// Represents a type which drives reconnects
16///
17/// Since the disconnected method asynchronous, and is invoked when the underlying stream
18/// disconnects, calling asynchronous functions like
19/// [`tokio::time::sleep`](https://docs.rs/tokio/latest/tokio/time/fn.sleep.html) from within the
20/// body, work.
21///
22/// # Unpin
23///
24/// Since the method provides `&mut Self`, Self must be [`Unpin`]
25///
26/// # Return Type
27///
28/// The return types of the methods are [`PinFut`]. This has the requirement that the returned
29/// future be 'static (cannot hold references to self, or any of the arguments). However, you are
30/// still free to mutate data outside of the returned future.
31///
32/// Additionally, this method is invoked each time the I/O fails to establish a connection so
33/// writing futures which do not reference their environment is a little easier than it may seem.
34///
35/// # Example
36///
37/// A very simple implementation may look something like the following:
38///
39/// ```no_run
40/// # use std::time::Duration;
41/// # use io_tether::{Context, Reason, Resolver, PinFut};
42/// pub struct RetryResolver(bool);
43///
44/// impl<C> Resolver<C> for RetryResolver {
45///     fn disconnected(&mut self, context: &Context, _: &mut C) -> PinFut<bool> {
46///         let reason = context.reason();
47///         println!("WARN: Disconnected from server {:?}", reason);
48///         self.0 = true;
49///
50///         if context.current_reconnect_attempts() >= 5 || context.total_reconnect_attempts() >= 50 {
51///             return Box::pin(async move {false});
52///         }
53///
54///         Box::pin(async move {
55///             tokio::time::sleep(Duration::from_secs(10)).await;
56///             true
57///         })
58///     }
59/// }
60/// ```
61pub trait Resolver<C> {
62    /// Invoked by Tether when an error/disconnect is encountered.
63    ///
64    /// Returning `true` will result in a reconnect being attempted via `<T as Io>::reconnect`,
65    /// returning `false` will result in the error being returned from the originating call.
66    fn disconnected(&mut self, context: &Context, connector: &mut C) -> PinFut<bool>;
67
68    /// Invoked within [`Tether::connect`] if the initial connection attempt fails
69    ///
70    /// As with [`Self::disconnected`] the returned boolean determines whether the initial
71    /// connection attempt is retried
72    ///
73    /// Defaults to invoking [`Self::disconnected`]
74    fn unreachable(&mut self, context: &Context, connector: &mut C) -> PinFut<bool> {
75        self.disconnected(context, connector)
76    }
77
78    /// Invoked within [`Tether::connect`] if the initial connection attempt succeeds
79    ///
80    /// Defaults to invoking [`Self::reconnected`]
81    fn established(&mut self, context: &Context) -> PinFut<()> {
82        self.reconnected(context)
83    }
84
85    /// Invoked by Tether whenever the connection to the underlying I/O source has been
86    /// re-established
87    fn reconnected(&mut self, _context: &Context) -> PinFut<()> {
88        Box::pin(std::future::ready(()))
89    }
90}
91
92/// Represents an I/O source capable of reconnecting
93///
94/// This trait is implemented for a number of types in the library, with the implementations placed
95/// behind feature flags
96pub trait Io {
97    type Output;
98
99    /// Initializes the connection to the I/O source
100    fn connect(&mut self) -> PinFut<Result<Self::Output, std::io::Error>>;
101
102    /// Re-establishes the connection to the I/O source
103    fn reconnect(&mut self) -> PinFut<Result<Self::Output, std::io::Error>> {
104        self.connect()
105    }
106}
107
108/// Enum representing reasons for a disconnect
109#[derive(Debug)]
110#[non_exhaustive]
111pub enum Reason {
112    /// Represents the end of the file for the underlying io
113    ///
114    /// This can occur when the end of a file is read from the file system, when the remote socket
115    /// on a TCP connection is closed, etc. Generally it indicates a successful end of the
116    /// connection
117    Eof,
118    /// An I/O Error occurred
119    Err(std::io::Error),
120}
121
122impl std::fmt::Display for Reason {
123    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
124        match self {
125            Reason::Eof => f.write_str("End of file detected"),
126            Reason::Err(error) => error.fmt(f),
127        }
128    }
129}
130
131impl std::error::Error for Reason {}
132
133impl Reason {
134    /// A convenience function which returns whether the original error is capable of being retried
135    pub fn retryable(&self) -> bool {
136        use std::io::ErrorKind as Kind;
137
138        match self {
139            Reason::Eof => true,
140            Reason::Err(error) => matches!(
141                error.kind(),
142                Kind::NotFound
143                    | Kind::PermissionDenied
144                    | Kind::ConnectionRefused
145                    | Kind::ConnectionAborted
146                    | Kind::ConnectionReset
147                    | Kind::NotConnected
148                    | Kind::AlreadyExists
149                    | Kind::HostUnreachable
150                    | Kind::AddrNotAvailable
151                    | Kind::NetworkDown
152                    | Kind::BrokenPipe
153                    | Kind::TimedOut
154                    | Kind::UnexpectedEof
155                    | Kind::NetworkUnreachable
156                    | Kind::AddrInUse
157            ),
158        }
159    }
160
161    fn take(&mut self) -> Self {
162        std::mem::replace(self, Self::Eof)
163    }
164}
165
166impl From<Reason> for std::io::Error {
167    fn from(value: Reason) -> Self {
168        match value {
169            Reason::Eof => std::io::Error::new(ErrorKind::UnexpectedEof, "Eof error"),
170            Reason::Err(error) => error,
171        }
172    }
173}
174
175/// A wrapper type which contains the underlying I/O object, it's initializer, and resolver.
176///
177/// This in the main type exposed by the library. It implements [`AsyncRead`](tokio::io::AsyncRead)
178/// and [`AsyncWrite`](tokio::io::AsyncWrite) whenever the underlying I/O object implements them.
179///
180/// Calling things like
181/// [`read_buf`](https://docs.rs/tokio/latest/tokio/io/trait.AsyncReadExt.html#method.read_buf) will
182/// result in the I/O automatically reconnecting if an error is detected during the underlying I/O
183/// call.
184///
185/// # Example
186///
187/// ## Basic Resolver
188///
189/// Below is an example of a basic resolver which just logs the error and retries
190///
191/// ```no_run
192/// # use io_tether::*;
193/// # async fn foo() -> Result<(), Box<dyn std::error::Error>> {
194/// struct MyResolver;
195///
196/// impl<C> Resolver<C> for MyResolver {
197///     fn disconnected(&mut self, context: &Context, _: &mut C) -> PinFut<bool> {
198///         println!("WARN(disconnect): {:?}", context);
199///         Box::pin(async move { true }) // always immediately retry the connection
200///     }
201/// }
202///
203/// let stream = Tether::connect_tcp("localhost:8080", MyResolver).await?;
204///
205/// // Regardless of which half detects the disconnect, a reconnect will be attempted
206/// let (read, write) = tokio::io::split(stream);
207/// # Ok(()) }
208/// ```
209///
210/// # Specialized Resolver
211///
212/// For more specialized use cases we can implement [`Resolver`] only for certain connectors to give
213/// us extra control over the reconnect process.
214///
215/// ```
216/// # use io_tether::{*, tcp::TcpConnector};
217/// # use std::net::{SocketAddrV4, Ipv4Addr};
218/// struct MyResolver;
219///
220/// type Connector = TcpConnector<SocketAddrV4>;
221///
222/// impl Resolver<Connector> for MyResolver {
223///     fn disconnected(&mut self, context: &Context, conn: &mut Connector) -> PinFut<bool> {
224///         // Because we've specialized our resolver to act on TcpConnector for IPv4, we can alter
225///         // the address in between the disconnect, and the reconnect, to try a different host
226///         conn.get_addr_mut().set_ip(Ipv4Addr::LOCALHOST);
227///         conn.get_addr_mut().set_port(8082);
228///
229///         Box::pin(async move { true }) // always immediately retry the connection
230///     }
231/// }
232/// ```
233///
234/// # Note
235///
236/// Currently, there is no way to obtain a reference into the underlying I/O object. And the only
237/// way to reclaim the inner I/O type is by calling [`Tether::into_inner`].
238pub struct Tether<C: Io, R> {
239    state: State<C::Output>,
240    inner: TetherInner<C, R>,
241}
242
243/// The inner type for tether.
244///
245/// Helps satisfy the borrow checker when we need to mutate this while holding a mutable ref to the
246/// larger futs state machine
247struct TetherInner<C: Io, R> {
248    context: Context,
249    connector: C,
250    io: C::Output,
251    resolver: R,
252}
253
254impl<C: Io, R: Resolver<C>> TetherInner<C, R> {
255    fn disconnected(&mut self) -> PinFut<bool> {
256        self.resolver
257            .disconnected(&self.context, &mut self.connector)
258    }
259
260    fn reconnected(&mut self) -> PinFut<()> {
261        self.resolver.reconnected(&self.context)
262    }
263}
264
265impl<C, R> Tether<C, R>
266where
267    C: Io,
268    R: Resolver<C>,
269{
270    /// Construct a tether object from an existing I/O source
271    ///
272    /// # Note
273    ///
274    /// Often a simpler way to construct a [`Tether`] object is through [`Tether::connect`]
275    pub fn new(connector: C, io: C::Output, resolver: R) -> Self {
276        Self::new_with_context(connector, io, resolver, Context::default())
277    }
278
279    fn new_with_context(connector: C, io: C::Output, resolver: R, context: Context) -> Self {
280        Self {
281            state: Default::default(),
282            inner: TetherInner {
283                context,
284                io,
285                resolver,
286                connector,
287            },
288        }
289    }
290
291    fn reconnect(&mut self) {
292        self.state = State::Connected;
293        self.inner.context.reset();
294    }
295
296    /// Consume the Tether, and return the underlying I/O type
297    #[inline]
298    pub fn into_inner(self) -> C::Output {
299        self.inner.io
300    }
301
302    /// Connect to the I/O source, retrying on a failure.
303    pub async fn connect(mut connector: C, mut resolver: R) -> Result<Self, std::io::Error> {
304        let mut context = Context::default();
305
306        loop {
307            let state = match connector.connect().await {
308                Ok(io) => {
309                    resolver.established(&context).await;
310                    context.reset();
311                    return Ok(Self::new_with_context(connector, io, resolver, context));
312                }
313                Err(error) => Reason::Err(error),
314            };
315
316            context.increment_attempts();
317
318            if !resolver.unreachable(&context, &mut connector).await {
319                let Reason::Err(error) = state else {
320                    unreachable!("state is immutable and established as Err above");
321                };
322
323                return Err(error);
324            }
325        }
326    }
327
328    /// Connect to the I/O source, bypassing [`Resolver::unreachable`] implementation on a failure.
329    ///
330    /// This does still invoke [`Resolver::established`] if the connection is made successfully.
331    /// To bypass both, construct the IO source and pass it to [`Self::new`].
332    pub async fn connect_without_retry(
333        mut connector: C,
334        mut resolver: R,
335    ) -> Result<Self, std::io::Error> {
336        let context = Context::default();
337
338        let io = connector.connect().await?;
339        resolver.established(&context).await;
340        Ok(Self::new_with_context(connector, io, resolver, context))
341    }
342}
343
344/// The internal state machine which drives the connection and reconnect logic
345#[derive(Default)]
346enum State<T> {
347    #[default]
348    Connected,
349    Disconnected(PinFut<bool>),
350    Reconnecting(PinFut<Result<T, std::io::Error>>),
351    Reconnected(PinFut<()>),
352}
353
354/// Contains additional information about the disconnect
355///
356/// This type internally tracks the number of times a disconnect has occurred, and the reason for
357/// the disconnect.
358#[derive(Debug)]
359pub struct Context {
360    total_attempts: usize,
361    current_attempts: usize,
362    reason: Reason,
363}
364
365impl Default for Context {
366    fn default() -> Self {
367        Self {
368            total_attempts: 0,
369            current_attempts: 0,
370            reason: Reason::Eof,
371        }
372    }
373}
374
375impl Context {
376    /// The number of reconnect attempts since the last successful connection. Reset each time
377    /// the connection is established
378    #[inline]
379    pub fn current_reconnect_attempts(&self) -> usize {
380        self.current_attempts
381    }
382
383    /// The total number of times a reconnect has been attempted.
384    ///
385    /// The first time [`Resolver::disconnected`] or [`Resolver::unreachable`] is invoked this will
386    /// return `0`, each subsequent time it will be incremented by 1.
387    #[inline]
388    pub fn total_reconnect_attempts(&self) -> usize {
389        self.total_attempts
390    }
391
392    fn increment_attempts(&mut self) {
393        self.current_attempts += 1;
394        self.total_attempts += 1;
395    }
396
397    /// Get the current reason for the disconnect
398    #[inline]
399    pub fn reason(&self) -> &Reason {
400        &self.reason
401    }
402
403    /// Resets the current attempts, leaving the total reconnect attempts unchanged
404    #[inline]
405    fn reset(&mut self) {
406        self.current_attempts = 0;
407    }
408}
409
410pub(crate) mod ready {
411    macro_rules! ready {
412        ($e:expr $(,)?) => {
413            match $e {
414                std::task::Poll::Ready(t) => t,
415                std::task::Poll::Pending => return std::task::Poll::Pending,
416            }
417        };
418    }
419
420    pub(crate) use ready;
421}
422
423#[cfg(test)]
424mod tests {
425    use tokio::{
426        io::{AsyncReadExt, AsyncWriteExt},
427        net::TcpListener,
428    };
429
430    use super::*;
431
432    struct Once;
433
434    impl<T> Resolver<T> for Once {
435        fn disconnected(&mut self, context: &Context, _connector: &mut T) -> PinFut<bool> {
436            let retry = context.total_reconnect_attempts() < 1;
437
438            Box::pin(async move { retry })
439        }
440    }
441
442    #[tokio::test]
443    async fn disconnect_is_retried() {
444        let listener = TcpListener::bind("0.0.0.0:0").await.unwrap();
445        let addr = listener.local_addr().unwrap();
446        tokio::spawn(async move {
447            let mut connections = 0;
448            loop {
449                let (mut stream, _addr) = listener.accept().await.unwrap();
450                stream.write_u8(connections).await.unwrap();
451                connections += 1;
452            }
453        });
454
455        let mut stream = Tether::connect_tcp(addr, Once).await.unwrap();
456        let mut buf = Vec::new();
457        assert!(stream.read_to_end(&mut buf).await.is_err());
458        assert_eq!(buf.as_slice(), &[0, 1])
459    }
460}