1#![doc = include_str!("../README.md")]
2use std::{future::Future, io::ErrorKind, pin::Pin};
3
4#[cfg(feature = "fs")]
5pub mod fs;
6mod implementations;
7#[cfg(feature = "net")]
8pub mod tcp;
9#[cfg(all(feature = "net", target_family = "unix"))]
10pub mod unix;
11
12pub type PinFut<O> = Pin<Box<dyn Future<Output = O> + 'static + Send>>;
14
15pub trait Resolver<C> {
62 fn disconnected(&mut self, context: &Context, connector: &mut C) -> PinFut<bool>;
67
68 fn unreachable(&mut self, context: &Context, connector: &mut C) -> PinFut<bool> {
75 self.disconnected(context, connector)
76 }
77
78 fn established(&mut self, context: &Context) -> PinFut<()> {
82 self.reconnected(context)
83 }
84
85 fn reconnected(&mut self, _context: &Context) -> PinFut<()> {
88 Box::pin(std::future::ready(()))
89 }
90}
91
92pub trait Io {
97 type Output;
98
99 fn connect(&mut self) -> PinFut<Result<Self::Output, std::io::Error>>;
101
102 fn reconnect(&mut self) -> PinFut<Result<Self::Output, std::io::Error>> {
104 self.connect()
105 }
106}
107
108#[derive(Debug)]
110#[non_exhaustive]
111pub enum Reason {
112 Eof,
118 Err(std::io::Error),
120}
121
122impl std::fmt::Display for Reason {
123 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
124 match self {
125 Reason::Eof => f.write_str("End of file detected"),
126 Reason::Err(error) => error.fmt(f),
127 }
128 }
129}
130
131impl std::error::Error for Reason {}
132
133impl Reason {
134 pub fn retryable(&self) -> bool {
136 use std::io::ErrorKind as Kind;
137
138 match self {
139 Reason::Eof => true,
140 Reason::Err(error) => matches!(
141 error.kind(),
142 Kind::NotFound
143 | Kind::PermissionDenied
144 | Kind::ConnectionRefused
145 | Kind::ConnectionAborted
146 | Kind::ConnectionReset
147 | Kind::NotConnected
148 | Kind::AlreadyExists
149 | Kind::HostUnreachable
150 | Kind::AddrNotAvailable
151 | Kind::NetworkDown
152 | Kind::BrokenPipe
153 | Kind::TimedOut
154 | Kind::UnexpectedEof
155 | Kind::NetworkUnreachable
156 | Kind::AddrInUse
157 ),
158 }
159 }
160}
161
162impl From<Reason> for std::io::Error {
163 fn from(value: Reason) -> Self {
164 match value {
165 Reason::Eof => std::io::Error::new(ErrorKind::UnexpectedEof, "Eof error"),
166 Reason::Err(error) => error,
167 }
168 }
169}
170
171pub struct Tether<C: Io, R> {
235 state: State<C::Output>,
236 inner: TetherInner<C, R>,
237}
238
239struct TetherInner<C: Io, R> {
244 context: Context,
245 connector: C,
246 io: C::Output,
247 resolver: R,
248}
249
250impl<C: Io, R: Resolver<C>> TetherInner<C, R> {
251 fn disconnected(&mut self) -> PinFut<bool> {
252 self.resolver
253 .disconnected(&self.context, &mut self.connector)
254 }
255
256 fn reconnected(&mut self) -> PinFut<()> {
257 self.resolver.reconnected(&self.context)
258 }
259}
260
261impl<C: Io, R> Tether<C, R> {
262 pub fn resolver(&self) -> &R {
264 &self.inner.resolver
265 }
266
267 pub fn connector(&self) -> &C {
269 &self.inner.connector
270 }
271}
272
273impl<C, R> Tether<C, R>
274where
275 C: Io,
276 R: Resolver<C>,
277{
278 pub fn new(connector: C, io: C::Output, resolver: R) -> Self {
285 Self::new_with_context(connector, io, resolver, Context::default())
286 }
287
288 fn new_with_context(connector: C, io: C::Output, resolver: R, context: Context) -> Self {
289 Self {
290 state: Default::default(),
291 inner: TetherInner {
292 context,
293 io,
294 resolver,
295 connector,
296 },
297 }
298 }
299
300 fn set_connected(&mut self) {
301 self.state = State::Connected;
302 self.inner.context.reset();
303 }
304
305 fn set_reconnected(&mut self) {
306 let fut = self.inner.reconnected();
307 self.state = State::Reconnected(fut);
308 }
309
310 fn set_reconnecting(&mut self) {
311 let fut = self.inner.connector.reconnect();
312 self.state = State::Reconnecting(fut);
313 }
314
315 fn set_disconnected(&mut self, reason: Reason) {
316 self.inner.context.reason = Some(reason);
317 let fut = self.inner.disconnected();
318 self.state = State::Disconnected(fut);
319 }
320
321 #[inline]
323 pub fn into_inner(self) -> C::Output {
324 self.inner.io
325 }
326
327 pub async fn connect(mut connector: C, mut resolver: R) -> Result<Self, std::io::Error> {
329 let mut context = Context::default();
330
331 loop {
332 let state = match connector.connect().await {
333 Ok(io) => {
334 resolver.established(&context).await;
335 context.reset();
336 return Ok(Self::new_with_context(connector, io, resolver, context));
337 }
338 Err(error) => Reason::Err(error),
339 };
340
341 context.increment_attempts();
342
343 if !resolver.unreachable(&context, &mut connector).await {
344 let Reason::Err(error) = state else {
345 unreachable!("state is immutable and established as Err above");
346 };
347
348 return Err(error);
349 }
350 }
351 }
352
353 pub async fn connect_without_retry(
358 mut connector: C,
359 mut resolver: R,
360 ) -> Result<Self, std::io::Error> {
361 let context = Context::default();
362
363 let io = connector.connect().await?;
364 resolver.established(&context).await;
365 Ok(Self::new_with_context(connector, io, resolver, context))
366 }
367}
368
369#[derive(Default)]
371enum State<T> {
372 #[default]
373 Connected,
374 Disconnected(PinFut<bool>),
375 Reconnecting(PinFut<Result<T, std::io::Error>>),
376 Reconnected(PinFut<()>),
377}
378
379#[derive(Default, Debug)]
384pub struct Context {
385 total_attempts: usize,
386 current_attempts: usize,
387 reason: Option<Reason>,
388}
389
390impl Context {
391 #[inline]
394 pub fn current_reconnect_attempts(&self) -> usize {
395 self.current_attempts
396 }
397
398 #[inline]
403 pub fn total_reconnect_attempts(&self) -> usize {
404 self.total_attempts
405 }
406
407 fn increment_attempts(&mut self) {
408 self.current_attempts += 1;
409 self.total_attempts += 1;
410 }
411
412 #[inline]
419 pub fn reason(&self) -> &Reason {
420 self.try_reason().unwrap()
421 }
422
423 #[inline]
425 pub fn try_reason(&self) -> Option<&Reason> {
426 self.reason.as_ref()
427 }
428
429 #[inline]
431 fn reset(&mut self) {
432 self.current_attempts = 0;
433 }
434}
435
436pub(crate) mod ready {
437 macro_rules! ready {
438 ($e:expr $(,)?) => {
439 match $e {
440 std::task::Poll::Ready(t) => t,
441 std::task::Poll::Pending => return std::task::Poll::Pending,
442 }
443 };
444 }
445
446 pub(crate) use ready;
447}
448
449#[cfg(test)]
450mod tests {
451 use tokio::{
452 io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
453 net::TcpListener,
454 };
455 use tokio_test::io::{Builder, Mock};
456
457 use super::*;
458
459 struct Value(bool);
460
461 impl<T> Resolver<T> for Value {
462 fn disconnected(&mut self, _context: &Context, _connector: &mut T) -> PinFut<bool> {
463 let val = self.0;
464 Box::pin(async move { val })
465 }
466 }
467
468 struct Once;
469
470 impl<T> Resolver<T> for Once {
471 fn disconnected(&mut self, context: &Context, _connector: &mut T) -> PinFut<bool> {
472 let retry = context.total_reconnect_attempts() < 1;
473
474 Box::pin(async move { retry })
475 }
476 }
477
478 fn other(err: &'static str) -> std::io::Error {
479 std::io::Error::other(err)
480 }
481
482 trait ReadWrite: 'static + AsyncRead + AsyncWrite + Unpin {}
483 impl<T: 'static + AsyncRead + AsyncWrite + Unpin> ReadWrite for T {}
484
485 struct MockConnector<F>(F);
486
487 impl<F: FnMut() -> Mock> Io for MockConnector<F> {
488 type Output = Mock;
489
490 fn connect(&mut self) -> PinFut<Result<Self::Output, std::io::Error>> {
491 let value = self.0();
492
493 Box::pin(async move { Ok(value) })
494 }
495 }
496
497 async fn tester<A>(test: A, mock: impl ReadWrite, tether: impl ReadWrite)
498 where
499 A: AsyncFn(Box<dyn ReadWrite>) -> String,
500 {
501 let mock_result = (test)(Box::new(mock)).await;
502 let tether_result = (test)(Box::new(tether)).await;
503
504 assert_eq!(mock_result, tether_result);
505 }
506
507 async fn mock_acts_as_tether_mock<F, A>(mut gener: F, test: A)
508 where
509 F: FnMut() -> Mock + 'static + Unpin,
510 A: AsyncFn(Box<dyn ReadWrite>) -> String,
511 {
512 let mock = gener();
513 let tether_mock = Tether::connect(MockConnector(gener), Value(false))
514 .await
515 .unwrap();
516
517 tester(test, mock, tether_mock).await
518 }
519
520 #[tokio::test]
521 async fn single_read_then_eof() {
522 let test = async |mut reader: Box<dyn ReadWrite>| {
523 let mut output = String::new();
524 reader.read_to_string(&mut output).await.unwrap();
525 output
526 };
527
528 mock_acts_as_tether_mock(|| Builder::new().read(b"foobar").read(b"").build(), test).await;
529 }
530
531 #[tokio::test]
532 async fn two_read_then_eof() {
533 let test = async |mut reader: Box<dyn ReadWrite>| {
534 let mut output = String::new();
535 reader.read_to_string(&mut output).await.unwrap();
536 output
537 };
538
539 let builder = || Builder::new().read(b"foo").read(b"bar").read(b"").build();
540
541 mock_acts_as_tether_mock(builder, test).await;
542 }
543
544 #[tokio::test]
545 async fn immediate_error() {
546 let test = async |mut reader: Box<dyn ReadWrite>| {
547 let mut output = String::new();
548 let result = reader.read_to_string(&mut output).await;
549 format!("{:?}", result)
550 };
551
552 let builder = || {
553 Builder::new()
554 .read_error(std::io::Error::other("oops!"))
555 .build()
556 };
557
558 mock_acts_as_tether_mock(builder, test).await;
559 }
560
561 #[tokio::test]
562 async fn basic_write() {
563 let mock = || Builder::new().write(b"foo").write(b"bar").build();
564
565 let mut tether = Tether::connect(MockConnector(mock), Once).await.unwrap();
566 tether.write_all(b"foo").await.unwrap();
567 tether.write_all(b"bar").await.unwrap(); }
569
570 #[tokio::test]
571 async fn read_then_disconnect() {
572 struct AllowEof;
573 impl<T> Resolver<T> for AllowEof {
574 fn disconnected(&mut self, context: &Context, _connector: &mut T) -> PinFut<bool> {
575 let value = !matches!(context.reason(), Reason::Eof); Box::pin(async move { value })
577 }
578 }
579
580 let mock = Builder::new().read(b"foobarbaz").read(b"").build();
581 let mut count = 0;
582 let b = move |v: &[u8]| Builder::new().read(v).read_error(other("error")).build();
584 let gener = move || {
585 let result = match count {
586 0 => b(b"foo"),
587 1 => b(b"bar"),
588 2 => b(b"baz"),
589 _ => Builder::new().read(b"").build(),
590 };
591
592 count += 1;
593 result
594 };
595
596 let test = async |mut reader: Box<dyn ReadWrite>| {
597 let mut output = String::new();
598 reader.read_to_string(&mut output).await.unwrap();
599 output
600 };
601
602 let tether_mock = Tether::connect(MockConnector(gener), AllowEof)
603 .await
604 .unwrap();
605
606 tester(test, mock, tether_mock).await
607 }
608
609 #[tokio::test]
610 async fn split_works() {
611 let listener = TcpListener::bind("0.0.0.0:0").await.unwrap();
612 let addr = listener.local_addr().unwrap();
613 tokio::spawn(async move {
614 loop {
615 let (mut stream, _addr) = listener.accept().await.unwrap();
616 stream.write_all(b"foobar").await.unwrap();
617 stream.shutdown().await.unwrap();
618 }
619 });
620
621 let stream = Tether::connect_tcp(addr, Once).await.unwrap();
622 let (mut read, mut write) = tokio::io::split(stream);
623 let mut buf = [0u8; 6];
624 read.read_exact(&mut buf).await.unwrap(); assert_eq!(&buf, b"foobar");
626 write.write_all(b"foobar").await.unwrap(); }
628
629 #[tokio::test]
630 async fn reconnect_value_is_respected() {
631 let listener = TcpListener::bind("0.0.0.0:0").await.unwrap();
632 let addr = listener.local_addr().unwrap();
633 tokio::spawn(async move {
634 let (mut stream, _addr) = listener.accept().await.unwrap();
635 stream.write_all(b"foobar").await.unwrap();
636 stream.shutdown().await.unwrap();
637 });
638
639 let mut stream = Tether::connect_tcp(addr, Value(false)).await.unwrap();
642 let mut output = String::new();
643 stream.read_to_string(&mut output).await.unwrap();
644 assert_eq!(&output, "foobar");
645 }
646
647 #[tokio::test]
648 async fn disconnect_is_retried() {
649 let listener = TcpListener::bind("0.0.0.0:0").await.unwrap();
650 let addr = listener.local_addr().unwrap();
651 tokio::spawn(async move {
652 let mut connections = 0;
653 loop {
654 let (mut stream, _addr) = listener.accept().await.unwrap();
655 stream.write_u8(connections).await.unwrap();
656 connections += 1;
657 }
658 });
659
660 let mut stream = Tether::connect_tcp(addr, Once).await.unwrap();
661 let mut buf = Vec::new();
662 stream.read_to_end(&mut buf).await.unwrap();
663 assert_eq!(buf.as_slice(), &[0, 1])
664 }
665}