1#![doc = include_str!("../README.md")]
2use std::{future::Future, io::ErrorKind, pin::Pin};
3
4#[cfg(feature = "fs")]
5pub mod fs;
6mod implementations;
7#[cfg(feature = "net")]
8pub mod tcp;
9#[cfg(all(feature = "net", target_family = "unix"))]
10pub mod unix;
11
12pub type PinFut<O> = Pin<Box<dyn Future<Output = O> + 'static + Send>>;
14
15pub trait Resolver<C> {
62 fn disconnected(&mut self, context: &Context, connector: &mut C) -> PinFut<bool>;
67
68 fn unreachable(&mut self, context: &Context, connector: &mut C) -> PinFut<bool> {
75 self.disconnected(context, connector)
76 }
77
78 fn established(&mut self, context: &Context) -> PinFut<()> {
82 self.reconnected(context)
83 }
84
85 fn reconnected(&mut self, _context: &Context) -> PinFut<()> {
88 Box::pin(std::future::ready(()))
89 }
90}
91
92pub trait Io {
97 type Output;
98
99 fn connect(&mut self) -> PinFut<Result<Self::Output, std::io::Error>>;
101
102 fn reconnect(&mut self) -> PinFut<Result<Self::Output, std::io::Error>> {
104 self.connect()
105 }
106}
107
108#[derive(Debug)]
110#[non_exhaustive]
111pub enum Reason {
112 Eof,
118 Err(std::io::Error),
120}
121
122impl std::fmt::Display for Reason {
123 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
124 match self {
125 Reason::Eof => f.write_str("End of file detected"),
126 Reason::Err(error) => error.fmt(f),
127 }
128 }
129}
130
131impl std::error::Error for Reason {}
132
133impl Reason {
134 pub fn retryable(&self) -> bool {
136 use std::io::ErrorKind as Kind;
137
138 match self {
139 Reason::Eof => true,
140 Reason::Err(error) => matches!(
141 error.kind(),
142 Kind::NotFound
143 | Kind::PermissionDenied
144 | Kind::ConnectionRefused
145 | Kind::ConnectionAborted
146 | Kind::ConnectionReset
147 | Kind::NotConnected
148 | Kind::AlreadyExists
149 | Kind::HostUnreachable
150 | Kind::AddrNotAvailable
151 | Kind::NetworkDown
152 | Kind::BrokenPipe
153 | Kind::TimedOut
154 | Kind::UnexpectedEof
155 | Kind::NetworkUnreachable
156 | Kind::AddrInUse
157 ),
158 }
159 }
160}
161
162impl From<Reason> for std::io::Error {
163 fn from(value: Reason) -> Self {
164 match value {
165 Reason::Eof => std::io::Error::new(ErrorKind::UnexpectedEof, "Eof error"),
166 Reason::Err(error) => error,
167 }
168 }
169}
170
171pub struct Tether<C: Io, R> {
235 state: State<C::Output>,
236 inner: TetherInner<C, R>,
237}
238
239struct TetherInner<C: Io, R> {
244 context: Context,
245 connector: C,
246 io: C::Output,
247 resolver: R,
248}
249
250impl<C: Io, R: Resolver<C>> TetherInner<C, R> {
251 fn disconnected(&mut self) -> PinFut<bool> {
252 self.resolver
253 .disconnected(&self.context, &mut self.connector)
254 }
255
256 fn reconnected(&mut self) -> PinFut<()> {
257 self.resolver.reconnected(&self.context)
258 }
259}
260
261impl<C: Io, R> Tether<C, R> {
262 pub fn resolver(&self) -> &R {
264 &self.inner.resolver
265 }
266
267 pub fn connector(&self) -> &C {
269 &self.inner.connector
270 }
271}
272
273impl<C, R> Tether<C, R>
274where
275 C: Io,
276 R: Resolver<C>,
277{
278 pub fn new(connector: C, io: C::Output, resolver: R) -> Self {
285 Self::new_with_context(connector, io, resolver, Context::default())
286 }
287
288 fn new_with_context(connector: C, io: C::Output, resolver: R, context: Context) -> Self {
289 Self {
290 state: Default::default(),
291 inner: TetherInner {
292 context,
293 io,
294 resolver,
295 connector,
296 },
297 }
298 }
299
300 fn set_connected(&mut self) {
301 self.state = State::Connected;
302 self.inner.context.reset();
303 }
304
305 fn set_reconnected(&mut self) {
306 let fut = self.inner.reconnected();
307 self.state = State::Reconnected(fut);
308 }
309
310 fn set_reconnecting(&mut self) {
311 let fut = self.inner.connector.reconnect();
312 self.state = State::Reconnecting(fut);
313 }
314
315 fn set_disconnected(&mut self, reason: Reason) {
316 self.inner.context.reason = Some(reason);
317 let fut = self.inner.disconnected();
318 self.state = State::Disconnected(fut);
319 }
320
321 #[inline]
323 pub fn into_inner(self) -> C::Output {
324 self.inner.io
325 }
326
327 pub async fn connect(mut connector: C, mut resolver: R) -> Result<Self, std::io::Error> {
329 let mut context = Context::default();
330
331 loop {
332 let state = match connector.connect().await {
333 Ok(io) => {
334 resolver.established(&context).await;
335 context.reset();
336 return Ok(Self::new_with_context(connector, io, resolver, context));
337 }
338 Err(error) => error,
339 };
340
341 context.reason = Some(Reason::Err(state));
342 context.increment_attempts();
343
344 if !resolver.unreachable(&context, &mut connector).await {
345 let Some(Reason::Err(error)) = context.reason else {
346 unreachable!("state is immutable and established as Err above");
347 };
348
349 return Err(error);
350 }
351 }
352 }
353
354 pub async fn connect_without_retry(
359 mut connector: C,
360 mut resolver: R,
361 ) -> Result<Self, std::io::Error> {
362 let context = Context::default();
363
364 let io = connector.connect().await?;
365 resolver.established(&context).await;
366 Ok(Self::new_with_context(connector, io, resolver, context))
367 }
368}
369
370#[derive(Default)]
372enum State<T> {
373 #[default]
374 Connected,
375 Disconnected(PinFut<bool>),
376 Reconnecting(PinFut<Result<T, std::io::Error>>),
377 Reconnected(PinFut<()>),
378}
379
380#[derive(Default, Debug)]
385pub struct Context {
386 total_attempts: usize,
387 current_attempts: usize,
388 reason: Option<Reason>,
389}
390
391impl Context {
392 #[inline]
395 pub fn current_reconnect_attempts(&self) -> usize {
396 self.current_attempts
397 }
398
399 #[inline]
404 pub fn total_reconnect_attempts(&self) -> usize {
405 self.total_attempts
406 }
407
408 fn increment_attempts(&mut self) {
409 self.current_attempts += 1;
410 self.total_attempts += 1;
411 }
412
413 #[inline]
420 pub fn reason(&self) -> &Reason {
421 self.try_reason().unwrap()
422 }
423
424 #[inline]
426 pub fn try_reason(&self) -> Option<&Reason> {
427 self.reason.as_ref()
428 }
429
430 #[inline]
432 fn reset(&mut self) {
433 self.current_attempts = 0;
434 }
435}
436
437pub(crate) mod ready {
438 macro_rules! ready {
439 ($e:expr $(,)?) => {
440 match $e {
441 std::task::Poll::Ready(t) => t,
442 std::task::Poll::Pending => return std::task::Poll::Pending,
443 }
444 };
445 }
446
447 pub(crate) use ready;
448}
449
450#[cfg(test)]
451mod tests {
452 use tokio::{
453 io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
454 net::TcpListener,
455 };
456 use tokio_test::io::{Builder, Mock};
457
458 use super::*;
459
460 struct Value(bool);
461
462 impl<T> Resolver<T> for Value {
463 fn disconnected(&mut self, _context: &Context, _connector: &mut T) -> PinFut<bool> {
464 let val = self.0;
465 Box::pin(async move { val })
466 }
467 }
468
469 struct Once;
470
471 impl<T> Resolver<T> for Once {
472 fn disconnected(&mut self, context: &Context, _connector: &mut T) -> PinFut<bool> {
473 let retry = context.total_reconnect_attempts() < 1;
474
475 Box::pin(async move { retry })
476 }
477 }
478
479 fn other(err: &'static str) -> std::io::Error {
480 std::io::Error::other(err)
481 }
482
483 trait ReadWrite: 'static + AsyncRead + AsyncWrite + Unpin {}
484 impl<T: 'static + AsyncRead + AsyncWrite + Unpin> ReadWrite for T {}
485
486 struct MockConnector<F>(F);
487
488 impl<F: FnMut() -> Mock> Io for MockConnector<F> {
489 type Output = Mock;
490
491 fn connect(&mut self) -> PinFut<Result<Self::Output, std::io::Error>> {
492 let value = self.0();
493
494 Box::pin(async move { Ok(value) })
495 }
496 }
497
498 async fn tester<A>(test: A, mock: impl ReadWrite, tether: impl ReadWrite)
499 where
500 A: AsyncFn(Box<dyn ReadWrite>) -> String,
501 {
502 let mock_result = (test)(Box::new(mock)).await;
503 let tether_result = (test)(Box::new(tether)).await;
504
505 assert_eq!(mock_result, tether_result);
506 }
507
508 async fn mock_acts_as_tether_mock<F, A>(mut gener: F, test: A)
509 where
510 F: FnMut() -> Mock + 'static + Unpin,
511 A: AsyncFn(Box<dyn ReadWrite>) -> String,
512 {
513 let mock = gener();
514 let tether_mock = Tether::connect(MockConnector(gener), Value(false))
515 .await
516 .unwrap();
517
518 tester(test, mock, tether_mock).await
519 }
520
521 #[tokio::test]
522 async fn single_read_then_eof() {
523 let test = async |mut reader: Box<dyn ReadWrite>| {
524 let mut output = String::new();
525 reader.read_to_string(&mut output).await.unwrap();
526 output
527 };
528
529 mock_acts_as_tether_mock(|| Builder::new().read(b"foobar").read(b"").build(), test).await;
530 }
531
532 #[tokio::test]
533 async fn two_read_then_eof() {
534 let test = async |mut reader: Box<dyn ReadWrite>| {
535 let mut output = String::new();
536 reader.read_to_string(&mut output).await.unwrap();
537 output
538 };
539
540 let builder = || Builder::new().read(b"foo").read(b"bar").read(b"").build();
541
542 mock_acts_as_tether_mock(builder, test).await;
543 }
544
545 #[tokio::test]
546 async fn immediate_error() {
547 let test = async |mut reader: Box<dyn ReadWrite>| {
548 let mut output = String::new();
549 let result = reader.read_to_string(&mut output).await;
550 format!("{:?}", result)
551 };
552
553 let builder = || {
554 Builder::new()
555 .read_error(std::io::Error::other("oops!"))
556 .build()
557 };
558
559 mock_acts_as_tether_mock(builder, test).await;
560 }
561
562 #[tokio::test]
563 async fn basic_write() {
564 let mock = || Builder::new().write(b"foo").write(b"bar").build();
565
566 let mut tether = Tether::connect(MockConnector(mock), Once).await.unwrap();
567 tether.write_all(b"foo").await.unwrap();
568 tether.write_all(b"bar").await.unwrap(); }
570
571 #[tokio::test]
572 async fn failure_to_connect_doesnt_panic() {
573 struct Unreachable;
574 impl<T> Resolver<T> for Unreachable {
575 fn disconnected(&mut self, context: &Context, _connector: &mut T) -> PinFut<bool> {
576 let _reason = context.reason(); Box::pin(async move { false })
578 }
579 }
580
581 let result = Tether::connect_tcp("0.0.0.0:3150", Unreachable).await;
582 assert!(result.is_err());
583 }
584
585 #[tokio::test]
586 async fn read_then_disconnect() {
587 struct AllowEof;
588 impl<T> Resolver<T> for AllowEof {
589 fn disconnected(&mut self, context: &Context, _connector: &mut T) -> PinFut<bool> {
590 let value = !matches!(context.reason(), Reason::Eof); Box::pin(async move { value })
592 }
593 }
594
595 let mock = Builder::new().read(b"foobarbaz").read(b"").build();
596 let mut count = 0;
597 let b = move |v: &[u8]| Builder::new().read(v).read_error(other("error")).build();
599 let gener = move || {
600 let result = match count {
601 0 => b(b"foo"),
602 1 => b(b"bar"),
603 2 => b(b"baz"),
604 _ => Builder::new().read(b"").build(),
605 };
606
607 count += 1;
608 result
609 };
610
611 let test = async |mut reader: Box<dyn ReadWrite>| {
612 let mut output = String::new();
613 reader.read_to_string(&mut output).await.unwrap();
614 output
615 };
616
617 let tether_mock = Tether::connect(MockConnector(gener), AllowEof)
618 .await
619 .unwrap();
620
621 tester(test, mock, tether_mock).await
622 }
623
624 #[tokio::test]
625 async fn split_works() {
626 let listener = TcpListener::bind("0.0.0.0:0").await.unwrap();
627 let addr = listener.local_addr().unwrap();
628 tokio::spawn(async move {
629 loop {
630 let (mut stream, _addr) = listener.accept().await.unwrap();
631 stream.write_all(b"foobar").await.unwrap();
632 stream.shutdown().await.unwrap();
633 }
634 });
635
636 let stream = Tether::connect_tcp(addr, Once).await.unwrap();
637 let (mut read, mut write) = tokio::io::split(stream);
638 let mut buf = [0u8; 6];
639 read.read_exact(&mut buf).await.unwrap(); assert_eq!(&buf, b"foobar");
641 write.write_all(b"foobar").await.unwrap(); }
643
644 #[tokio::test]
645 async fn reconnect_value_is_respected() {
646 let listener = TcpListener::bind("0.0.0.0:0").await.unwrap();
647 let addr = listener.local_addr().unwrap();
648 tokio::spawn(async move {
649 let (mut stream, _addr) = listener.accept().await.unwrap();
650 stream.write_all(b"foobar").await.unwrap();
651 stream.shutdown().await.unwrap();
652 });
653
654 let mut stream = Tether::connect_tcp(addr, Value(false)).await.unwrap();
657 let mut output = String::new();
658 stream.read_to_string(&mut output).await.unwrap();
659 assert_eq!(&output, "foobar");
660 }
661
662 #[tokio::test]
663 async fn disconnect_is_retried() {
664 let listener = TcpListener::bind("0.0.0.0:0").await.unwrap();
665 let addr = listener.local_addr().unwrap();
666 tokio::spawn(async move {
667 let mut connections = 0;
668 loop {
669 let (mut stream, _addr) = listener.accept().await.unwrap();
670 stream.write_u8(connections).await.unwrap();
671 connections += 1;
672 }
673 });
674
675 let mut stream = Tether::connect_tcp(addr, Once).await.unwrap();
676 let mut buf = Vec::new();
677 stream.read_to_end(&mut buf).await.unwrap();
678 assert_eq!(buf.as_slice(), &[0, 1])
679 }
680}