Skip to main content

io_tether/
lib.rs

1#![doc = include_str!("../README.md")]
2use std::{future::Future, io::ErrorKind, pin::Pin};
3
4pub mod config;
5#[cfg(feature = "fs")]
6pub mod fs;
7mod implementations;
8#[cfg(feature = "net")]
9pub mod tcp;
10#[cfg(all(feature = "net", target_family = "unix"))]
11pub mod unix;
12
13#[cfg(test)]
14mod tests;
15
16use config::Config;
17
18/// A dynamically dispatched static future
19pub type PinFut<O> = Pin<Box<dyn Future<Output = O> + 'static + Send>>;
20
21/// Represents a type which drives reconnects
22///
23/// Since the disconnected method asynchronous, and is invoked when the underlying stream
24/// disconnects, calling asynchronous functions like
25/// [`tokio::time::sleep`](https://docs.rs/tokio/latest/tokio/time/fn.sleep.html) from within the
26/// body, work.
27///
28/// # Unpin
29///
30/// Since the method provides `&mut Self`, Self must be [`Unpin`]
31///
32/// # Return Type
33///
34/// The return types of the methods are [`PinFut`]. This has the requirement that the returned
35/// future be 'static (cannot hold references to self, or any of the arguments). However, you are
36/// still free to mutate data outside of the returned future.
37///
38/// Additionally, this method is invoked each time the I/O fails to establish a connection so
39/// writing futures which do not reference their environment is a little easier than it may seem.
40///
41/// # Example
42///
43/// A very simple implementation may look something like the following:
44///
45/// ```no_run
46/// # use std::time::Duration;
47/// # use io_tether::{Action, Context, Reason, Resolver, PinFut};
48/// pub struct RetryResolver(bool);
49///
50/// impl<C> Resolver<C> for RetryResolver {
51///     fn disconnected(&mut self, context: &Context, _: &mut C) -> PinFut<Action> {
52///         let reason = context.reason();
53///         println!("WARN: Disconnected from server {:?}", reason);
54///         self.0 = true;
55///
56///         if context.current_reconnect_attempts() >= 5 || context.total_reconnect_attempts() >= 50 {
57///             return Box::pin(async move { Action::Exhaust });
58///         }
59///
60///         Box::pin(async move {
61///             tokio::time::sleep(Duration::from_secs(10)).await;
62///             Action::AttemptReconnect
63///         })
64///     }
65/// }
66/// ```
67pub trait Resolver<C> {
68    /// Invoked by Tether when an error/disconnect is encountered.
69    ///
70    /// Returning `true` will result in a reconnect being attempted via `<T as Io>::reconnect`,
71    /// returning `false` will result in the error being returned from the originating call.
72    fn disconnected(&mut self, context: &Context, connector: &mut C) -> PinFut<Action>;
73
74    /// Invoked within [`Tether::connect`] if the initial connection attempt fails
75    ///
76    /// As with [`Self::disconnected`] the returned boolean determines whether the initial
77    /// connection attempt is retried
78    ///
79    /// Defaults to invoking [`Self::disconnected`] where [`Action::Ignore`] results in a disconnect
80    fn unreachable(&mut self, context: &Context, connector: &mut C) -> PinFut<bool> {
81        let fut = self.disconnected(context, connector);
82        Box::pin(async move {
83            match fut.await {
84                Action::AttemptReconnect => true,
85                Action::Exhaust | Action::Ignore => false,
86            }
87        })
88    }
89
90    /// Invoked within [`Tether::connect`] if the initial connection attempt succeeds
91    ///
92    /// Defaults to invoking [`Self::reconnected`]
93    fn established(&mut self, context: &Context) -> PinFut<()> {
94        self.reconnected(context)
95    }
96
97    /// Invoked by Tether whenever the connection to the underlying I/O source has been
98    /// re-established
99    fn reconnected(&mut self, _context: &Context) -> PinFut<()> {
100        Box::pin(std::future::ready(()))
101    }
102}
103
104/// Represents an I/O source capable of reconnecting
105///
106/// This trait is implemented for a number of types in the library, with the implementations placed
107/// behind feature flags
108pub trait Connector {
109    type Output;
110
111    /// Initializes the connection to the I/O source
112    fn connect(&mut self) -> PinFut<Result<Self::Output, std::io::Error>>;
113
114    /// Re-establishes the connection to the I/O source
115    fn reconnect(&mut self) -> PinFut<Result<Self::Output, std::io::Error>> {
116        self.connect()
117    }
118}
119
120/// Enum representing reasons for a disconnect
121#[derive(Debug)]
122#[non_exhaustive]
123pub enum Reason {
124    /// Represents the end of the file for the underlying io
125    ///
126    /// This can occur when the end of a file is read from the file system, when the remote socket
127    /// on a TCP connection is closed, etc. Generally it indicates a successful end of the
128    /// connection
129    Eof,
130    /// An I/O Error occurred
131    Err(std::io::Error),
132}
133
134impl std::fmt::Display for Reason {
135    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
136        match self {
137            Reason::Eof => f.write_str("End of file detected"),
138            Reason::Err(error) => error.fmt(f),
139        }
140    }
141}
142
143impl std::error::Error for Reason {}
144
145impl Reason {
146    pub(crate) fn clone_private(&self) -> Self {
147        match self {
148            Reason::Eof => Self::Eof,
149            Reason::Err(error) => {
150                let kind = error.kind();
151                let error = std::io::Error::new(kind, error.to_string());
152                Self::Err(error)
153            }
154        }
155    }
156
157    /// A convenience function which returns whether the original error is capable of being retried
158    pub fn retryable(&self) -> bool {
159        use std::io::ErrorKind as Kind;
160
161        match self {
162            Reason::Eof => true,
163            Reason::Err(error) => matches!(
164                error.kind(),
165                Kind::NotFound
166                    | Kind::PermissionDenied
167                    | Kind::ConnectionRefused
168                    | Kind::ConnectionAborted
169                    | Kind::ConnectionReset
170                    | Kind::NotConnected
171                    | Kind::AlreadyExists
172                    | Kind::HostUnreachable
173                    | Kind::AddrNotAvailable
174                    | Kind::NetworkDown
175                    | Kind::BrokenPipe
176                    | Kind::TimedOut
177                    | Kind::UnexpectedEof
178                    | Kind::NetworkUnreachable
179                    | Kind::AddrInUse
180            ),
181        }
182    }
183}
184
185impl From<Reason> for std::io::Error {
186    fn from(value: Reason) -> Self {
187        match value {
188            Reason::Eof => std::io::Error::new(ErrorKind::UnexpectedEof, "Eof error"),
189            Reason::Err(error) => error,
190        }
191    }
192}
193
194/// A wrapper type which contains the underlying I/O object, it's initializer, and resolver.
195///
196/// This in the main type exposed by the library. It implements [`AsyncRead`](tokio::io::AsyncRead)
197/// and [`AsyncWrite`](tokio::io::AsyncWrite) whenever the underlying I/O object implements them.
198///
199/// Calling things like
200/// [`read_buf`](https://docs.rs/tokio/latest/tokio/io/trait.AsyncReadExt.html#method.read_buf) will
201/// result in the I/O automatically reconnecting if an error is detected during the underlying I/O
202/// call.
203///
204/// # Example
205///
206/// ## Basic Resolver
207///
208/// Below is an example of a basic resolver which just logs the error and retries
209///
210/// ```no_run
211/// # use io_tether::*;
212/// # async fn foo() -> Result<(), Box<dyn std::error::Error>> {
213/// struct MyResolver;
214///
215/// impl<C> Resolver<C> for MyResolver {
216///     fn disconnected(&mut self, context: &Context, _: &mut C) -> PinFut<Action> {
217///         println!("WARN(disconnect): {:?}", context);
218///
219///         // always immediately retry the connection
220///         Box::pin(async move { Action::AttemptReconnect })
221///     }
222/// }
223///
224/// let stream = Tether::connect_tcp("localhost:8080", MyResolver).await?;
225///
226/// // Regardless of which half detects the disconnect, a reconnect will be attempted
227/// let (read, write) = tokio::io::split(stream);
228/// # Ok(()) }
229/// ```
230///
231/// # Specialized Resolver
232///
233/// For more specialized use cases we can implement [`Resolver`] only for certain connectors to give
234/// us extra control over the reconnect process.
235///
236/// ```
237/// # use io_tether::{*, tcp::TcpConnector};
238/// # use std::net::{SocketAddrV4, Ipv4Addr};
239/// struct MyResolver;
240///
241/// type Connector = TcpConnector<SocketAddrV4>;
242///
243/// impl Resolver<Connector> for MyResolver {
244///     fn disconnected(&mut self, context: &Context, conn: &mut Connector) -> PinFut<Action> {
245///         // Because we've specialized our resolver to act on TcpConnector for IPv4, we can alter
246///         // the address in between the disconnect, and the reconnect, to try a different host
247///         conn.get_addr_mut().set_ip(Ipv4Addr::LOCALHOST);
248///         conn.get_addr_mut().set_port(8082);
249///
250///         // always immediately retry the connection
251///         Box::pin(async move { Action::AttemptReconnect })
252///     }
253/// }
254/// ```
255///
256/// # Note
257///
258/// Currently, there is no way to obtain a reference into the underlying I/O object. And the only
259/// way to reclaim the inner I/O type is by calling [`Tether::into_inner`].
260pub struct Tether<C: Connector, R> {
261    state: State<C::Output>,
262    inner: TetherInner<C, R>,
263}
264
265/// The inner type for tether.
266///
267/// Helps satisfy the borrow checker when we need to mutate this while holding a mutable ref to the
268/// larger futs state machine
269struct TetherInner<C: Connector, R> {
270    config: Config,
271    connector: C,
272    context: Context,
273    io: C::Output,
274    resolver: R,
275    // Should only be acted on when Config::keep_data_on_failed_write is false
276    last_write: Option<Reason>,
277}
278
279impl<C: Connector, R: Resolver<C>> TetherInner<C, R> {
280    fn set_connected(&mut self, state: &mut State<C::Output>) {
281        *state = State::Connected;
282        self.context.reset();
283    }
284
285    fn set_reconnected(&mut self, state: &mut State<C::Output>, new_io: <C as Connector>::Output) {
286        self.io = new_io;
287        let fut = self.resolver.reconnected(&self.context);
288        *state = State::Reconnected(fut);
289    }
290
291    fn set_reconnecting(&mut self, state: &mut State<C::Output>) {
292        let fut = self.connector.reconnect();
293        *state = State::Reconnecting(fut);
294    }
295
296    fn set_disconnected(&mut self, state: &mut State<C::Output>, reason: Reason, source: Source) {
297        self.context.reason = Some((reason, source));
298        let fut = self
299            .resolver
300            .disconnected(&self.context, &mut self.connector);
301        *state = State::Disconnected(fut);
302    }
303}
304
305impl<C: Connector, R> Tether<C, R> {
306    /// Returns a reference to the inner resolver
307    pub fn resolver(&self) -> &R {
308        &self.inner.resolver
309    }
310
311    /// Returns a reference to the inner connector
312    pub fn connector(&self) -> &C {
313        &self.inner.connector
314    }
315
316    /// Returns a reference to the context
317    pub fn context(&self) -> &Context {
318        &self.inner.context
319    }
320}
321
322impl<C, R> Tether<C, R>
323where
324    C: Connector,
325    R: Resolver<C>,
326{
327    /// Construct a tether object from an existing I/O source
328    ///
329    /// # Warning
330    ///
331    /// Unlike [`Tether::connect`], this method does not invoke the resolver's `established` method.
332    /// It is generally recommended that you use [`Tether::connect`].
333    pub fn new(connector: C, io: C::Output, resolver: R) -> Self {
334        Self::new_with_config(connector, io, resolver, Config::default())
335    }
336
337    pub fn new_with_config(connector: C, io: C::Output, resolver: R, config: Config) -> Self {
338        Self::new_with_context(connector, io, resolver, Context::default(), config)
339    }
340
341    fn new_with_context(
342        connector: C,
343        io: C::Output,
344        resolver: R,
345        context: Context,
346        config: Config,
347    ) -> Self {
348        Self {
349            state: Default::default(),
350            inner: TetherInner {
351                config,
352                connector,
353                context,
354                io,
355                resolver,
356                last_write: None,
357            },
358        }
359    }
360
361    /// Overrides the default configuration of the Tether object
362    pub fn set_config(&mut self, config: Config) {
363        self.inner.config = config;
364    }
365
366    /// Consume the Tether, and return the underlying I/O type
367    #[inline]
368    pub fn into_inner(self) -> C::Output {
369        self.inner.io
370    }
371
372    /// Connect to the I/O source, retrying on a failure.
373    pub async fn connect(mut connector: C, mut resolver: R) -> Result<Self, std::io::Error> {
374        let mut context = Context::default();
375
376        loop {
377            let state = match connector.connect().await {
378                Ok(io) => {
379                    resolver.established(&context).await;
380                    context.reset();
381                    return Ok(Self::new_with_context(
382                        connector,
383                        io,
384                        resolver,
385                        context,
386                        Config::default(),
387                    ));
388                }
389                Err(error) => error,
390            };
391
392            context.reason = Some((Reason::Err(state), Source::Reconnect));
393            context.increment_attempts();
394
395            if !resolver.unreachable(&context, &mut connector).await {
396                let Some((Reason::Err(error), _)) = context.reason else {
397                    unreachable!("state is immutable and established as Err above");
398                };
399
400                return Err(error);
401            }
402        }
403    }
404
405    /// Connect to the I/O source, bypassing [`Resolver::unreachable`] implementation on a failure.
406    ///
407    /// This does still invoke [`Resolver::established`] if the connection is made successfully.
408    /// To bypass both, construct the IO source and pass it to [`Self::new`].
409    pub async fn connect_without_retry(
410        mut connector: C,
411        mut resolver: R,
412    ) -> Result<Self, std::io::Error> {
413        let context = Context::default();
414
415        let io = connector.connect().await?;
416        resolver.established(&context).await;
417        Ok(Self::new_with_context(
418            connector,
419            io,
420            resolver,
421            context,
422            Config::default(),
423        ))
424    }
425}
426
427#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
428pub enum Action {
429    /// Instruct the Tether object to attempt to reconnect to the underlying I/O resource
430    AttemptReconnect,
431    /// Instruct the Tether object to not attempt to reconnect to the underlying I/O resource, and
432    /// instead propegate the error up to the callsite.
433    Exhaust,
434    /// Ignore the reason for the disconnect, the same I/O instance will be preserved and the
435    /// it's waker will be registered with the underlying poll method.
436    ///
437    /// # Warning
438    ///
439    /// Some implementations may panic if they provided an EOF, and are subsequently polled again.
440    /// Use caution when returning this
441    Ignore,
442}
443
444/// The internal state machine which drives the connection and reconnect logic
445#[derive(Default)]
446enum State<T> {
447    #[default]
448    Connected,
449    Disconnected(PinFut<Action>),
450    Reconnecting(PinFut<Result<T, std::io::Error>>),
451    Reconnected(PinFut<()>),
452    /// Terminal state: resolver returned `Action::Exhaust`. No further reconnects or resolver
453    /// calls will occur. Subsequent polls return a result derived from the stored reason.
454    Exhausted(Reason, Source),
455}
456
457#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
458enum Source {
459    Io,
460    Reconnect,
461}
462
463/// Contains additional information about the disconnect
464///
465/// This type internally tracks the number of times a disconnect has occurred, and the reason for
466/// the disconnect.
467#[derive(Default, Debug)]
468pub struct Context {
469    total_attempts: usize,
470    current_attempts: usize,
471    reason: Option<(Reason, Source)>,
472}
473
474impl Context {
475    /// The number of reconnect attempts since the last successful connection. Reset each time
476    /// the connection is established
477    #[inline]
478    pub fn current_reconnect_attempts(&self) -> usize {
479        self.current_attempts
480    }
481
482    /// The total number of times a reconnect has been attempted.
483    ///
484    /// The first time [`Resolver::disconnected`] or [`Resolver::unreachable`] is invoked this will
485    /// return `0`, each subsequent time it will be incremented by 1.
486    #[inline]
487    pub fn total_reconnect_attempts(&self) -> usize {
488        self.total_attempts
489    }
490
491    fn increment_attempts(&mut self) {
492        self.current_attempts += 1;
493        self.total_attempts += 1;
494    }
495
496    /// Get the current reason for the disconnect
497    ///
498    /// # Panics
499    ///
500    /// Might, panic if called outside of the methods in resolver. Will also panic if called AFTER
501    /// and error has been returned
502    #[inline]
503    pub fn reason(&self) -> &Reason {
504        self.try_reason().unwrap()
505    }
506
507    /// Get the current optional reason for the disconnect
508    #[inline]
509    pub fn try_reason(&self) -> Option<&Reason> {
510        self.reason.as_ref().map(|val| &val.0)
511    }
512
513    /// Resets the current attempts, leaving the total reconnect attempts unchanged
514    #[inline]
515    fn reset(&mut self) {
516        self.current_attempts = 0;
517    }
518}
519
520pub(crate) mod ready {
521    macro_rules! ready {
522        ($e:expr $(,)?) => {
523            match $e {
524                std::task::Poll::Ready(t) => t,
525                std::task::Poll::Pending => return std::task::Poll::Pending,
526            }
527        };
528    }
529
530    pub(crate) use ready;
531}