io_tether/lib.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313
#![doc = include_str!("../README.md")]
use std::{future::Future, task::Poll};
mod implementations;
/// Represents a type which drives reconnects
///
/// Since the disconnected method asynchronous, and is invoked when the underlying stream
/// disconnects, calling asynchronous functions like
/// [`tokio::time::sleep`](https://docs.rs/tokio/latest/tokio/time/fn.sleep.html) from within the
/// body, work.
///
/// # Example
///
/// A very simple implementation may look something like the following:
///
/// ```ignore
/// # use io_tether::{Context, State, TetherResolver};
/// pub struct RetryResolver;
///
/// impl TetherResolver for RetryResolver {
/// type Error = std::io::Error;
///
/// async fn disconnected(&mut self, context: &Context, state: &State<Self::Error>) -> bool {
/// tracing::warn!(?state, "Disconnected from server");
/// if context.reconnect_count() >= 5 {
/// return false;
/// }
///
/// tokio::time::sleep(Duration::from_secs(10)).await;
/// true
/// }
/// }
/// ```
// TODO: Remove the Unpin restriction
pub trait TetherResolver: Unpin {
type Error;
/// Invoked by Tether when an error/disconnect is encountered.
///
/// Returning `true` will result in a reconnect being attempted via `<T as TetherIo>::reconnect`,
/// returning `false` will result in the error being returned from the originating call.
///
/// # Note
///
/// The [`State`] will describe the type of the underlying error. It can either be `State::Eof`,
/// in which case the end of file was reached, or an error. This information can be leveraged
/// in this function to determine whether to attempt to reconnect.
fn disconnected(
&mut self,
context: &Context,
state: &State<Self::Error>,
) -> impl Future<Output = bool> + Send;
}
/// Represents an I/O source capable of reconnecting
///
/// This trait is implemented for a number of types in the library, with the implementations placed
/// behind feature flags
pub trait TetherIo<T>: Sized + Unpin {
type Error;
/// Initializes the connection to the I/O source
fn connect(initializer: &T) -> impl Future<Output = Result<Self, Self::Error>> + Send;
/// Re-establishes the connection to the I/O source
fn reconnect(initializer: &T) -> impl Future<Output = Result<Self, Self::Error>> + Send {
Self::connect(initializer)
}
}
enum Status<E> {
Success,
Failover(State<E>),
}
/// The underlying cause of the I/O disconnect
///
/// Currently this is either an error, or an 'end of file'.
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum State<E> {
/// End of File
///
/// # Note
///
/// This is also emitted when the other half of a TCP connection is closed.
Eof,
/// An I/O Error occurred
Err(E),
}
impl From<State<std::io::Error>> for std::io::Error {
fn from(value: State<std::io::Error>) -> Self {
match value {
State::Eof => std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "Eof error"),
State::Err(error) => error,
}
}
}
/// A wrapper type which contains the underlying I/O object, it's initializer, and resolver.
///
/// This in the main type exposed by the library. It implements [`AsyncRead`](tokio::io::AsyncRead)
/// and [`AsyncWrite`](tokio::io::AsyncWrite) whenever the underlying I/O object implements them.
///
/// Calling things like
/// [`read_buf`](https://docs.rs/tokio/latest/tokio/io/trait.AsyncReadExt.html#method.read_buf) will
/// result in the I/O automatically reconnecting if an error is detected during the underlying I/O
/// call.
///
/// # Note
///
/// Currently, there is no way to obtain a reference into the underlying I/O object. And the only
/// way to reclaim the inner I/O type is by calling [`Tether::into_inner`]. This is by design, since
/// in the future there may be reason to add unsafe code which cannot be guaranteed if outside
/// callers can obtain references. In the future I may add these as unsafe functions if those cases
/// can be described.
pub struct Tether<I, T, R> {
context: Context,
initializer: I,
inner: T,
resolver: R,
}
impl<I, T, R> Tether<I, T, R> {
/// Construct a tether object from an existing I/O source
///
/// # Note
///
/// Often a simpler way to construct a [`Tether`] object is through [`Tether::connect`]
pub fn new(inner: T, initializer: I, resolver: R) -> Self {
Self {
context: Context::default(),
initializer,
inner,
resolver,
}
}
/// Returns a reference to the resolver
pub fn get_resolver(&self) -> &R {
&self.resolver
}
/// Returns a mutable reference to the resolver
pub fn get_resolver_mut(&mut self) -> &mut R {
&mut self.resolver
}
/// Returns a reference to the initializer
pub fn get_initializer(&self) -> &I {
&self.initializer
}
/// Returns a mutable reference to the initializer
pub fn get_initializer_mut(&mut self) -> &mut I {
&mut self.initializer
}
/// Consume the Tether, and return the underlying I/O type
pub fn into_inner(self) -> T {
self.inner
}
/// Returns a reference to the context
pub fn get_context(&self) -> &Context {
&self.context
}
/// Returns a mutable reference to the context
pub fn get_context_mut(&mut self) -> &mut Context {
&mut self.context
}
}
impl<I, T, R> Tether<I, T, R>
where
T: TetherIo<I>,
{
/// Connect to the I/O source
///
/// Invokes [`TetherIo::connect`] to establish the connection, the same method which is called
/// when Tether attempts to reconnect.
pub async fn connect(initializer: I, resolver: R) -> Result<Self, T::Error> {
let inner = T::connect(&initializer).await?;
Ok(Self::new(inner, initializer, resolver))
}
}
/// Contains metrics about the underlying connection
///
/// Passed to the [`TetherResolver`], with each call to `disconnect`.
///
/// Currently tracks the number of reconnect attempts, but in the future may be expanded to include
/// additional metrics.
#[derive(Default, Debug, Clone, PartialEq, Eq)]
pub struct Context {
reconnection_attempts: usize,
}
impl Context {
/// The number of times a reconnect has been attempted.
///
/// The first time [`TetherResolver::disconnected`] is invoked this will return `1`.
pub fn reconnect_count(&self) -> usize {
self.reconnection_attempts
}
}
impl<I, T, R> Tether<I, T, R>
where
T: TetherIo<I, Error = R::Error>,
R: TetherResolver,
{
pub(crate) fn poll_reconnect(
&mut self,
cx: &mut std::task::Context<'_>,
mut state: State<R::Error>,
) -> Poll<Status<R::Error>> {
loop {
self.context.reconnection_attempts += 1;
// NOTE: Prevent holding the ref to error outside this block
let retry = {
let mut resolver_pin = std::pin::pin!(&mut self.resolver);
let resolver_fut = resolver_pin.disconnected(&self.context, &state);
let resolver_fut_pin = std::pin::pin!(resolver_fut);
ready::ready!(resolver_fut_pin.poll(cx))
};
if !retry {
return Poll::Ready(Status::Failover(state));
}
let fut = T::reconnect(&self.initializer);
let fut_pin = std::pin::pin!(fut);
match ready::ready!(fut_pin.poll(cx)) {
Ok(new_stream) => {
// NOTE: This is why we need the underlying stream to be Unpin, since we swap
// it with a new one of the same type. Not aware of a safe alternative
self.inner = new_stream;
return Poll::Ready(Status::Success);
}
Err(new_error) => state = State::Err(new_error),
}
}
}
}
pub(crate) mod ready {
macro_rules! ready {
($e:expr $(,)?) => {
match $e {
std::task::Poll::Ready(t) => t,
std::task::Poll::Pending => return std::task::Poll::Pending,
}
};
}
pub(crate) use ready;
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::{TcpListener, TcpStream},
sync::mpsc,
};
async fn create_tcp_pair() -> (TcpStream, TcpStream) {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let (client, server) = tokio::join!(TcpStream::connect(addr), listener.accept());
(client.unwrap(), server.unwrap().0)
}
pub struct CallbackResolver {
inner: mpsc::Sender<()>,
}
impl TetherResolver for CallbackResolver {
type Error = std::io::Error;
async fn disconnected(&mut self, _context: &Context, _state: &State<Self::Error>) -> bool {
self.inner.send(()).await.unwrap();
false
}
}
#[cfg(feature = "net")]
#[tokio::test]
async fn disconnect_triggers_callback() {
let (tx, mut rx) = mpsc::channel(1);
let (client, mut server) = create_tcp_pair().await;
let resolver = CallbackResolver { inner: tx };
let mut tether = Tether::new(client, "", resolver);
let mut buf = Vec::new();
server.write_all(b"foo-bar").await.unwrap();
tether.read_buf(&mut buf).await.unwrap();
assert_eq!(&buf, b"foo-bar");
server.shutdown().await.unwrap();
tether.read_buf(&mut buf).await.unwrap();
assert!(rx.recv().await.is_some());
}
}