io_tether/
lib.rs

1#![doc = include_str!("../README.md")]
2use std::{future::Future, io::ErrorKind, pin::Pin};
3
4#[cfg(feature = "fs")]
5pub mod fs;
6mod implementations;
7#[cfg(feature = "net")]
8pub mod tcp;
9#[cfg(all(feature = "net", target_family = "unix"))]
10pub mod unix;
11
12/// A dynamically dispatched static future
13pub type PinFut<O> = Pin<Box<dyn Future<Output = O> + 'static + Send>>;
14
15/// Represents a type which drives reconnects
16///
17/// Since the disconnected method asynchronous, and is invoked when the underlying stream
18/// disconnects, calling asynchronous functions like
19/// [`tokio::time::sleep`](https://docs.rs/tokio/latest/tokio/time/fn.sleep.html) from within the
20/// body, work.
21///
22/// # Unpin
23///
24/// Since the method provides `&mut Self`, Self must be [`Unpin`]
25///
26/// # Return Type
27///
28/// The return types of the methods are [`PinFut`]. This has the requirement that the returned
29/// future be 'static (cannot hold references to self, or any of the arguments). However, you are
30/// still free to mutate data outside of the returned future.
31///
32/// Additionally, this method is invoked each time the I/O fails to establish a connection so
33/// writing futures which do not reference their environment is a little easier than it may seem.
34///
35/// # Example
36///
37/// A very simple implementation may look something like the following:
38///
39/// ```no_run
40/// # use std::time::Duration;
41/// # use io_tether::{Context, Reason, Resolver, PinFut};
42/// pub struct RetryResolver(bool);
43///
44/// impl<C> Resolver<C> for RetryResolver {
45///     fn disconnected(&mut self, context: &Context, _: &mut C) -> PinFut<bool> {
46///         let reason = context.reason();
47///         println!("WARN: Disconnected from server {:?}", reason);
48///         self.0 = true;
49///
50///         if context.current_reconnect_attempts() >= 5 || context.total_reconnect_attempts() >= 50 {
51///             return Box::pin(async move {false});
52///         }
53///
54///         Box::pin(async move {
55///             tokio::time::sleep(Duration::from_secs(10)).await;
56///             true
57///         })
58///     }
59/// }
60/// ```
61pub trait Resolver<C> {
62    /// Invoked by Tether when an error/disconnect is encountered.
63    ///
64    /// Returning `true` will result in a reconnect being attempted via `<T as Io>::reconnect`,
65    /// returning `false` will result in the error being returned from the originating call.
66    fn disconnected(&mut self, context: &Context, connector: &mut C) -> PinFut<bool>;
67
68    /// Invoked within [`Tether::connect`] if the initial connection attempt fails
69    ///
70    /// As with [`Self::disconnected`] the returned boolean determines whether the initial
71    /// connection attempt is retried
72    ///
73    /// Defaults to invoking [`Self::disconnected`]
74    fn unreachable(&mut self, context: &Context, connector: &mut C) -> PinFut<bool> {
75        self.disconnected(context, connector)
76    }
77
78    /// Invoked within [`Tether::connect`] if the initial connection attempt succeeds
79    ///
80    /// Defaults to invoking [`Self::reconnected`]
81    fn established(&mut self, context: &Context) -> PinFut<()> {
82        self.reconnected(context)
83    }
84
85    /// Invoked by Tether whenever the connection to the underlying I/O source has been
86    /// re-established
87    fn reconnected(&mut self, _context: &Context) -> PinFut<()> {
88        Box::pin(std::future::ready(()))
89    }
90}
91
92/// Represents an I/O source capable of reconnecting
93///
94/// This trait is implemented for a number of types in the library, with the implementations placed
95/// behind feature flags
96pub trait Io {
97    type Output;
98
99    /// Initializes the connection to the I/O source
100    fn connect(&mut self) -> PinFut<Result<Self::Output, std::io::Error>>;
101
102    /// Re-establishes the connection to the I/O source
103    fn reconnect(&mut self) -> PinFut<Result<Self::Output, std::io::Error>> {
104        self.connect()
105    }
106}
107
108/// Enum representing reasons for a disconnect
109#[derive(Debug)]
110#[non_exhaustive]
111pub enum Reason {
112    /// Represents the end of the file for the underlying io
113    ///
114    /// This can occur when the end of a file is read from the file system, when the remote socket
115    /// on a TCP connection is closed, etc. Generally it indicates a successful end of the
116    /// connection
117    Eof,
118    /// An I/O Error occurred
119    Err(std::io::Error),
120}
121
122impl std::fmt::Display for Reason {
123    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
124        match self {
125            Reason::Eof => f.write_str("End of file detected"),
126            Reason::Err(error) => error.fmt(f),
127        }
128    }
129}
130
131impl std::error::Error for Reason {}
132
133impl Reason {
134    /// A convenience function which returns whether the original error is capable of being retried
135    pub fn retryable(&self) -> bool {
136        use std::io::ErrorKind as Kind;
137
138        match self {
139            Reason::Eof => true,
140            Reason::Err(error) => matches!(
141                error.kind(),
142                Kind::NotFound
143                    | Kind::PermissionDenied
144                    | Kind::ConnectionRefused
145                    | Kind::ConnectionAborted
146                    | Kind::ConnectionReset
147                    | Kind::NotConnected
148                    | Kind::AlreadyExists
149                    | Kind::HostUnreachable
150                    | Kind::AddrNotAvailable
151                    | Kind::NetworkDown
152                    | Kind::BrokenPipe
153                    | Kind::TimedOut
154                    | Kind::UnexpectedEof
155                    | Kind::NetworkUnreachable
156                    | Kind::AddrInUse
157            ),
158        }
159    }
160}
161
162impl From<Reason> for std::io::Error {
163    fn from(value: Reason) -> Self {
164        match value {
165            Reason::Eof => std::io::Error::new(ErrorKind::UnexpectedEof, "Eof error"),
166            Reason::Err(error) => error,
167        }
168    }
169}
170
171/// A wrapper type which contains the underlying I/O object, it's initializer, and resolver.
172///
173/// This in the main type exposed by the library. It implements [`AsyncRead`](tokio::io::AsyncRead)
174/// and [`AsyncWrite`](tokio::io::AsyncWrite) whenever the underlying I/O object implements them.
175///
176/// Calling things like
177/// [`read_buf`](https://docs.rs/tokio/latest/tokio/io/trait.AsyncReadExt.html#method.read_buf) will
178/// result in the I/O automatically reconnecting if an error is detected during the underlying I/O
179/// call.
180///
181/// # Example
182///
183/// ## Basic Resolver
184///
185/// Below is an example of a basic resolver which just logs the error and retries
186///
187/// ```no_run
188/// # use io_tether::*;
189/// # async fn foo() -> Result<(), Box<dyn std::error::Error>> {
190/// struct MyResolver;
191///
192/// impl<C> Resolver<C> for MyResolver {
193///     fn disconnected(&mut self, context: &Context, _: &mut C) -> PinFut<bool> {
194///         println!("WARN(disconnect): {:?}", context);
195///         Box::pin(async move { true }) // always immediately retry the connection
196///     }
197/// }
198///
199/// let stream = Tether::connect_tcp("localhost:8080", MyResolver).await?;
200///
201/// // Regardless of which half detects the disconnect, a reconnect will be attempted
202/// let (read, write) = tokio::io::split(stream);
203/// # Ok(()) }
204/// ```
205///
206/// # Specialized Resolver
207///
208/// For more specialized use cases we can implement [`Resolver`] only for certain connectors to give
209/// us extra control over the reconnect process.
210///
211/// ```
212/// # use io_tether::{*, tcp::TcpConnector};
213/// # use std::net::{SocketAddrV4, Ipv4Addr};
214/// struct MyResolver;
215///
216/// type Connector = TcpConnector<SocketAddrV4>;
217///
218/// impl Resolver<Connector> for MyResolver {
219///     fn disconnected(&mut self, context: &Context, conn: &mut Connector) -> PinFut<bool> {
220///         // Because we've specialized our resolver to act on TcpConnector for IPv4, we can alter
221///         // the address in between the disconnect, and the reconnect, to try a different host
222///         conn.get_addr_mut().set_ip(Ipv4Addr::LOCALHOST);
223///         conn.get_addr_mut().set_port(8082);
224///
225///         Box::pin(async move { true }) // always immediately retry the connection
226///     }
227/// }
228/// ```
229///
230/// # Note
231///
232/// Currently, there is no way to obtain a reference into the underlying I/O object. And the only
233/// way to reclaim the inner I/O type is by calling [`Tether::into_inner`].
234pub struct Tether<C: Io, R> {
235    state: State<C::Output>,
236    inner: TetherInner<C, R>,
237}
238
239/// The inner type for tether.
240///
241/// Helps satisfy the borrow checker when we need to mutate this while holding a mutable ref to the
242/// larger futs state machine
243struct TetherInner<C: Io, R> {
244    context: Context,
245    connector: C,
246    io: C::Output,
247    resolver: R,
248}
249
250impl<C: Io, R: Resolver<C>> TetherInner<C, R> {
251    fn disconnected(&mut self) -> PinFut<bool> {
252        self.resolver
253            .disconnected(&self.context, &mut self.connector)
254    }
255
256    fn reconnected(&mut self) -> PinFut<()> {
257        self.resolver.reconnected(&self.context)
258    }
259}
260
261impl<C: Io, R> Tether<C, R> {
262    /// Returns a reference to the inner resolver
263    pub fn resolver(&self) -> &R {
264        &self.inner.resolver
265    }
266
267    /// Returns a reference to the inner connector
268    pub fn connector(&self) -> &C {
269        &self.inner.connector
270    }
271}
272
273impl<C, R> Tether<C, R>
274where
275    C: Io,
276    R: Resolver<C>,
277{
278    /// Construct a tether object from an existing I/O source
279    ///
280    /// # Warning
281    ///
282    /// Unlike [`Tether::connect`], this method does not invoke the resolver's `established` method.
283    /// It is generally recommended that you use [`Tether::connect`].
284    pub fn new(connector: C, io: C::Output, resolver: R) -> Self {
285        Self::new_with_context(connector, io, resolver, Context::default())
286    }
287
288    fn new_with_context(connector: C, io: C::Output, resolver: R, context: Context) -> Self {
289        Self {
290            state: Default::default(),
291            inner: TetherInner {
292                context,
293                io,
294                resolver,
295                connector,
296            },
297        }
298    }
299
300    fn set_connected(&mut self) {
301        self.state = State::Connected;
302        self.inner.context.reset();
303    }
304
305    fn set_reconnected(&mut self) {
306        let fut = self.inner.reconnected();
307        self.state = State::Reconnected(fut);
308    }
309
310    fn set_reconnecting(&mut self) {
311        let fut = self.inner.connector.reconnect();
312        self.state = State::Reconnecting(fut);
313    }
314
315    fn set_disconnected(&mut self, reason: Reason) {
316        self.inner.context.reason = Some(reason);
317        let fut = self.inner.disconnected();
318        self.state = State::Disconnected(fut);
319    }
320
321    /// Consume the Tether, and return the underlying I/O type
322    #[inline]
323    pub fn into_inner(self) -> C::Output {
324        self.inner.io
325    }
326
327    /// Connect to the I/O source, retrying on a failure.
328    pub async fn connect(mut connector: C, mut resolver: R) -> Result<Self, std::io::Error> {
329        let mut context = Context::default();
330
331        loop {
332            let state = match connector.connect().await {
333                Ok(io) => {
334                    resolver.established(&context).await;
335                    context.reset();
336                    return Ok(Self::new_with_context(connector, io, resolver, context));
337                }
338                Err(error) => Reason::Err(error),
339            };
340
341            context.increment_attempts();
342
343            if !resolver.unreachable(&context, &mut connector).await {
344                let Reason::Err(error) = state else {
345                    unreachable!("state is immutable and established as Err above");
346                };
347
348                return Err(error);
349            }
350        }
351    }
352
353    /// Connect to the I/O source, bypassing [`Resolver::unreachable`] implementation on a failure.
354    ///
355    /// This does still invoke [`Resolver::established`] if the connection is made successfully.
356    /// To bypass both, construct the IO source and pass it to [`Self::new`].
357    pub async fn connect_without_retry(
358        mut connector: C,
359        mut resolver: R,
360    ) -> Result<Self, std::io::Error> {
361        let context = Context::default();
362
363        let io = connector.connect().await?;
364        resolver.established(&context).await;
365        Ok(Self::new_with_context(connector, io, resolver, context))
366    }
367}
368
369/// The internal state machine which drives the connection and reconnect logic
370#[derive(Default)]
371enum State<T> {
372    #[default]
373    Connected,
374    Disconnected(PinFut<bool>),
375    Reconnecting(PinFut<Result<T, std::io::Error>>),
376    Reconnected(PinFut<()>),
377}
378
379/// Contains additional information about the disconnect
380///
381/// This type internally tracks the number of times a disconnect has occurred, and the reason for
382/// the disconnect.
383#[derive(Default, Debug)]
384pub struct Context {
385    total_attempts: usize,
386    current_attempts: usize,
387    reason: Option<Reason>,
388}
389
390impl Context {
391    /// The number of reconnect attempts since the last successful connection. Reset each time
392    /// the connection is established
393    #[inline]
394    pub fn current_reconnect_attempts(&self) -> usize {
395        self.current_attempts
396    }
397
398    /// The total number of times a reconnect has been attempted.
399    ///
400    /// The first time [`Resolver::disconnected`] or [`Resolver::unreachable`] is invoked this will
401    /// return `0`, each subsequent time it will be incremented by 1.
402    #[inline]
403    pub fn total_reconnect_attempts(&self) -> usize {
404        self.total_attempts
405    }
406
407    fn increment_attempts(&mut self) {
408        self.current_attempts += 1;
409        self.total_attempts += 1;
410    }
411
412    /// Get the current reason for the disconnect
413    ///
414    /// # Panics
415    ///
416    /// Might, panic if called outside of the methods in resolver. Will also panic if called AFTER
417    /// and error has been returned
418    #[inline]
419    pub fn reason(&self) -> &Reason {
420        self.try_reason().unwrap()
421    }
422
423    /// Get the current optional reason for the disconnect
424    #[inline]
425    pub fn try_reason(&self) -> Option<&Reason> {
426        self.reason.as_ref()
427    }
428
429    /// Resets the current attempts, leaving the total reconnect attempts unchanged
430    #[inline]
431    fn reset(&mut self) {
432        self.current_attempts = 0;
433    }
434}
435
436pub(crate) mod ready {
437    macro_rules! ready {
438        ($e:expr $(,)?) => {
439            match $e {
440                std::task::Poll::Ready(t) => t,
441                std::task::Poll::Pending => return std::task::Poll::Pending,
442            }
443        };
444    }
445
446    pub(crate) use ready;
447}
448
449#[cfg(test)]
450mod tests {
451    use tokio::{
452        io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
453        net::TcpListener,
454    };
455    use tokio_test::io::{Builder, Mock};
456
457    use super::*;
458
459    struct Value(bool);
460
461    impl<T> Resolver<T> for Value {
462        fn disconnected(&mut self, _context: &Context, _connector: &mut T) -> PinFut<bool> {
463            let val = self.0;
464            Box::pin(async move { val })
465        }
466    }
467
468    struct Once;
469
470    impl<T> Resolver<T> for Once {
471        fn disconnected(&mut self, context: &Context, _connector: &mut T) -> PinFut<bool> {
472            let retry = context.total_reconnect_attempts() < 1;
473
474            Box::pin(async move { retry })
475        }
476    }
477
478    fn other(err: &'static str) -> std::io::Error {
479        std::io::Error::other(err)
480    }
481
482    trait ReadWrite: 'static + AsyncRead + AsyncWrite + Unpin {}
483    impl<T: 'static + AsyncRead + AsyncWrite + Unpin> ReadWrite for T {}
484
485    struct MockConnector<F>(F);
486
487    impl<F: FnMut() -> Mock> Io for MockConnector<F> {
488        type Output = Mock;
489
490        fn connect(&mut self) -> PinFut<Result<Self::Output, std::io::Error>> {
491            let value = self.0();
492
493            Box::pin(async move { Ok(value) })
494        }
495    }
496
497    async fn tester<A>(test: A, mock: impl ReadWrite, tether: impl ReadWrite)
498    where
499        A: AsyncFn(Box<dyn ReadWrite>) -> String,
500    {
501        let mock_result = (test)(Box::new(mock)).await;
502        let tether_result = (test)(Box::new(tether)).await;
503
504        assert_eq!(mock_result, tether_result);
505    }
506
507    async fn mock_acts_as_tether_mock<F, A>(mut gener: F, test: A)
508    where
509        F: FnMut() -> Mock + 'static + Unpin,
510        A: AsyncFn(Box<dyn ReadWrite>) -> String,
511    {
512        let mock = gener();
513        let tether_mock = Tether::connect(MockConnector(gener), Value(false))
514            .await
515            .unwrap();
516
517        tester(test, mock, tether_mock).await
518    }
519
520    #[tokio::test]
521    async fn single_read_then_eof() {
522        let test = async |mut reader: Box<dyn ReadWrite>| {
523            let mut output = String::new();
524            reader.read_to_string(&mut output).await.unwrap();
525            output
526        };
527
528        mock_acts_as_tether_mock(|| Builder::new().read(b"foobar").read(b"").build(), test).await;
529    }
530
531    #[tokio::test]
532    async fn two_read_then_eof() {
533        let test = async |mut reader: Box<dyn ReadWrite>| {
534            let mut output = String::new();
535            reader.read_to_string(&mut output).await.unwrap();
536            output
537        };
538
539        let builder = || Builder::new().read(b"foo").read(b"bar").read(b"").build();
540
541        mock_acts_as_tether_mock(builder, test).await;
542    }
543
544    #[tokio::test]
545    async fn immediate_error() {
546        let test = async |mut reader: Box<dyn ReadWrite>| {
547            let mut output = String::new();
548            let result = reader.read_to_string(&mut output).await;
549            format!("{:?}", result)
550        };
551
552        let builder = || {
553            Builder::new()
554                .read_error(std::io::Error::other("oops!"))
555                .build()
556        };
557
558        mock_acts_as_tether_mock(builder, test).await;
559    }
560
561    #[tokio::test]
562    async fn basic_write() {
563        let mock = || Builder::new().write(b"foo").write(b"bar").build();
564
565        let mut tether = Tether::connect(MockConnector(mock), Once).await.unwrap();
566        tether.write_all(b"foo").await.unwrap();
567        tether.write_all(b"bar").await.unwrap(); // should trigger error which is propagated
568    }
569
570    #[tokio::test]
571    async fn read_then_disconnect() {
572        struct AllowEof;
573        impl<T> Resolver<T> for AllowEof {
574            fn disconnected(&mut self, context: &Context, _connector: &mut T) -> PinFut<bool> {
575                let value = !matches!(context.reason(), Reason::Eof); // Don't reconnect on EoF
576                Box::pin(async move { value })
577            }
578        }
579
580        let mock = Builder::new().read(b"foobarbaz").read(b"").build();
581        let mut count = 0;
582        // After each read call we error
583        let b = move |v: &[u8]| Builder::new().read(v).read_error(other("error")).build();
584        let gener = move || {
585            let result = match count {
586                0 => b(b"foo"),
587                1 => b(b"bar"),
588                2 => b(b"baz"),
589                _ => Builder::new().read(b"").build(),
590            };
591
592            count += 1;
593            result
594        };
595
596        let test = async |mut reader: Box<dyn ReadWrite>| {
597            let mut output = String::new();
598            reader.read_to_string(&mut output).await.unwrap();
599            output
600        };
601
602        let tether_mock = Tether::connect(MockConnector(gener), AllowEof)
603            .await
604            .unwrap();
605
606        tester(test, mock, tether_mock).await
607    }
608
609    #[tokio::test]
610    async fn split_works() {
611        let listener = TcpListener::bind("0.0.0.0:0").await.unwrap();
612        let addr = listener.local_addr().unwrap();
613        tokio::spawn(async move {
614            loop {
615                let (mut stream, _addr) = listener.accept().await.unwrap();
616                stream.write_all(b"foobar").await.unwrap();
617                stream.shutdown().await.unwrap();
618            }
619        });
620
621        let stream = Tether::connect_tcp(addr, Once).await.unwrap();
622        let (mut read, mut write) = tokio::io::split(stream);
623        let mut buf = [0u8; 6];
624        read.read_exact(&mut buf).await.unwrap(); // Disconnect happens here
625        assert_eq!(&buf, b"foobar");
626        write.write_all(b"foobar").await.unwrap(); // Reconnect is triggered
627    }
628
629    #[tokio::test]
630    async fn reconnect_value_is_respected() {
631        let listener = TcpListener::bind("0.0.0.0:0").await.unwrap();
632        let addr = listener.local_addr().unwrap();
633        tokio::spawn(async move {
634            let (mut stream, _addr) = listener.accept().await.unwrap();
635            stream.write_all(b"foobar").await.unwrap();
636            stream.shutdown().await.unwrap();
637        });
638
639        // We set it to not reconnect, thus we expect this to work exactly as though we had not
640        // wrapped the connector in a tether at all
641        let mut stream = Tether::connect_tcp(addr, Value(false)).await.unwrap();
642        let mut output = String::new();
643        stream.read_to_string(&mut output).await.unwrap();
644        assert_eq!(&output, "foobar");
645    }
646
647    #[tokio::test]
648    async fn disconnect_is_retried() {
649        let listener = TcpListener::bind("0.0.0.0:0").await.unwrap();
650        let addr = listener.local_addr().unwrap();
651        tokio::spawn(async move {
652            let mut connections = 0;
653            loop {
654                let (mut stream, _addr) = listener.accept().await.unwrap();
655                stream.write_u8(connections).await.unwrap();
656                connections += 1;
657            }
658        });
659
660        let mut stream = Tether::connect_tcp(addr, Once).await.unwrap();
661        let mut buf = Vec::new();
662        stream.read_to_end(&mut buf).await.unwrap();
663        assert_eq!(buf.as_slice(), &[0, 1])
664    }
665}