io_tether/
lib.rs

1#![doc = include_str!("../README.md")]
2use std::{future::Future, io::ErrorKind, pin::Pin};
3
4#[cfg(feature = "fs")]
5pub mod fs;
6mod implementations;
7#[cfg(feature = "net")]
8pub mod tcp;
9#[cfg(all(feature = "net", target_family = "unix"))]
10pub mod unix;
11
12/// A dynamically dispatched static future
13pub type PinFut<O> = Pin<Box<dyn Future<Output = O> + 'static + Send>>;
14
15/// Represents a type which drives reconnects
16///
17/// Since the disconnected method asynchronous, and is invoked when the underlying stream
18/// disconnects, calling asynchronous functions like
19/// [`tokio::time::sleep`](https://docs.rs/tokio/latest/tokio/time/fn.sleep.html) from within the
20/// body, work.
21///
22/// # Unpin
23///
24/// Since the method provides `&mut Self`, Self must be [`Unpin`]
25///
26/// # Return Type
27///
28/// The return types of the methods are [`PinFut`]. This has the requirement that the returned
29/// future be 'static (cannot hold references to self, or any of the arguments). However, you are
30/// still free to mutate data outside of the returned future.
31///
32/// Additionally, this method is invoked each time the I/O fails to establish a connection so
33/// writing futures which do not reference their environment is a little easier than it may seem.
34///
35/// # Example
36///
37/// A very simple implementation may look something like the following:
38///
39/// ```no_run
40/// # use std::time::Duration;
41/// # use io_tether::{Context, Reason, Resolver, PinFut};
42/// pub struct RetryResolver(bool);
43///
44/// impl<C> Resolver<C> for RetryResolver {
45///     fn disconnected(&mut self, context: &Context, _: &mut C) -> PinFut<bool> {
46///         let reason = context.reason();
47///         println!("WARN: Disconnected from server {:?}", reason);
48///         self.0 = true;
49///
50///         if context.current_reconnect_attempts() >= 5 || context.total_reconnect_attempts() >= 50 {
51///             return Box::pin(async move {false});
52///         }
53///
54///         Box::pin(async move {
55///             tokio::time::sleep(Duration::from_secs(10)).await;
56///             true
57///         })
58///     }
59/// }
60/// ```
61pub trait Resolver<C> {
62    /// Invoked by Tether when an error/disconnect is encountered.
63    ///
64    /// Returning `true` will result in a reconnect being attempted via `<T as Io>::reconnect`,
65    /// returning `false` will result in the error being returned from the originating call.
66    fn disconnected(&mut self, context: &Context, connector: &mut C) -> PinFut<bool>;
67
68    /// Invoked within [`Tether::connect`] if the initial connection attempt fails
69    ///
70    /// As with [`Self::disconnected`] the returned boolean determines whether the initial
71    /// connection attempt is retried
72    ///
73    /// Defaults to invoking [`Self::disconnected`]
74    fn unreachable(&mut self, context: &Context, connector: &mut C) -> PinFut<bool> {
75        self.disconnected(context, connector)
76    }
77
78    /// Invoked within [`Tether::connect`] if the initial connection attempt succeeds
79    ///
80    /// Defaults to invoking [`Self::reconnected`]
81    fn established(&mut self, context: &Context) -> PinFut<()> {
82        self.reconnected(context)
83    }
84
85    /// Invoked by Tether whenever the connection to the underlying I/O source has been
86    /// re-established
87    fn reconnected(&mut self, _context: &Context) -> PinFut<()> {
88        Box::pin(std::future::ready(()))
89    }
90}
91
92/// Represents an I/O source capable of reconnecting
93///
94/// This trait is implemented for a number of types in the library, with the implementations placed
95/// behind feature flags
96pub trait Io {
97    type Output;
98
99    /// Initializes the connection to the I/O source
100    fn connect(&mut self) -> PinFut<Result<Self::Output, std::io::Error>>;
101
102    /// Re-establishes the connection to the I/O source
103    fn reconnect(&mut self) -> PinFut<Result<Self::Output, std::io::Error>> {
104        self.connect()
105    }
106}
107
108/// Enum representing reasons for a disconnect
109#[derive(Debug)]
110#[non_exhaustive]
111pub enum Reason {
112    /// Represents the end of the file for the underlying io
113    ///
114    /// This can occur when the end of a file is read from the file system, when the remote socket
115    /// on a TCP connection is closed, etc. Generally it indicates a successful end of the
116    /// connection
117    Eof,
118    /// An I/O Error occurred
119    Err(std::io::Error),
120}
121
122impl std::fmt::Display for Reason {
123    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
124        match self {
125            Reason::Eof => f.write_str("End of file detected"),
126            Reason::Err(error) => error.fmt(f),
127        }
128    }
129}
130
131impl std::error::Error for Reason {}
132
133impl Reason {
134    /// A convenience function which returns whether the original error is capable of being retried
135    pub fn retryable(&self) -> bool {
136        use std::io::ErrorKind as Kind;
137
138        match self {
139            Reason::Eof => true,
140            Reason::Err(error) => matches!(
141                error.kind(),
142                Kind::NotFound
143                    | Kind::PermissionDenied
144                    | Kind::ConnectionRefused
145                    | Kind::ConnectionAborted
146                    | Kind::ConnectionReset
147                    | Kind::NotConnected
148                    | Kind::AlreadyExists
149                    | Kind::HostUnreachable
150                    | Kind::AddrNotAvailable
151                    | Kind::NetworkDown
152                    | Kind::BrokenPipe
153                    | Kind::TimedOut
154                    | Kind::UnexpectedEof
155                    | Kind::NetworkUnreachable
156                    | Kind::AddrInUse
157            ),
158        }
159    }
160}
161
162impl From<Reason> for std::io::Error {
163    fn from(value: Reason) -> Self {
164        match value {
165            Reason::Eof => std::io::Error::new(ErrorKind::UnexpectedEof, "Eof error"),
166            Reason::Err(error) => error,
167        }
168    }
169}
170
171/// A wrapper type which contains the underlying I/O object, it's initializer, and resolver.
172///
173/// This in the main type exposed by the library. It implements [`AsyncRead`](tokio::io::AsyncRead)
174/// and [`AsyncWrite`](tokio::io::AsyncWrite) whenever the underlying I/O object implements them.
175///
176/// Calling things like
177/// [`read_buf`](https://docs.rs/tokio/latest/tokio/io/trait.AsyncReadExt.html#method.read_buf) will
178/// result in the I/O automatically reconnecting if an error is detected during the underlying I/O
179/// call.
180///
181/// # Example
182///
183/// ## Basic Resolver
184///
185/// Below is an example of a basic resolver which just logs the error and retries
186///
187/// ```no_run
188/// # use io_tether::*;
189/// # async fn foo() -> Result<(), Box<dyn std::error::Error>> {
190/// struct MyResolver;
191///
192/// impl<C> Resolver<C> for MyResolver {
193///     fn disconnected(&mut self, context: &Context, _: &mut C) -> PinFut<bool> {
194///         println!("WARN(disconnect): {:?}", context);
195///         Box::pin(async move { true }) // always immediately retry the connection
196///     }
197/// }
198///
199/// let stream = Tether::connect_tcp("localhost:8080", MyResolver).await?;
200///
201/// // Regardless of which half detects the disconnect, a reconnect will be attempted
202/// let (read, write) = tokio::io::split(stream);
203/// # Ok(()) }
204/// ```
205///
206/// # Specialized Resolver
207///
208/// For more specialized use cases we can implement [`Resolver`] only for certain connectors to give
209/// us extra control over the reconnect process.
210///
211/// ```
212/// # use io_tether::{*, tcp::TcpConnector};
213/// # use std::net::{SocketAddrV4, Ipv4Addr};
214/// struct MyResolver;
215///
216/// type Connector = TcpConnector<SocketAddrV4>;
217///
218/// impl Resolver<Connector> for MyResolver {
219///     fn disconnected(&mut self, context: &Context, conn: &mut Connector) -> PinFut<bool> {
220///         // Because we've specialized our resolver to act on TcpConnector for IPv4, we can alter
221///         // the address in between the disconnect, and the reconnect, to try a different host
222///         conn.get_addr_mut().set_ip(Ipv4Addr::LOCALHOST);
223///         conn.get_addr_mut().set_port(8082);
224///
225///         Box::pin(async move { true }) // always immediately retry the connection
226///     }
227/// }
228/// ```
229///
230/// # Note
231///
232/// Currently, there is no way to obtain a reference into the underlying I/O object. And the only
233/// way to reclaim the inner I/O type is by calling [`Tether::into_inner`].
234pub struct Tether<C: Io, R> {
235    state: State<C::Output>,
236    inner: TetherInner<C, R>,
237}
238
239/// The inner type for tether.
240///
241/// Helps satisfy the borrow checker when we need to mutate this while holding a mutable ref to the
242/// larger futs state machine
243struct TetherInner<C: Io, R> {
244    context: Context,
245    connector: C,
246    io: C::Output,
247    resolver: R,
248}
249
250impl<C: Io, R: Resolver<C>> TetherInner<C, R> {
251    fn disconnected(&mut self) -> PinFut<bool> {
252        self.resolver
253            .disconnected(&self.context, &mut self.connector)
254    }
255
256    fn reconnected(&mut self) -> PinFut<()> {
257        self.resolver.reconnected(&self.context)
258    }
259}
260
261impl<C: Io, R> Tether<C, R> {
262    /// Returns a reference to the inner resolver
263    pub fn resolver(&self) -> &R {
264        &self.inner.resolver
265    }
266
267    /// Returns a reference to the inner connector
268    pub fn connector(&self) -> &C {
269        &self.inner.connector
270    }
271}
272
273impl<C, R> Tether<C, R>
274where
275    C: Io,
276    R: Resolver<C>,
277{
278    /// Construct a tether object from an existing I/O source
279    ///
280    /// # Warning
281    ///
282    /// Unlike [`Tether::connect`], this method does not invoke the resolver's `established` method.
283    /// It is generally recommended that you use [`Tether::connect`].
284    pub fn new(connector: C, io: C::Output, resolver: R) -> Self {
285        Self::new_with_context(connector, io, resolver, Context::default())
286    }
287
288    fn new_with_context(connector: C, io: C::Output, resolver: R, context: Context) -> Self {
289        Self {
290            state: Default::default(),
291            inner: TetherInner {
292                context,
293                io,
294                resolver,
295                connector,
296            },
297        }
298    }
299
300    fn set_connected(&mut self) {
301        self.state = State::Connected;
302        self.inner.context.reset();
303    }
304
305    fn set_reconnected(&mut self) {
306        let fut = self.inner.reconnected();
307        self.state = State::Reconnected(fut);
308    }
309
310    fn set_reconnecting(&mut self) {
311        let fut = self.inner.connector.reconnect();
312        self.state = State::Reconnecting(fut);
313    }
314
315    fn set_disconnected(&mut self, reason: Reason) {
316        self.inner.context.reason = Some(reason);
317        let fut = self.inner.disconnected();
318        self.state = State::Disconnected(fut);
319    }
320
321    /// Consume the Tether, and return the underlying I/O type
322    #[inline]
323    pub fn into_inner(self) -> C::Output {
324        self.inner.io
325    }
326
327    /// Connect to the I/O source, retrying on a failure.
328    pub async fn connect(mut connector: C, mut resolver: R) -> Result<Self, std::io::Error> {
329        let mut context = Context::default();
330
331        loop {
332            let state = match connector.connect().await {
333                Ok(io) => {
334                    resolver.established(&context).await;
335                    context.reset();
336                    return Ok(Self::new_with_context(connector, io, resolver, context));
337                }
338                Err(error) => error,
339            };
340
341            context.reason = Some(Reason::Err(state));
342            context.increment_attempts();
343
344            if !resolver.unreachable(&context, &mut connector).await {
345                let Some(Reason::Err(error)) = context.reason else {
346                    unreachable!("state is immutable and established as Err above");
347                };
348
349                return Err(error);
350            }
351        }
352    }
353
354    /// Connect to the I/O source, bypassing [`Resolver::unreachable`] implementation on a failure.
355    ///
356    /// This does still invoke [`Resolver::established`] if the connection is made successfully.
357    /// To bypass both, construct the IO source and pass it to [`Self::new`].
358    pub async fn connect_without_retry(
359        mut connector: C,
360        mut resolver: R,
361    ) -> Result<Self, std::io::Error> {
362        let context = Context::default();
363
364        let io = connector.connect().await?;
365        resolver.established(&context).await;
366        Ok(Self::new_with_context(connector, io, resolver, context))
367    }
368}
369
370/// The internal state machine which drives the connection and reconnect logic
371#[derive(Default)]
372enum State<T> {
373    #[default]
374    Connected,
375    Disconnected(PinFut<bool>),
376    Reconnecting(PinFut<Result<T, std::io::Error>>),
377    Reconnected(PinFut<()>),
378}
379
380/// Contains additional information about the disconnect
381///
382/// This type internally tracks the number of times a disconnect has occurred, and the reason for
383/// the disconnect.
384#[derive(Default, Debug)]
385pub struct Context {
386    total_attempts: usize,
387    current_attempts: usize,
388    reason: Option<Reason>,
389}
390
391impl Context {
392    /// The number of reconnect attempts since the last successful connection. Reset each time
393    /// the connection is established
394    #[inline]
395    pub fn current_reconnect_attempts(&self) -> usize {
396        self.current_attempts
397    }
398
399    /// The total number of times a reconnect has been attempted.
400    ///
401    /// The first time [`Resolver::disconnected`] or [`Resolver::unreachable`] is invoked this will
402    /// return `0`, each subsequent time it will be incremented by 1.
403    #[inline]
404    pub fn total_reconnect_attempts(&self) -> usize {
405        self.total_attempts
406    }
407
408    fn increment_attempts(&mut self) {
409        self.current_attempts += 1;
410        self.total_attempts += 1;
411    }
412
413    /// Get the current reason for the disconnect
414    ///
415    /// # Panics
416    ///
417    /// Might, panic if called outside of the methods in resolver. Will also panic if called AFTER
418    /// and error has been returned
419    #[inline]
420    pub fn reason(&self) -> &Reason {
421        self.try_reason().unwrap()
422    }
423
424    /// Get the current optional reason for the disconnect
425    #[inline]
426    pub fn try_reason(&self) -> Option<&Reason> {
427        self.reason.as_ref()
428    }
429
430    /// Resets the current attempts, leaving the total reconnect attempts unchanged
431    #[inline]
432    fn reset(&mut self) {
433        self.current_attempts = 0;
434    }
435}
436
437pub(crate) mod ready {
438    macro_rules! ready {
439        ($e:expr $(,)?) => {
440            match $e {
441                std::task::Poll::Ready(t) => t,
442                std::task::Poll::Pending => return std::task::Poll::Pending,
443            }
444        };
445    }
446
447    pub(crate) use ready;
448}
449
450#[cfg(test)]
451mod tests {
452    use tokio::{
453        io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
454        net::TcpListener,
455    };
456    use tokio_test::io::{Builder, Mock};
457
458    use super::*;
459
460    struct Value(bool);
461
462    impl<T> Resolver<T> for Value {
463        fn disconnected(&mut self, _context: &Context, _connector: &mut T) -> PinFut<bool> {
464            let val = self.0;
465            Box::pin(async move { val })
466        }
467    }
468
469    struct Once;
470
471    impl<T> Resolver<T> for Once {
472        fn disconnected(&mut self, context: &Context, _connector: &mut T) -> PinFut<bool> {
473            let retry = context.total_reconnect_attempts() < 1;
474
475            Box::pin(async move { retry })
476        }
477    }
478
479    fn other(err: &'static str) -> std::io::Error {
480        std::io::Error::other(err)
481    }
482
483    trait ReadWrite: 'static + AsyncRead + AsyncWrite + Unpin {}
484    impl<T: 'static + AsyncRead + AsyncWrite + Unpin> ReadWrite for T {}
485
486    struct MockConnector<F>(F);
487
488    impl<F: FnMut() -> Mock> Io for MockConnector<F> {
489        type Output = Mock;
490
491        fn connect(&mut self) -> PinFut<Result<Self::Output, std::io::Error>> {
492            let value = self.0();
493
494            Box::pin(async move { Ok(value) })
495        }
496    }
497
498    async fn tester<A>(test: A, mock: impl ReadWrite, tether: impl ReadWrite)
499    where
500        A: AsyncFn(Box<dyn ReadWrite>) -> String,
501    {
502        let mock_result = (test)(Box::new(mock)).await;
503        let tether_result = (test)(Box::new(tether)).await;
504
505        assert_eq!(mock_result, tether_result);
506    }
507
508    async fn mock_acts_as_tether_mock<F, A>(mut gener: F, test: A)
509    where
510        F: FnMut() -> Mock + 'static + Unpin,
511        A: AsyncFn(Box<dyn ReadWrite>) -> String,
512    {
513        let mock = gener();
514        let tether_mock = Tether::connect(MockConnector(gener), Value(false))
515            .await
516            .unwrap();
517
518        tester(test, mock, tether_mock).await
519    }
520
521    #[tokio::test]
522    async fn single_read_then_eof() {
523        let test = async |mut reader: Box<dyn ReadWrite>| {
524            let mut output = String::new();
525            reader.read_to_string(&mut output).await.unwrap();
526            output
527        };
528
529        mock_acts_as_tether_mock(|| Builder::new().read(b"foobar").read(b"").build(), test).await;
530    }
531
532    #[tokio::test]
533    async fn two_read_then_eof() {
534        let test = async |mut reader: Box<dyn ReadWrite>| {
535            let mut output = String::new();
536            reader.read_to_string(&mut output).await.unwrap();
537            output
538        };
539
540        let builder = || Builder::new().read(b"foo").read(b"bar").read(b"").build();
541
542        mock_acts_as_tether_mock(builder, test).await;
543    }
544
545    #[tokio::test]
546    async fn immediate_error() {
547        let test = async |mut reader: Box<dyn ReadWrite>| {
548            let mut output = String::new();
549            let result = reader.read_to_string(&mut output).await;
550            format!("{:?}", result)
551        };
552
553        let builder = || {
554            Builder::new()
555                .read_error(std::io::Error::other("oops!"))
556                .build()
557        };
558
559        mock_acts_as_tether_mock(builder, test).await;
560    }
561
562    #[tokio::test]
563    async fn basic_write() {
564        let mock = || Builder::new().write(b"foo").write(b"bar").build();
565
566        let mut tether = Tether::connect(MockConnector(mock), Once).await.unwrap();
567        tether.write_all(b"foo").await.unwrap();
568        tether.write_all(b"bar").await.unwrap(); // should trigger error which is propagated
569    }
570
571    #[tokio::test]
572    async fn failure_to_connect_doesnt_panic() {
573        struct Unreachable;
574        impl<T> Resolver<T> for Unreachable {
575            fn disconnected(&mut self, context: &Context, _connector: &mut T) -> PinFut<bool> {
576                let _reason = context.reason(); // This should not panic
577                Box::pin(async move { false })
578            }
579        }
580
581        let result = Tether::connect_tcp("0.0.0.0:3150", Unreachable).await;
582        assert!(result.is_err());
583    }
584
585    #[tokio::test]
586    async fn read_then_disconnect() {
587        struct AllowEof;
588        impl<T> Resolver<T> for AllowEof {
589            fn disconnected(&mut self, context: &Context, _connector: &mut T) -> PinFut<bool> {
590                let value = !matches!(context.reason(), Reason::Eof); // Don't reconnect on EoF
591                Box::pin(async move { value })
592            }
593        }
594
595        let mock = Builder::new().read(b"foobarbaz").read(b"").build();
596        let mut count = 0;
597        // After each read call we error
598        let b = move |v: &[u8]| Builder::new().read(v).read_error(other("error")).build();
599        let gener = move || {
600            let result = match count {
601                0 => b(b"foo"),
602                1 => b(b"bar"),
603                2 => b(b"baz"),
604                _ => Builder::new().read(b"").build(),
605            };
606
607            count += 1;
608            result
609        };
610
611        let test = async |mut reader: Box<dyn ReadWrite>| {
612            let mut output = String::new();
613            reader.read_to_string(&mut output).await.unwrap();
614            output
615        };
616
617        let tether_mock = Tether::connect(MockConnector(gener), AllowEof)
618            .await
619            .unwrap();
620
621        tester(test, mock, tether_mock).await
622    }
623
624    #[tokio::test]
625    async fn split_works() {
626        let listener = TcpListener::bind("0.0.0.0:0").await.unwrap();
627        let addr = listener.local_addr().unwrap();
628        tokio::spawn(async move {
629            loop {
630                let (mut stream, _addr) = listener.accept().await.unwrap();
631                stream.write_all(b"foobar").await.unwrap();
632                stream.shutdown().await.unwrap();
633            }
634        });
635
636        let stream = Tether::connect_tcp(addr, Once).await.unwrap();
637        let (mut read, mut write) = tokio::io::split(stream);
638        let mut buf = [0u8; 6];
639        read.read_exact(&mut buf).await.unwrap(); // Disconnect happens here
640        assert_eq!(&buf, b"foobar");
641        write.write_all(b"foobar").await.unwrap(); // Reconnect is triggered
642    }
643
644    #[tokio::test]
645    async fn reconnect_value_is_respected() {
646        let listener = TcpListener::bind("0.0.0.0:0").await.unwrap();
647        let addr = listener.local_addr().unwrap();
648        tokio::spawn(async move {
649            let (mut stream, _addr) = listener.accept().await.unwrap();
650            stream.write_all(b"foobar").await.unwrap();
651            stream.shutdown().await.unwrap();
652        });
653
654        // We set it to not reconnect, thus we expect this to work exactly as though we had not
655        // wrapped the connector in a tether at all
656        let mut stream = Tether::connect_tcp(addr, Value(false)).await.unwrap();
657        let mut output = String::new();
658        stream.read_to_string(&mut output).await.unwrap();
659        assert_eq!(&output, "foobar");
660    }
661
662    #[tokio::test]
663    async fn disconnect_is_retried() {
664        let listener = TcpListener::bind("0.0.0.0:0").await.unwrap();
665        let addr = listener.local_addr().unwrap();
666        tokio::spawn(async move {
667            let mut connections = 0;
668            loop {
669                let (mut stream, _addr) = listener.accept().await.unwrap();
670                stream.write_u8(connections).await.unwrap();
671                connections += 1;
672            }
673        });
674
675        let mut stream = Tether::connect_tcp(addr, Once).await.unwrap();
676        let mut buf = Vec::new();
677        stream.read_to_end(&mut buf).await.unwrap();
678        assert_eq!(buf.as_slice(), &[0, 1])
679    }
680}