io_tether/lib.rs
1#![doc = include_str!("../README.md")]
2use std::{future::Future, io::ErrorKind, pin::Pin};
3
4#[cfg(feature = "fs")]
5pub mod fs;
6mod implementations;
7#[cfg(feature = "net")]
8pub mod tcp;
9#[cfg(all(feature = "net", target_family = "unix"))]
10pub mod unix;
11
12/// A dynamically dispatched static future
13pub type PinFut<O> = Pin<Box<dyn Future<Output = O> + 'static + Send>>;
14
15/// Represents a type which drives reconnects
16///
17/// Since the disconnected method asynchronous, and is invoked when the underlying stream
18/// disconnects, calling asynchronous functions like
19/// [`tokio::time::sleep`](https://docs.rs/tokio/latest/tokio/time/fn.sleep.html) from within the
20/// body, work.
21///
22/// # Unpin
23///
24/// Since the method provides `&mut Self`, Self must be [`Unpin`]
25///
26/// # Return Type
27///
28/// The return types of the methods are [`PinFut`]. This has the requirement that the returned
29/// future be 'static (cannot hold references to self, or any of the arguments). However, you are
30/// still free to mutate data outside of the returned future.
31///
32/// Additionally, this method is invoked each time the I/O fails to establish a connection so
33/// writing futures which do not reference their environment is a little easier than it may seem.
34///
35/// # Example
36///
37/// A very simple implementation may look something like the following:
38///
39/// ```no_run
40/// # use std::time::Duration;
41/// # use io_tether::{Context, Reason, Resolver, PinFut};
42/// pub struct RetryResolver(bool);
43///
44/// impl<C> Resolver<C> for RetryResolver {
45/// fn disconnected(&mut self, context: &Context, _: &mut C) -> PinFut<bool> {
46/// let reason = context.reason();
47/// println!("WARN: Disconnected from server {:?}", reason);
48/// self.0 = true;
49///
50/// if context.current_reconnect_attempts() >= 5 || context.total_reconnect_attempts() >= 50 {
51/// return Box::pin(async move {false});
52/// }
53///
54/// Box::pin(async move {
55/// tokio::time::sleep(Duration::from_secs(10)).await;
56/// true
57/// })
58/// }
59/// }
60/// ```
61pub trait Resolver<C> {
62 /// Invoked by Tether when an error/disconnect is encountered.
63 ///
64 /// Returning `true` will result in a reconnect being attempted via `<T as Io>::reconnect`,
65 /// returning `false` will result in the error being returned from the originating call.
66 fn disconnected(&mut self, context: &Context, connector: &mut C) -> PinFut<bool>;
67
68 /// Invoked within [`Tether::connect`] if the initial connection attempt fails
69 ///
70 /// As with [`Self::disconnected`] the returned boolean determines whether the initial
71 /// connection attempt is retried
72 ///
73 /// Defaults to invoking [`Self::disconnected`]
74 fn unreachable(&mut self, context: &Context, connector: &mut C) -> PinFut<bool> {
75 self.disconnected(context, connector)
76 }
77
78 /// Invoked within [`Tether::connect`] if the initial connection attempt succeeds
79 ///
80 /// Defaults to invoking [`Self::reconnected`]
81 fn established(&mut self, context: &Context) -> PinFut<()> {
82 self.reconnected(context)
83 }
84
85 /// Invoked by Tether whenever the connection to the underlying I/O source has been
86 /// re-established
87 fn reconnected(&mut self, _context: &Context) -> PinFut<()> {
88 Box::pin(std::future::ready(()))
89 }
90}
91
92/// Represents an I/O source capable of reconnecting
93///
94/// This trait is implemented for a number of types in the library, with the implementations placed
95/// behind feature flags
96pub trait Io {
97 type Output;
98
99 /// Initializes the connection to the I/O source
100 fn connect(&mut self) -> PinFut<Result<Self::Output, std::io::Error>>;
101
102 /// Re-establishes the connection to the I/O source
103 fn reconnect(&mut self) -> PinFut<Result<Self::Output, std::io::Error>> {
104 self.connect()
105 }
106}
107
108/// Enum representing reasons for a disconnect
109#[derive(Debug)]
110#[non_exhaustive]
111pub enum Reason {
112 /// Represents the end of the file for the underlying io
113 ///
114 /// This can occur when the end of a file is read from the file system, when the remote socket
115 /// on a TCP connection is closed, etc. Generally it indicates a successful end of the
116 /// connection
117 Eof,
118 /// An I/O Error occurred
119 Err(std::io::Error),
120}
121
122impl std::fmt::Display for Reason {
123 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
124 match self {
125 Reason::Eof => f.write_str("End of file detected"),
126 Reason::Err(error) => error.fmt(f),
127 }
128 }
129}
130
131impl std::error::Error for Reason {}
132
133impl Reason {
134 /// A convenience function which returns whether the original error is capable of being retried
135 pub fn retryable(&self) -> bool {
136 use std::io::ErrorKind as Kind;
137
138 match self {
139 Reason::Eof => true,
140 Reason::Err(error) => matches!(
141 error.kind(),
142 Kind::NotFound
143 | Kind::PermissionDenied
144 | Kind::ConnectionRefused
145 | Kind::ConnectionAborted
146 | Kind::ConnectionReset
147 | Kind::NotConnected
148 | Kind::AlreadyExists
149 | Kind::HostUnreachable
150 | Kind::AddrNotAvailable
151 | Kind::NetworkDown
152 | Kind::BrokenPipe
153 | Kind::TimedOut
154 | Kind::UnexpectedEof
155 | Kind::NetworkUnreachable
156 | Kind::AddrInUse
157 ),
158 }
159 }
160
161 fn take(&mut self) -> Self {
162 std::mem::replace(self, Self::Eof)
163 }
164}
165
166impl From<Reason> for std::io::Error {
167 fn from(value: Reason) -> Self {
168 match value {
169 Reason::Eof => std::io::Error::new(ErrorKind::UnexpectedEof, "Eof error"),
170 Reason::Err(error) => error,
171 }
172 }
173}
174
175/// A wrapper type which contains the underlying I/O object, it's initializer, and resolver.
176///
177/// This in the main type exposed by the library. It implements [`AsyncRead`](tokio::io::AsyncRead)
178/// and [`AsyncWrite`](tokio::io::AsyncWrite) whenever the underlying I/O object implements them.
179///
180/// Calling things like
181/// [`read_buf`](https://docs.rs/tokio/latest/tokio/io/trait.AsyncReadExt.html#method.read_buf) will
182/// result in the I/O automatically reconnecting if an error is detected during the underlying I/O
183/// call.
184///
185/// # Example
186///
187/// ## Basic Resolver
188///
189/// Below is an example of a basic resolver which just logs the error and retries
190///
191/// ```no_run
192/// # use io_tether::*;
193/// # async fn foo() -> Result<(), Box<dyn std::error::Error>> {
194/// struct MyResolver;
195///
196/// impl<C> Resolver<C> for MyResolver {
197/// fn disconnected(&mut self, context: &Context, _: &mut C) -> PinFut<bool> {
198/// println!("WARN(disconnect): {:?}", context);
199/// Box::pin(async move { true }) // always immediately retry the connection
200/// }
201/// }
202///
203/// let stream = Tether::connect_tcp("localhost:8080", MyResolver).await?;
204///
205/// // Regardless of which half detects the disconnect, a reconnect will be attempted
206/// let (read, write) = tokio::io::split(stream);
207/// # Ok(()) }
208/// ```
209///
210/// # Specialized Resolver
211///
212/// For more specialized use cases we can implement [`Resolver`] only for certain connectors to give
213/// us extra control over the reconnect process.
214///
215/// ```
216/// # use io_tether::{*, tcp::TcpConnector};
217/// # use std::net::{SocketAddrV4, Ipv4Addr};
218/// struct MyResolver;
219///
220/// type Connector = TcpConnector<SocketAddrV4>;
221///
222/// impl Resolver<Connector> for MyResolver {
223/// fn disconnected(&mut self, context: &Context, conn: &mut Connector) -> PinFut<bool> {
224/// // Because we've specialized our resolver to act on TcpConnector for IPv4, we can alter
225/// // the address in between the disconnect, and the reconnect, to try a different host
226/// conn.get_addr_mut().set_ip(Ipv4Addr::LOCALHOST);
227/// conn.get_addr_mut().set_port(8082);
228///
229/// Box::pin(async move { true }) // always immediately retry the connection
230/// }
231/// }
232/// ```
233///
234/// # Note
235///
236/// Currently, there is no way to obtain a reference into the underlying I/O object. And the only
237/// way to reclaim the inner I/O type is by calling [`Tether::into_inner`].
238pub struct Tether<C: Io, R> {
239 state: State<C::Output>,
240 inner: TetherInner<C, R>,
241}
242
243/// The inner type for tether.
244///
245/// Helps satisfy the borrow checker when we need to mutate this while holding a mutable ref to the
246/// larger futs state machine
247struct TetherInner<C: Io, R> {
248 context: Context,
249 connector: C,
250 io: C::Output,
251 resolver: R,
252}
253
254impl<C: Io, R: Resolver<C>> TetherInner<C, R> {
255 fn disconnected(&mut self) -> PinFut<bool> {
256 self.resolver
257 .disconnected(&self.context, &mut self.connector)
258 }
259
260 fn reconnected(&mut self) -> PinFut<()> {
261 self.resolver.reconnected(&self.context)
262 }
263}
264
265impl<C, R> Tether<C, R>
266where
267 C: Io,
268 R: Resolver<C>,
269{
270 /// Construct a tether object from an existing I/O source
271 ///
272 /// # Note
273 ///
274 /// Often a simpler way to construct a [`Tether`] object is through [`Tether::connect`]
275 pub fn new(connector: C, io: C::Output, resolver: R) -> Self {
276 Self::new_with_context(connector, io, resolver, Context::default())
277 }
278
279 fn new_with_context(connector: C, io: C::Output, resolver: R, context: Context) -> Self {
280 Self {
281 state: Default::default(),
282 inner: TetherInner {
283 context,
284 io,
285 resolver,
286 connector,
287 },
288 }
289 }
290
291 fn reconnect(&mut self) {
292 self.state = State::Connected;
293 self.inner.context.reset();
294 }
295
296 /// Consume the Tether, and return the underlying I/O type
297 #[inline]
298 pub fn into_inner(self) -> C::Output {
299 self.inner.io
300 }
301
302 /// Connect to the I/O source, retrying on a failure.
303 pub async fn connect(mut connector: C, mut resolver: R) -> Result<Self, std::io::Error> {
304 let mut context = Context::default();
305
306 loop {
307 let state = match connector.connect().await {
308 Ok(io) => {
309 resolver.established(&context).await;
310 context.reset();
311 return Ok(Self::new_with_context(connector, io, resolver, context));
312 }
313 Err(error) => Reason::Err(error),
314 };
315
316 context.increment_attempts();
317
318 if !resolver.unreachable(&context, &mut connector).await {
319 let Reason::Err(error) = state else {
320 unreachable!("state is immutable and established as Err above");
321 };
322
323 return Err(error);
324 }
325 }
326 }
327
328 /// Connect to the I/O source, bypassing [`Resolver::unreachable`] implementation on a failure.
329 ///
330 /// This does still invoke [`Resolver::established`] if the connection is made successfully.
331 /// To bypass both, construct the IO source and pass it to [`Self::new`].
332 pub async fn connect_without_retry(
333 mut connector: C,
334 mut resolver: R,
335 ) -> Result<Self, std::io::Error> {
336 let context = Context::default();
337
338 let io = connector.connect().await?;
339 resolver.established(&context).await;
340 Ok(Self::new_with_context(connector, io, resolver, context))
341 }
342}
343
344/// The internal state machine which drives the connection and reconnect logic
345#[derive(Default)]
346enum State<T> {
347 #[default]
348 Connected,
349 Disconnected(PinFut<bool>),
350 Reconnecting(PinFut<Result<T, std::io::Error>>),
351 Reconnected(PinFut<()>),
352}
353
354/// Contains additional information about the disconnect
355///
356/// This type internally tracks the number of times a disconnect has occurred, and the reason for
357/// the disconnect.
358#[derive(Debug)]
359pub struct Context {
360 total_attempts: usize,
361 current_attempts: usize,
362 reason: Reason,
363}
364
365impl Default for Context {
366 fn default() -> Self {
367 Self {
368 total_attempts: 0,
369 current_attempts: 0,
370 reason: Reason::Eof,
371 }
372 }
373}
374
375impl Context {
376 /// The number of reconnect attempts since the last successful connection. Reset each time
377 /// the connection is established
378 #[inline]
379 pub fn current_reconnect_attempts(&self) -> usize {
380 self.current_attempts
381 }
382
383 /// The total number of times a reconnect has been attempted.
384 ///
385 /// The first time [`Resolver::disconnected`] or [`Resolver::unreachable`] is invoked this will
386 /// return `0`, each subsequent time it will be incremented by 1.
387 #[inline]
388 pub fn total_reconnect_attempts(&self) -> usize {
389 self.total_attempts
390 }
391
392 fn increment_attempts(&mut self) {
393 self.current_attempts += 1;
394 self.total_attempts += 1;
395 }
396
397 /// Get the current reason for the disconnect
398 #[inline]
399 pub fn reason(&self) -> &Reason {
400 &self.reason
401 }
402
403 /// Resets the current attempts, leaving the total reconnect attempts unchanged
404 #[inline]
405 fn reset(&mut self) {
406 self.current_attempts = 0;
407 }
408}
409
410pub(crate) mod ready {
411 macro_rules! ready {
412 ($e:expr $(,)?) => {
413 match $e {
414 std::task::Poll::Ready(t) => t,
415 std::task::Poll::Pending => return std::task::Poll::Pending,
416 }
417 };
418 }
419
420 pub(crate) use ready;
421}
422
423#[cfg(test)]
424mod tests {
425 use tokio::{
426 io::{AsyncReadExt, AsyncWriteExt},
427 net::TcpListener,
428 };
429
430 use super::*;
431
432 struct Once;
433
434 impl<T> Resolver<T> for Once {
435 fn disconnected(&mut self, context: &Context, _connector: &mut T) -> PinFut<bool> {
436 let retry = context.total_reconnect_attempts() < 1;
437
438 Box::pin(async move { retry })
439 }
440 }
441
442 #[tokio::test]
443 async fn disconnect_is_retried() {
444 let listener = TcpListener::bind("0.0.0.0:0").await.unwrap();
445 let addr = listener.local_addr().unwrap();
446 tokio::spawn(async move {
447 let mut connections = 0;
448 loop {
449 let (mut stream, _addr) = listener.accept().await.unwrap();
450 stream.write_u8(connections).await.unwrap();
451 connections += 1;
452 }
453 });
454
455 let mut stream = Tether::connect_tcp(addr, Once).await.unwrap();
456 let mut buf = Vec::new();
457 assert!(stream.read_to_end(&mut buf).await.is_err());
458 assert_eq!(buf.as_slice(), &[0, 1])
459 }
460}