#![doc = include_str!("../README.md")]
use std::{future::Future, io::ErrorKind, pin::Pin};
pub mod config;
#[cfg(feature = "fs")]
pub mod fs;
mod implementations;
#[cfg(feature = "net")]
pub mod tcp;
#[cfg(all(feature = "net", target_family = "unix"))]
pub mod unix;
use config::Config;
pub type PinFut<O> = Pin<Box<dyn Future<Output = O> + 'static + Send>>;
pub trait Resolver<C> {
fn disconnected(&mut self, context: &Context, connector: &mut C) -> PinFut<Action>;
fn unreachable(&mut self, context: &Context, connector: &mut C) -> PinFut<bool> {
let fut = self.disconnected(context, connector);
Box::pin(async move {
match fut.await {
Action::AttemptReconnect => true,
Action::Exhaust | Action::Ignore => false,
}
})
}
fn established(&mut self, context: &Context) -> PinFut<()> {
self.reconnected(context)
}
fn reconnected(&mut self, _context: &Context) -> PinFut<()> {
Box::pin(std::future::ready(()))
}
}
pub trait Connector {
type Output;
fn connect(&mut self) -> PinFut<Result<Self::Output, std::io::Error>>;
fn reconnect(&mut self) -> PinFut<Result<Self::Output, std::io::Error>> {
self.connect()
}
}
#[derive(Debug)]
#[non_exhaustive]
pub enum Reason {
Eof,
Err(std::io::Error),
}
impl std::fmt::Display for Reason {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Reason::Eof => f.write_str("End of file detected"),
Reason::Err(error) => error.fmt(f),
}
}
}
impl std::error::Error for Reason {}
impl Reason {
pub(crate) fn clone_private(&self) -> Self {
match self {
Reason::Eof => Self::Eof,
Reason::Err(error) => {
let kind = error.kind();
let error = std::io::Error::new(kind, error.to_string());
Self::Err(error)
}
}
}
pub fn retryable(&self) -> bool {
use std::io::ErrorKind as Kind;
match self {
Reason::Eof => true,
Reason::Err(error) => matches!(
error.kind(),
Kind::NotFound
| Kind::PermissionDenied
| Kind::ConnectionRefused
| Kind::ConnectionAborted
| Kind::ConnectionReset
| Kind::NotConnected
| Kind::AlreadyExists
| Kind::HostUnreachable
| Kind::AddrNotAvailable
| Kind::NetworkDown
| Kind::BrokenPipe
| Kind::TimedOut
| Kind::UnexpectedEof
| Kind::NetworkUnreachable
| Kind::AddrInUse
),
}
}
}
impl From<Reason> for std::io::Error {
fn from(value: Reason) -> Self {
match value {
Reason::Eof => std::io::Error::new(ErrorKind::UnexpectedEof, "Eof error"),
Reason::Err(error) => error,
}
}
}
pub struct Tether<C: Connector, R> {
state: State<C::Output>,
inner: TetherInner<C, R>,
}
struct TetherInner<C: Connector, R> {
config: Config,
connector: C,
context: Context,
io: C::Output,
resolver: R,
last_write: Option<Reason>,
}
impl<C: Connector, R: Resolver<C>> TetherInner<C, R> {
fn set_connected(&mut self, state: &mut State<C::Output>) {
*state = State::Connected;
self.context.reset();
}
fn set_reconnected(&mut self, state: &mut State<C::Output>, new_io: <C as Connector>::Output) {
self.io = new_io;
let fut = self.resolver.reconnected(&self.context);
*state = State::Reconnected(fut);
}
fn set_reconnecting(&mut self, state: &mut State<C::Output>) {
let fut = self.connector.reconnect();
*state = State::Reconnecting(fut);
}
fn set_disconnected(&mut self, state: &mut State<C::Output>, reason: Reason, source: Source) {
self.context.reason = Some((reason, source));
let fut = self
.resolver
.disconnected(&self.context, &mut self.connector);
*state = State::Disconnected(fut);
}
}
impl<C: Connector, R> Tether<C, R> {
pub fn resolver(&self) -> &R {
&self.inner.resolver
}
pub fn connector(&self) -> &C {
&self.inner.connector
}
pub fn context(&self) -> &Context {
&self.inner.context
}
}
impl<C, R> Tether<C, R>
where
C: Connector,
R: Resolver<C>,
{
pub fn new(connector: C, io: C::Output, resolver: R) -> Self {
Self::new_with_config(connector, io, resolver, Config::default())
}
pub fn new_with_config(connector: C, io: C::Output, resolver: R, config: Config) -> Self {
Self::new_with_context(connector, io, resolver, Context::default(), config)
}
fn new_with_context(
connector: C,
io: C::Output,
resolver: R,
context: Context,
config: Config,
) -> Self {
Self {
state: Default::default(),
inner: TetherInner {
config,
connector,
context,
io,
resolver,
last_write: None,
},
}
}
pub fn set_config(&mut self, config: Config) {
self.inner.config = config;
}
#[inline]
pub fn into_inner(self) -> C::Output {
self.inner.io
}
pub async fn connect(mut connector: C, mut resolver: R) -> Result<Self, std::io::Error> {
let mut context = Context::default();
loop {
let state = match connector.connect().await {
Ok(io) => {
resolver.established(&context).await;
context.reset();
return Ok(Self::new_with_context(
connector,
io,
resolver,
context,
Config::default(),
));
}
Err(error) => error,
};
context.reason = Some((Reason::Err(state), Source::Reconnect));
context.increment_attempts();
if !resolver.unreachable(&context, &mut connector).await {
let Some((Reason::Err(error), _)) = context.reason else {
unreachable!("state is immutable and established as Err above");
};
return Err(error);
}
}
}
pub async fn connect_without_retry(
mut connector: C,
mut resolver: R,
) -> Result<Self, std::io::Error> {
let context = Context::default();
let io = connector.connect().await?;
resolver.established(&context).await;
Ok(Self::new_with_context(
connector,
io,
resolver,
context,
Config::default(),
))
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum Action {
AttemptReconnect,
Exhaust,
Ignore,
}
#[derive(Default)]
enum State<T> {
#[default]
Connected,
Disconnected(PinFut<Action>),
Reconnecting(PinFut<Result<T, std::io::Error>>),
Reconnected(PinFut<()>),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
enum Source {
Io,
Reconnect,
}
#[derive(Default, Debug)]
pub struct Context {
total_attempts: usize,
current_attempts: usize,
reason: Option<(Reason, Source)>,
}
impl Context {
#[inline]
pub fn current_reconnect_attempts(&self) -> usize {
self.current_attempts
}
#[inline]
pub fn total_reconnect_attempts(&self) -> usize {
self.total_attempts
}
fn increment_attempts(&mut self) {
self.current_attempts += 1;
self.total_attempts += 1;
}
#[inline]
pub fn reason(&self) -> &Reason {
self.try_reason().unwrap()
}
#[inline]
pub fn try_reason(&self) -> Option<&Reason> {
self.reason.as_ref().map(|val| &val.0)
}
#[inline]
fn reset(&mut self) {
self.current_attempts = 0;
}
}
pub(crate) mod ready {
macro_rules! ready {
($e:expr $(,)?) => {
match $e {
std::task::Poll::Ready(t) => t,
std::task::Poll::Pending => return std::task::Poll::Pending,
}
};
}
pub(crate) use ready;
}
#[cfg(test)]
mod tests {
use tokio::{
io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
net::TcpListener,
};
use tokio_test::io::{Builder, Mock};
use super::*;
struct Value(Action);
impl<T> Resolver<T> for Value {
fn disconnected(&mut self, _context: &Context, _connector: &mut T) -> PinFut<Action> {
let val = self.0;
Box::pin(async move { val })
}
}
struct Once;
impl<T> Resolver<T> for Once {
fn disconnected(&mut self, context: &Context, _connector: &mut T) -> PinFut<Action> {
let retry = if context.total_reconnect_attempts() < 1 {
Action::AttemptReconnect
} else {
Action::Exhaust
};
Box::pin(async move { retry })
}
}
fn other(err: &'static str) -> std::io::Error {
std::io::Error::other(err)
}
trait ReadWrite: 'static + AsyncRead + AsyncWrite + Unpin {}
impl<T: 'static + AsyncRead + AsyncWrite + Unpin> ReadWrite for T {}
struct MockConnector<F>(F);
impl<F: FnMut() -> Mock> Connector for MockConnector<F> {
type Output = Mock;
fn connect(&mut self) -> PinFut<Result<Self::Output, std::io::Error>> {
let value = self.0();
Box::pin(async move { Ok(value) })
}
}
async fn tester<A>(test: A, mock: impl ReadWrite, tether: impl ReadWrite)
where
A: AsyncFn(Box<dyn ReadWrite>) -> String,
{
let mock_result = (test)(Box::new(mock)).await;
let tether_result = (test)(Box::new(tether)).await;
assert_eq!(mock_result, tether_result);
}
async fn mock_acts_as_tether_mock<F, A>(mut gener: F, test: A)
where
F: FnMut() -> Mock + 'static + Unpin,
A: AsyncFn(Box<dyn ReadWrite>) -> String,
{
let mock = gener();
let tether_mock = Tether::connect(MockConnector(gener), Value(Action::Exhaust))
.await
.unwrap();
tester(test, mock, tether_mock).await
}
#[tokio::test]
async fn single_read_then_eof() {
let test = async |mut reader: Box<dyn ReadWrite>| {
let mut output = String::new();
reader.read_to_string(&mut output).await.unwrap();
output
};
mock_acts_as_tether_mock(|| Builder::new().read(b"foobar").read(b"").build(), test).await;
}
#[tokio::test]
async fn two_read_then_eof() {
let test = async |mut reader: Box<dyn ReadWrite>| {
let mut output = String::new();
reader.read_to_string(&mut output).await.unwrap();
output
};
let builder = || Builder::new().read(b"foo").read(b"bar").read(b"").build();
mock_acts_as_tether_mock(builder, test).await;
}
#[tokio::test]
async fn immediate_error() {
let test = async |mut reader: Box<dyn ReadWrite>| {
let mut output = String::new();
let result = reader.read_to_string(&mut output).await;
format!("{:?}", result)
};
let builder = || {
Builder::new()
.read_error(std::io::Error::other("oops!"))
.build()
};
mock_acts_as_tether_mock(builder, test).await;
}
#[tokio::test]
async fn basic_write() {
let mock = || Builder::new().write(b"foo").write(b"bar").build();
let mut tether = Tether::connect(MockConnector(mock), Once).await.unwrap();
tether.write_all(b"foo").await.unwrap();
tether.write_all(b"bar").await.unwrap(); }
#[tokio::test]
async fn failure_to_connect_doesnt_panic() {
struct Unreachable;
impl<T> Resolver<T> for Unreachable {
fn disconnected(&mut self, context: &Context, _connector: &mut T) -> PinFut<Action> {
let _reason = context.reason(); Box::pin(async move { Action::Exhaust })
}
}
let result = Tether::connect_tcp("0.0.0.0:3150", Unreachable).await;
assert!(result.is_err());
}
#[tokio::test]
async fn read_then_disconnect() {
struct AllowEof;
impl<T> Resolver<T> for AllowEof {
fn disconnected(&mut self, context: &Context, _connector: &mut T) -> PinFut<Action> {
let value = if !matches!(context.reason(), Reason::Eof) {
Action::AttemptReconnect
} else {
Action::Exhaust
};
Box::pin(async move { value })
}
}
let mock = Builder::new().read(b"foobarbaz").read(b"").build();
let mut count = 0;
let b = move |v: &[u8]| Builder::new().read(v).read_error(other("error")).build();
let gener = move || {
let result = match count {
0 => b(b"foo"),
1 => b(b"bar"),
2 => b(b"baz"),
_ => Builder::new().read(b"").build(),
};
count += 1;
result
};
let test = async |mut reader: Box<dyn ReadWrite>| {
let mut output = String::new();
reader.read_to_string(&mut output).await.unwrap();
output
};
let tether_mock = Tether::connect(MockConnector(gener), AllowEof)
.await
.unwrap();
tester(test, mock, tether_mock).await
}
#[tokio::test]
async fn split_works() {
let listener = TcpListener::bind("0.0.0.0:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
loop {
let (mut stream, _addr) = listener.accept().await.unwrap();
stream.write_all(b"foobar").await.unwrap();
stream.shutdown().await.unwrap();
}
});
let stream = Tether::connect_tcp(addr, Once).await.unwrap();
let (mut read, mut write) = tokio::io::split(stream);
let mut buf = [0u8; 6];
read.read_exact(&mut buf).await.unwrap(); assert_eq!(&buf, b"foobar");
write.write_all(b"foobar").await.unwrap(); }
#[tokio::test]
async fn reconnect_value_is_respected() {
let listener = TcpListener::bind("0.0.0.0:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
let (mut stream, _addr) = listener.accept().await.unwrap();
stream.write_all(b"foobar").await.unwrap();
stream.shutdown().await.unwrap();
});
let mut stream = Tether::connect_tcp(addr, Value(Action::Exhaust))
.await
.unwrap();
let mut output = String::new();
stream.read_to_string(&mut output).await.unwrap();
assert_eq!(&output, "foobar");
}
#[tokio::test]
async fn disconnect_is_retried() {
let listener = TcpListener::bind("0.0.0.0:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
let mut connections = 0;
loop {
let (mut stream, _addr) = listener.accept().await.unwrap();
stream.write_u8(connections).await.unwrap();
connections += 1;
}
});
let mut stream = Tether::connect_tcp(addr, Once).await.unwrap();
let mut buf = Vec::new();
stream.read_to_end(&mut buf).await.unwrap();
assert_eq!(buf.as_slice(), &[0, 1])
}
#[tokio::test]
async fn error_is_consumed_when_set() {
let listener = TcpListener::bind("0.0.0.0:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
let (mut stream, _addr) = listener.accept().await.unwrap();
stream.write_all(b"foobar").await.unwrap();
stream.shutdown().await.unwrap();
});
let mut stream = Tether::connect_tcp(addr, Once).await.unwrap();
stream.set_config(Config {
error_propagation_on_no_retry: config::ErrorPropagation::IoOperations,
..Default::default()
});
let mut buf = Vec::new();
stream.read_to_end(&mut buf).await.unwrap();
assert_eq!(buf, b"foobar".as_slice())
}
#[tokio::test]
async fn write_data_is_silently_dropped_when_set() {
let listener = TcpListener::bind("0.0.0.0:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let handle = tokio::spawn(async move {
let mut buf = vec![0u8; 3];
let (mut stream, _addr) = listener.accept().await.unwrap();
stream.read_exact(&mut buf[..]).await.unwrap();
stream.shutdown().await.unwrap();
buf
});
let mut stream = Tether::connect_tcp(addr, Value(Action::Exhaust))
.await
.unwrap();
stream.set_config(Config {
keep_data_on_failed_write: false,
..Default::default()
});
stream.write_all(b"foo").await.unwrap();
let buf = handle.await.unwrap();
stream.write_all(b"bar").await.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
stream.write_all(b"baz").await.unwrap();
assert_eq!(b"foo".as_slice(), buf)
}
}