1#![doc = include_str!("../README.md")]
2use std::{future::Future, io::ErrorKind, pin::Pin};
3
4pub mod config;
5#[cfg(feature = "fs")]
6pub mod fs;
7mod implementations;
8#[cfg(feature = "net")]
9pub mod tcp;
10#[cfg(all(feature = "net", target_family = "unix"))]
11pub mod unix;
12
13use config::Config;
14
15pub type PinFut<O> = Pin<Box<dyn Future<Output = O> + 'static + Send>>;
17
18pub trait Resolver<C> {
65 fn disconnected(&mut self, context: &Context, connector: &mut C) -> PinFut<Action>;
70
71 fn unreachable(&mut self, context: &Context, connector: &mut C) -> PinFut<bool> {
78 let fut = self.disconnected(context, connector);
79 Box::pin(async move {
80 match fut.await {
81 Action::AttemptReconnect => true,
82 Action::Exhaust | Action::Ignore => false,
83 }
84 })
85 }
86
87 fn established(&mut self, context: &Context) -> PinFut<()> {
91 self.reconnected(context)
92 }
93
94 fn reconnected(&mut self, _context: &Context) -> PinFut<()> {
97 Box::pin(std::future::ready(()))
98 }
99}
100
101pub trait Connector {
106 type Output;
107
108 fn connect(&mut self) -> PinFut<Result<Self::Output, std::io::Error>>;
110
111 fn reconnect(&mut self) -> PinFut<Result<Self::Output, std::io::Error>> {
113 self.connect()
114 }
115}
116
117#[derive(Debug)]
119#[non_exhaustive]
120pub enum Reason {
121 Eof,
127 Err(std::io::Error),
129}
130
131impl std::fmt::Display for Reason {
132 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
133 match self {
134 Reason::Eof => f.write_str("End of file detected"),
135 Reason::Err(error) => error.fmt(f),
136 }
137 }
138}
139
140impl std::error::Error for Reason {}
141
142impl Reason {
143 pub(crate) fn clone_private(&self) -> Self {
144 match self {
145 Reason::Eof => Self::Eof,
146 Reason::Err(error) => {
147 let kind = error.kind();
148 let error = std::io::Error::new(kind, error.to_string());
149 Self::Err(error)
150 }
151 }
152 }
153
154 pub fn retryable(&self) -> bool {
156 use std::io::ErrorKind as Kind;
157
158 match self {
159 Reason::Eof => true,
160 Reason::Err(error) => matches!(
161 error.kind(),
162 Kind::NotFound
163 | Kind::PermissionDenied
164 | Kind::ConnectionRefused
165 | Kind::ConnectionAborted
166 | Kind::ConnectionReset
167 | Kind::NotConnected
168 | Kind::AlreadyExists
169 | Kind::HostUnreachable
170 | Kind::AddrNotAvailable
171 | Kind::NetworkDown
172 | Kind::BrokenPipe
173 | Kind::TimedOut
174 | Kind::UnexpectedEof
175 | Kind::NetworkUnreachable
176 | Kind::AddrInUse
177 ),
178 }
179 }
180}
181
182impl From<Reason> for std::io::Error {
183 fn from(value: Reason) -> Self {
184 match value {
185 Reason::Eof => std::io::Error::new(ErrorKind::UnexpectedEof, "Eof error"),
186 Reason::Err(error) => error,
187 }
188 }
189}
190
191pub struct Tether<C: Connector, R> {
258 state: State<C::Output>,
259 inner: TetherInner<C, R>,
260}
261
262struct TetherInner<C: Connector, R> {
267 config: Config,
268 connector: C,
269 context: Context,
270 io: C::Output,
271 resolver: R,
272 last_write: Option<Reason>,
274}
275
276impl<C: Connector, R: Resolver<C>> TetherInner<C, R> {
277 fn set_connected(&mut self, state: &mut State<C::Output>) {
278 *state = State::Connected;
279 self.context.reset();
280 }
281
282 fn set_reconnected(&mut self, state: &mut State<C::Output>, new_io: <C as Connector>::Output) {
283 self.io = new_io;
284 let fut = self.resolver.reconnected(&self.context);
285 *state = State::Reconnected(fut);
286 }
287
288 fn set_reconnecting(&mut self, state: &mut State<C::Output>) {
289 let fut = self.connector.reconnect();
290 *state = State::Reconnecting(fut);
291 }
292
293 fn set_disconnected(&mut self, state: &mut State<C::Output>, reason: Reason, source: Source) {
294 self.context.reason = Some((reason, source));
295 let fut = self
296 .resolver
297 .disconnected(&self.context, &mut self.connector);
298 *state = State::Disconnected(fut);
299 }
300}
301
302impl<C: Connector, R> Tether<C, R> {
303 pub fn resolver(&self) -> &R {
305 &self.inner.resolver
306 }
307
308 pub fn connector(&self) -> &C {
310 &self.inner.connector
311 }
312
313 pub fn context(&self) -> &Context {
315 &self.inner.context
316 }
317}
318
319impl<C, R> Tether<C, R>
320where
321 C: Connector,
322 R: Resolver<C>,
323{
324 pub fn new(connector: C, io: C::Output, resolver: R) -> Self {
331 Self::new_with_config(connector, io, resolver, Config::default())
332 }
333
334 pub fn new_with_config(connector: C, io: C::Output, resolver: R, config: Config) -> Self {
335 Self::new_with_context(connector, io, resolver, Context::default(), config)
336 }
337
338 fn new_with_context(
339 connector: C,
340 io: C::Output,
341 resolver: R,
342 context: Context,
343 config: Config,
344 ) -> Self {
345 Self {
346 state: Default::default(),
347 inner: TetherInner {
348 config,
349 connector,
350 context,
351 io,
352 resolver,
353 last_write: None,
354 },
355 }
356 }
357
358 pub fn set_config(&mut self, config: Config) {
360 self.inner.config = config;
361 }
362
363 #[inline]
365 pub fn into_inner(self) -> C::Output {
366 self.inner.io
367 }
368
369 pub async fn connect(mut connector: C, mut resolver: R) -> Result<Self, std::io::Error> {
371 let mut context = Context::default();
372
373 loop {
374 let state = match connector.connect().await {
375 Ok(io) => {
376 resolver.established(&context).await;
377 context.reset();
378 return Ok(Self::new_with_context(
379 connector,
380 io,
381 resolver,
382 context,
383 Config::default(),
384 ));
385 }
386 Err(error) => error,
387 };
388
389 context.reason = Some((Reason::Err(state), Source::Reconnect));
390 context.increment_attempts();
391
392 if !resolver.unreachable(&context, &mut connector).await {
393 let Some((Reason::Err(error), _)) = context.reason else {
394 unreachable!("state is immutable and established as Err above");
395 };
396
397 return Err(error);
398 }
399 }
400 }
401
402 pub async fn connect_without_retry(
407 mut connector: C,
408 mut resolver: R,
409 ) -> Result<Self, std::io::Error> {
410 let context = Context::default();
411
412 let io = connector.connect().await?;
413 resolver.established(&context).await;
414 Ok(Self::new_with_context(
415 connector,
416 io,
417 resolver,
418 context,
419 Config::default(),
420 ))
421 }
422}
423
424#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
425pub enum Action {
426 AttemptReconnect,
428 Exhaust,
431 Ignore,
439}
440
441#[derive(Default)]
443enum State<T> {
444 #[default]
445 Connected,
446 Disconnected(PinFut<Action>),
447 Reconnecting(PinFut<Result<T, std::io::Error>>),
448 Reconnected(PinFut<()>),
449}
450
451#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
452enum Source {
453 Io,
454 Reconnect,
455}
456
457#[derive(Default, Debug)]
462pub struct Context {
463 total_attempts: usize,
464 current_attempts: usize,
465 reason: Option<(Reason, Source)>,
466}
467
468impl Context {
469 #[inline]
472 pub fn current_reconnect_attempts(&self) -> usize {
473 self.current_attempts
474 }
475
476 #[inline]
481 pub fn total_reconnect_attempts(&self) -> usize {
482 self.total_attempts
483 }
484
485 fn increment_attempts(&mut self) {
486 self.current_attempts += 1;
487 self.total_attempts += 1;
488 }
489
490 #[inline]
497 pub fn reason(&self) -> &Reason {
498 self.try_reason().unwrap()
499 }
500
501 #[inline]
503 pub fn try_reason(&self) -> Option<&Reason> {
504 self.reason.as_ref().map(|val| &val.0)
505 }
506
507 #[inline]
509 fn reset(&mut self) {
510 self.current_attempts = 0;
511 }
512}
513
514pub(crate) mod ready {
515 macro_rules! ready {
516 ($e:expr $(,)?) => {
517 match $e {
518 std::task::Poll::Ready(t) => t,
519 std::task::Poll::Pending => return std::task::Poll::Pending,
520 }
521 };
522 }
523
524 pub(crate) use ready;
525}
526
527#[cfg(test)]
528mod tests {
529 use tokio::{
530 io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
531 net::TcpListener,
532 };
533 use tokio_test::io::{Builder, Mock};
534
535 use super::*;
536
537 struct Value(Action);
538
539 impl<T> Resolver<T> for Value {
540 fn disconnected(&mut self, _context: &Context, _connector: &mut T) -> PinFut<Action> {
541 let val = self.0;
542 Box::pin(async move { val })
543 }
544 }
545
546 struct Once;
547
548 impl<T> Resolver<T> for Once {
549 fn disconnected(&mut self, context: &Context, _connector: &mut T) -> PinFut<Action> {
550 let retry = if context.total_reconnect_attempts() < 1 {
551 Action::AttemptReconnect
552 } else {
553 Action::Exhaust
554 };
555
556 Box::pin(async move { retry })
557 }
558 }
559
560 fn other(err: &'static str) -> std::io::Error {
561 std::io::Error::other(err)
562 }
563
564 trait ReadWrite: 'static + AsyncRead + AsyncWrite + Unpin {}
565 impl<T: 'static + AsyncRead + AsyncWrite + Unpin> ReadWrite for T {}
566
567 struct MockConnector<F>(F);
568
569 impl<F: FnMut() -> Mock> Connector for MockConnector<F> {
570 type Output = Mock;
571
572 fn connect(&mut self) -> PinFut<Result<Self::Output, std::io::Error>> {
573 let value = self.0();
574
575 Box::pin(async move { Ok(value) })
576 }
577 }
578
579 async fn tester<A>(test: A, mock: impl ReadWrite, tether: impl ReadWrite)
580 where
581 A: AsyncFn(Box<dyn ReadWrite>) -> String,
582 {
583 let mock_result = (test)(Box::new(mock)).await;
584 let tether_result = (test)(Box::new(tether)).await;
585
586 assert_eq!(mock_result, tether_result);
587 }
588
589 async fn mock_acts_as_tether_mock<F, A>(mut gener: F, test: A)
590 where
591 F: FnMut() -> Mock + 'static + Unpin,
592 A: AsyncFn(Box<dyn ReadWrite>) -> String,
593 {
594 let mock = gener();
595 let tether_mock = Tether::connect(MockConnector(gener), Value(Action::Exhaust))
596 .await
597 .unwrap();
598
599 tester(test, mock, tether_mock).await
600 }
601
602 #[tokio::test]
603 async fn single_read_then_eof() {
604 let test = async |mut reader: Box<dyn ReadWrite>| {
605 let mut output = String::new();
606 reader.read_to_string(&mut output).await.unwrap();
607 output
608 };
609
610 mock_acts_as_tether_mock(|| Builder::new().read(b"foobar").read(b"").build(), test).await;
611 }
612
613 #[tokio::test]
614 async fn two_read_then_eof() {
615 let test = async |mut reader: Box<dyn ReadWrite>| {
616 let mut output = String::new();
617 reader.read_to_string(&mut output).await.unwrap();
618 output
619 };
620
621 let builder = || Builder::new().read(b"foo").read(b"bar").read(b"").build();
622
623 mock_acts_as_tether_mock(builder, test).await;
624 }
625
626 #[tokio::test]
627 async fn immediate_error() {
628 let test = async |mut reader: Box<dyn ReadWrite>| {
629 let mut output = String::new();
630 let result = reader.read_to_string(&mut output).await;
631 format!("{:?}", result)
632 };
633
634 let builder = || {
635 Builder::new()
636 .read_error(std::io::Error::other("oops!"))
637 .build()
638 };
639
640 mock_acts_as_tether_mock(builder, test).await;
641 }
642
643 #[tokio::test]
644 async fn basic_write() {
645 let mock = || Builder::new().write(b"foo").write(b"bar").build();
646
647 let mut tether = Tether::connect(MockConnector(mock), Once).await.unwrap();
648 tether.write_all(b"foo").await.unwrap();
649 tether.write_all(b"bar").await.unwrap(); }
651
652 #[tokio::test]
653 async fn failure_to_connect_doesnt_panic() {
654 struct Unreachable;
655 impl<T> Resolver<T> for Unreachable {
656 fn disconnected(&mut self, context: &Context, _connector: &mut T) -> PinFut<Action> {
657 let _reason = context.reason(); Box::pin(async move { Action::Exhaust })
659 }
660 }
661
662 let result = Tether::connect_tcp("0.0.0.0:3150", Unreachable).await;
663 assert!(result.is_err());
664 }
665
666 #[tokio::test]
667 async fn read_then_disconnect() {
668 struct AllowEof;
669 impl<T> Resolver<T> for AllowEof {
670 fn disconnected(&mut self, context: &Context, _connector: &mut T) -> PinFut<Action> {
671 let value = if !matches!(context.reason(), Reason::Eof) {
673 Action::AttemptReconnect
674 } else {
675 Action::Exhaust
676 };
677 Box::pin(async move { value })
678 }
679 }
680
681 let mock = Builder::new().read(b"foobarbaz").read(b"").build();
682 let mut count = 0;
683 let b = move |v: &[u8]| Builder::new().read(v).read_error(other("error")).build();
685 let gener = move || {
686 let result = match count {
687 0 => b(b"foo"),
688 1 => b(b"bar"),
689 2 => b(b"baz"),
690 _ => Builder::new().read(b"").build(),
691 };
692
693 count += 1;
694 result
695 };
696
697 let test = async |mut reader: Box<dyn ReadWrite>| {
698 let mut output = String::new();
699 reader.read_to_string(&mut output).await.unwrap();
700 output
701 };
702
703 let tether_mock = Tether::connect(MockConnector(gener), AllowEof)
704 .await
705 .unwrap();
706
707 tester(test, mock, tether_mock).await
708 }
709
710 #[tokio::test]
711 async fn split_works() {
712 let listener = TcpListener::bind("0.0.0.0:0").await.unwrap();
713 let addr = listener.local_addr().unwrap();
714 tokio::spawn(async move {
715 loop {
716 let (mut stream, _addr) = listener.accept().await.unwrap();
717 stream.write_all(b"foobar").await.unwrap();
718 stream.shutdown().await.unwrap();
719 }
720 });
721
722 let stream = Tether::connect_tcp(addr, Once).await.unwrap();
723 let (mut read, mut write) = tokio::io::split(stream);
724 let mut buf = [0u8; 6];
725 read.read_exact(&mut buf).await.unwrap(); assert_eq!(&buf, b"foobar");
727 write.write_all(b"foobar").await.unwrap(); }
729
730 #[tokio::test]
731 async fn reconnect_value_is_respected() {
732 let listener = TcpListener::bind("0.0.0.0:0").await.unwrap();
733 let addr = listener.local_addr().unwrap();
734 tokio::spawn(async move {
735 let (mut stream, _addr) = listener.accept().await.unwrap();
736 stream.write_all(b"foobar").await.unwrap();
737 stream.shutdown().await.unwrap();
738 });
739
740 let mut stream = Tether::connect_tcp(addr, Value(Action::Exhaust))
743 .await
744 .unwrap();
745 let mut output = String::new();
746 stream.read_to_string(&mut output).await.unwrap();
747 assert_eq!(&output, "foobar");
748 }
749
750 #[tokio::test]
751 async fn disconnect_is_retried() {
752 let listener = TcpListener::bind("0.0.0.0:0").await.unwrap();
753 let addr = listener.local_addr().unwrap();
754 tokio::spawn(async move {
755 let mut connections = 0;
756 loop {
757 let (mut stream, _addr) = listener.accept().await.unwrap();
758 stream.write_u8(connections).await.unwrap();
759 connections += 1;
760 }
761 });
762
763 let mut stream = Tether::connect_tcp(addr, Once).await.unwrap();
764 let mut buf = Vec::new();
765 stream.read_to_end(&mut buf).await.unwrap();
766 assert_eq!(buf.as_slice(), &[0, 1])
767 }
768
769 #[tokio::test]
770 async fn error_is_consumed_when_set() {
771 let listener = TcpListener::bind("0.0.0.0:0").await.unwrap();
772 let addr = listener.local_addr().unwrap();
773 tokio::spawn(async move {
774 let (mut stream, _addr) = listener.accept().await.unwrap();
775 stream.write_all(b"foobar").await.unwrap();
776 stream.shutdown().await.unwrap();
777 });
778
779 let mut stream = Tether::connect_tcp(addr, Once).await.unwrap();
784 stream.set_config(Config {
785 error_propagation_on_no_retry: config::ErrorPropagation::IoOperations,
786 ..Default::default()
787 });
788 let mut buf = Vec::new();
789
790 stream.read_to_end(&mut buf).await.unwrap();
791 assert_eq!(buf, b"foobar".as_slice())
792 }
793
794 #[tokio::test]
795 async fn write_data_is_silently_dropped_when_set() {
796 let listener = TcpListener::bind("0.0.0.0:0").await.unwrap();
797 let addr = listener.local_addr().unwrap();
798 let handle = tokio::spawn(async move {
799 let mut buf = vec![0u8; 3];
800
801 let (mut stream, _addr) = listener.accept().await.unwrap();
802 stream.read_exact(&mut buf[..]).await.unwrap();
803 stream.shutdown().await.unwrap();
804
805 buf
806 });
807
808 let mut stream = Tether::connect_tcp(addr, Value(Action::Exhaust))
809 .await
810 .unwrap();
811 stream.set_config(Config {
812 keep_data_on_failed_write: false,
813 ..Default::default()
814 });
815
816 stream.write_all(b"foo").await.unwrap();
817
818 let buf = handle.await.unwrap();
819
820 stream.write_all(b"bar").await.unwrap();
824
825 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
827
828 stream.write_all(b"baz").await.unwrap();
830
831 assert_eq!(b"foo".as_slice(), buf)
832 }
833}