io_tether/lib.rs
1#![doc = include_str!("../README.md")]
2use std::{future::Future, io::ErrorKind, pin::Pin};
3
4#[cfg(feature = "fs")]
5pub mod fs;
6mod implementations;
7#[cfg(feature = "net")]
8pub mod tcp;
9#[cfg(all(feature = "net", target_family = "unix"))]
10pub mod unix;
11
12/// A dynamically dispatched static future
13pub type PinFut<O> = Pin<Box<dyn Future<Output = O> + 'static + Send>>;
14
15/// Represents a type which drives reconnects
16///
17/// Since the disconnected method asynchronous, and is invoked when the underlying stream
18/// disconnects, calling asynchronous functions like
19/// [`tokio::time::sleep`](https://docs.rs/tokio/latest/tokio/time/fn.sleep.html) from within the
20/// body, work.
21///
22/// # Unpin
23///
24/// Since the method provides `&mut Self`, Self must be [`Unpin`]
25///
26/// # Return Type
27///
28/// The return types of the methods are [`PinFut`]. This has the requirement that the returned
29/// future be 'static (cannot hold references to self, or any of the arguments). However, you are
30/// still free to mutate data outside of the returned future.
31///
32/// Additionally, this method is invoked each time the I/O fails to establish a connection so
33/// writing futures which do not reference their environment is a little easier than it may seem.
34///
35/// # Example
36///
37/// A very simple implementation may look something like the following:
38///
39/// ```no_run
40/// # use std::time::Duration;
41/// # use io_tether::{Context, Reason, Resolver, PinFut};
42/// pub struct RetryResolver(bool);
43///
44/// impl<C> Resolver<C> for RetryResolver {
45/// fn disconnected(&mut self, context: &Context, _: &mut C) -> PinFut<bool> {
46/// let reason = context.reason();
47/// println!("WARN: Disconnected from server {:?}", reason);
48/// self.0 = true;
49///
50/// if context.current_reconnect_attempts() >= 5 || context.total_reconnect_attempts() >= 50 {
51/// return Box::pin(async move {false});
52/// }
53///
54/// Box::pin(async move {
55/// tokio::time::sleep(Duration::from_secs(10)).await;
56/// true
57/// })
58/// }
59/// }
60/// ```
61pub trait Resolver<C> {
62 /// Invoked by Tether when an error/disconnect is encountered.
63 ///
64 /// Returning `true` will result in a reconnect being attempted via `<T as Io>::reconnect`,
65 /// returning `false` will result in the error being returned from the originating call.
66 fn disconnected(&mut self, context: &Context, connector: &mut C) -> PinFut<bool>;
67
68 /// Invoked within [`Tether::connect`] if the initial connection attempt fails
69 ///
70 /// As with [`Self::disconnected`] the returned boolean determines whether the initial
71 /// connection attempt is retried
72 ///
73 /// Defaults to invoking [`Self::disconnected`]
74 fn unreachable(&mut self, context: &Context, connector: &mut C) -> PinFut<bool> {
75 self.disconnected(context, connector)
76 }
77
78 /// Invoked within [`Tether::connect`] if the initial connection attempt succeeds
79 ///
80 /// Defaults to invoking [`Self::reconnected`]
81 fn established(&mut self, context: &Context) -> PinFut<()> {
82 self.reconnected(context)
83 }
84
85 /// Invoked by Tether whenever the connection to the underlying I/O source has been
86 /// re-established
87 fn reconnected(&mut self, _context: &Context) -> PinFut<()> {
88 Box::pin(std::future::ready(()))
89 }
90}
91
92/// Represents an I/O source capable of reconnecting
93///
94/// This trait is implemented for a number of types in the library, with the implementations placed
95/// behind feature flags
96pub trait Io {
97 type Output;
98
99 /// Initializes the connection to the I/O source
100 fn connect(&mut self) -> PinFut<Result<Self::Output, std::io::Error>>;
101
102 /// Re-establishes the connection to the I/O source
103 fn reconnect(&mut self) -> PinFut<Result<Self::Output, std::io::Error>> {
104 self.connect()
105 }
106}
107
108/// Enum representing reasons for a disconnect
109#[derive(Debug)]
110#[non_exhaustive]
111pub enum Reason {
112 /// Represents the end of the file for the underlying io
113 ///
114 /// This can occur when the end of a file is read from the file system, when the remote socket
115 /// on a TCP connection is closed, etc. Generally it indicates a successful end of the
116 /// connection
117 Eof,
118 /// An I/O Error occurred
119 Err(std::io::Error),
120}
121
122impl std::fmt::Display for Reason {
123 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
124 match self {
125 Reason::Eof => f.write_str("End of file detected"),
126 Reason::Err(error) => error.fmt(f),
127 }
128 }
129}
130
131impl std::error::Error for Reason {}
132
133impl Reason {
134 /// A convenience function which returns whether the original error is capable of being retried
135 pub fn retryable(&self) -> bool {
136 use std::io::ErrorKind as Kind;
137
138 match self {
139 Reason::Eof => true,
140 Reason::Err(error) => matches!(
141 error.kind(),
142 Kind::NotFound
143 | Kind::PermissionDenied
144 | Kind::ConnectionRefused
145 | Kind::ConnectionAborted
146 | Kind::ConnectionReset
147 | Kind::NotConnected
148 | Kind::AlreadyExists
149 | Kind::HostUnreachable
150 | Kind::AddrNotAvailable
151 | Kind::NetworkDown
152 | Kind::BrokenPipe
153 | Kind::TimedOut
154 | Kind::UnexpectedEof
155 | Kind::NetworkUnreachable
156 | Kind::AddrInUse
157 ),
158 }
159 }
160
161 fn take(&mut self) -> Self {
162 std::mem::replace(self, Self::Eof)
163 }
164}
165
166impl From<Reason> for std::io::Error {
167 fn from(value: Reason) -> Self {
168 match value {
169 Reason::Eof => std::io::Error::new(ErrorKind::UnexpectedEof, "Eof error"),
170 Reason::Err(error) => error,
171 }
172 }
173}
174
175/// A wrapper type which contains the underlying I/O object, it's initializer, and resolver.
176///
177/// This in the main type exposed by the library. It implements [`AsyncRead`](tokio::io::AsyncRead)
178/// and [`AsyncWrite`](tokio::io::AsyncWrite) whenever the underlying I/O object implements them.
179///
180/// Calling things like
181/// [`read_buf`](https://docs.rs/tokio/latest/tokio/io/trait.AsyncReadExt.html#method.read_buf) will
182/// result in the I/O automatically reconnecting if an error is detected during the underlying I/O
183/// call.
184///
185/// # Example
186///
187/// ## Basic Resolver
188///
189/// Below is an example of a basic resolver which just logs the error and retries
190///
191/// ```no_run
192/// # use io_tether::*;
193/// # async fn foo() -> Result<(), Box<dyn std::error::Error>> {
194/// struct MyResolver;
195///
196/// impl<C> Resolver<C> for MyResolver {
197/// fn disconnected(&mut self, context: &Context, _: &mut C) -> PinFut<bool> {
198/// println!("WARN(disconnect): {:?}", context);
199/// Box::pin(async move { true }) // always immediately retry the connection
200/// }
201/// }
202///
203/// let stream = Tether::connect_tcp("localhost:8080", MyResolver).await?;
204///
205/// // Regardless of which half detects the disconnect, a reconnect will be attempted
206/// let (read, write) = tokio::io::split(stream);
207/// # Ok(()) }
208/// ```
209///
210/// # Specialized Resolver
211///
212/// For more specialized use cases we can implement [`Resolver`] only for certain connectors to give
213/// us extra control over the reconnect process.
214///
215/// ```
216/// # use io_tether::{*, tcp::TcpConnector};
217/// # use std::net::{SocketAddrV4, Ipv4Addr};
218/// struct MyResolver;
219///
220/// type Connector = TcpConnector<SocketAddrV4>;
221///
222/// impl Resolver<Connector> for MyResolver {
223/// fn disconnected(&mut self, context: &Context, conn: &mut Connector) -> PinFut<bool> {
224/// // Because we've specialized our resolver to act on TcpConnector for IPv4, we can alter
225/// // the address in between the disconnect, and the reconnect, to try a different host
226/// conn.get_addr_mut().set_ip(Ipv4Addr::LOCALHOST);
227/// conn.get_addr_mut().set_port(8082);
228///
229/// Box::pin(async move { true }) // always immediately retry the connection
230/// }
231/// }
232/// ```
233///
234/// # Note
235///
236/// Currently, there is no way to obtain a reference into the underlying I/O object. And the only
237/// way to reclaim the inner I/O type is by calling [`Tether::into_inner`].
238pub struct Tether<C: Io, R> {
239 state: State<C::Output>,
240 inner: TetherInner<C, R>,
241}
242
243/// The inner type for tether.
244///
245/// Helps satisfy the borrow checker when we need to mutate this while holding a mutable ref to the
246/// larger futs state machine
247struct TetherInner<C: Io, R> {
248 context: Context,
249 connector: C,
250 io: C::Output,
251 resolver: R,
252}
253
254impl<C: Io, R: Resolver<C>> TetherInner<C, R> {
255 fn disconnected(&mut self) -> PinFut<bool> {
256 self.resolver
257 .disconnected(&self.context, &mut self.connector)
258 }
259
260 fn reconnected(&mut self) -> PinFut<()> {
261 self.resolver.reconnected(&self.context)
262 }
263}
264
265impl<C: Io, R> Tether<C, R> {
266 /// Returns a reference to the inner resolver
267 pub fn resolver(&self) -> &R {
268 &self.inner.resolver
269 }
270
271 /// Returns a reference to the inner connector
272 pub fn connector(&self) -> &C {
273 &self.inner.connector
274 }
275}
276
277impl<C, R> Tether<C, R>
278where
279 C: Io,
280 R: Resolver<C>,
281{
282 /// Construct a tether object from an existing I/O source
283 ///
284 /// # Warning
285 ///
286 /// Unlike [`Tether::connect`], this method does not invoke the resolver's `established` method.
287 /// It is generally recommended that you use [`Tether::connect`].
288 pub fn new(connector: C, io: C::Output, resolver: R) -> Self {
289 Self::new_with_context(connector, io, resolver, Context::default())
290 }
291
292 fn new_with_context(connector: C, io: C::Output, resolver: R, context: Context) -> Self {
293 Self {
294 state: Default::default(),
295 inner: TetherInner {
296 context,
297 io,
298 resolver,
299 connector,
300 },
301 }
302 }
303
304 fn reconnect(&mut self) {
305 self.state = State::Connected;
306 self.inner.context.reset();
307 }
308
309 /// Consume the Tether, and return the underlying I/O type
310 #[inline]
311 pub fn into_inner(self) -> C::Output {
312 self.inner.io
313 }
314
315 /// Connect to the I/O source, retrying on a failure.
316 pub async fn connect(mut connector: C, mut resolver: R) -> Result<Self, std::io::Error> {
317 let mut context = Context::default();
318
319 loop {
320 let state = match connector.connect().await {
321 Ok(io) => {
322 resolver.established(&context).await;
323 context.reset();
324 return Ok(Self::new_with_context(connector, io, resolver, context));
325 }
326 Err(error) => Reason::Err(error),
327 };
328
329 context.increment_attempts();
330
331 if !resolver.unreachable(&context, &mut connector).await {
332 let Reason::Err(error) = state else {
333 unreachable!("state is immutable and established as Err above");
334 };
335
336 return Err(error);
337 }
338 }
339 }
340
341 /// Connect to the I/O source, bypassing [`Resolver::unreachable`] implementation on a failure.
342 ///
343 /// This does still invoke [`Resolver::established`] if the connection is made successfully.
344 /// To bypass both, construct the IO source and pass it to [`Self::new`].
345 pub async fn connect_without_retry(
346 mut connector: C,
347 mut resolver: R,
348 ) -> Result<Self, std::io::Error> {
349 let context = Context::default();
350
351 let io = connector.connect().await?;
352 resolver.established(&context).await;
353 Ok(Self::new_with_context(connector, io, resolver, context))
354 }
355}
356
357/// The internal state machine which drives the connection and reconnect logic
358#[derive(Default)]
359enum State<T> {
360 #[default]
361 Connected,
362 Disconnected(PinFut<bool>),
363 Reconnecting(PinFut<Result<T, std::io::Error>>),
364 Reconnected(PinFut<()>),
365}
366
367/// Contains additional information about the disconnect
368///
369/// This type internally tracks the number of times a disconnect has occurred, and the reason for
370/// the disconnect.
371#[derive(Debug)]
372pub struct Context {
373 total_attempts: usize,
374 current_attempts: usize,
375 reason: Reason,
376}
377
378impl Default for Context {
379 fn default() -> Self {
380 Self {
381 total_attempts: 0,
382 current_attempts: 0,
383 reason: Reason::Eof,
384 }
385 }
386}
387
388impl Context {
389 /// The number of reconnect attempts since the last successful connection. Reset each time
390 /// the connection is established
391 #[inline]
392 pub fn current_reconnect_attempts(&self) -> usize {
393 self.current_attempts
394 }
395
396 /// The total number of times a reconnect has been attempted.
397 ///
398 /// The first time [`Resolver::disconnected`] or [`Resolver::unreachable`] is invoked this will
399 /// return `0`, each subsequent time it will be incremented by 1.
400 #[inline]
401 pub fn total_reconnect_attempts(&self) -> usize {
402 self.total_attempts
403 }
404
405 fn increment_attempts(&mut self) {
406 self.current_attempts += 1;
407 self.total_attempts += 1;
408 }
409
410 /// Get the current reason for the disconnect
411 #[inline]
412 pub fn reason(&self) -> &Reason {
413 &self.reason
414 }
415
416 /// Resets the current attempts, leaving the total reconnect attempts unchanged
417 #[inline]
418 fn reset(&mut self) {
419 self.current_attempts = 0;
420 }
421}
422
423pub(crate) mod ready {
424 macro_rules! ready {
425 ($e:expr $(,)?) => {
426 match $e {
427 std::task::Poll::Ready(t) => t,
428 std::task::Poll::Pending => return std::task::Poll::Pending,
429 }
430 };
431 }
432
433 pub(crate) use ready;
434}
435
436#[cfg(test)]
437mod tests {
438 use tokio::{
439 io::{AsyncReadExt, AsyncWriteExt},
440 net::TcpListener,
441 };
442
443 use super::*;
444
445 struct Once;
446
447 impl<T> Resolver<T> for Once {
448 fn disconnected(&mut self, context: &Context, _connector: &mut T) -> PinFut<bool> {
449 let retry = context.total_reconnect_attempts() < 1;
450
451 Box::pin(async move { retry })
452 }
453 }
454
455 #[tokio::test]
456 async fn disconnect_is_retried() {
457 let listener = TcpListener::bind("0.0.0.0:0").await.unwrap();
458 let addr = listener.local_addr().unwrap();
459 tokio::spawn(async move {
460 let mut connections = 0;
461 loop {
462 let (mut stream, _addr) = listener.accept().await.unwrap();
463 stream.write_u8(connections).await.unwrap();
464 connections += 1;
465 }
466 });
467
468 let mut stream = Tether::connect_tcp(addr, Once).await.unwrap();
469 let mut buf = Vec::new();
470 assert!(stream.read_to_end(&mut buf).await.is_err());
471 assert_eq!(buf.as_slice(), &[0, 1])
472 }
473}