use core::future::Future;
use std::{
fmt::{Debug, Display},
marker::PhantomData,
sync::atomic::{AtomicUsize, Ordering},
sync::Arc,
};
pub use async_trait::async_trait;
use tokio::sync::RwLock;
#[cfg(feature = "tracing")]
use tracing::{
field::{debug, display, Empty},
instrument, Instrument, Span,
};
pub trait Next {
fn is_next(&self) -> bool;
}
impl Next for std::io::Error {
fn is_next(&self) -> bool {
use std::io::ErrorKind::*;
match self.kind() {
ConnectionRefused | ConnectionReset | ConnectionAborted | NotConnected | BrokenPipe
| TimedOut | Interrupted | UnexpectedEof => true,
NotFound | PermissionDenied | AddrInUse | AddrNotAvailable | AlreadyExists
| WouldBlock | InvalidInput | InvalidData | Other => false,
_ => false,
}
}
}
#[async_trait]
pub trait Connector<SvcSrc, Svc, E> {
async fn connect(&self, src: &SvcSrc) -> Result<Svc, E>;
}
pub struct RoundRobin<SvcSrc, Svc, E, Conn>
where
Conn: Connector<SvcSrc, Svc, E>,
{
sources: Vec<SvcSrc>,
connector: Conn,
max_attempts: usize,
service: RwLock<Option<Arc<Svc>>>,
current: AtomicUsize,
_phantom: PhantomData<E>,
}
impl<SvcSrc, Svc, E, Conn> RoundRobin<SvcSrc, Svc, E, Conn>
where
SvcSrc: Debug,
E: Next + Display,
Conn: Connector<SvcSrc, Svc, E>,
{
pub fn new(sources: Vec<SvcSrc>, connector: Conn) -> Self {
Self {
max_attempts: sources.len() + 1,
sources,
connector,
service: RwLock::new(None),
current: AtomicUsize::new(0),
_phantom: PhantomData::default(),
}
}
pub fn set_max_attempts(&mut self, count: usize) {
self.max_attempts = count;
}
pub fn max_attempts(self, count: usize) -> Self {
Self { max_attempts: count, ..self }
}
#[cfg_attr(
feature = "tracing",
tracing::instrument(skip(self, run), err, fields(service = Empty, index = Empty)),
)]
async fn run_inner<Run, RunFut, T>(&self, run: &Run, current: usize) -> Result<T, E>
where
Run: Fn(Arc<Svc>) -> RunFut,
RunFut: Future<Output = Result<T, E>>,
{
let index = current % self.sources.len();
#[cfg(feature = "tracing")]
{
let span = Span::current();
span.record("index", &display(index));
span.record("service", &debug(&self.sources[index]));
}
if self.service.read().await.is_none() {
*self.service.write().await =
Some(Arc::new(self.connector.connect(&self.sources[index]).await?));
}
let fut = run(self.service.read().await.clone().unwrap());
#[cfg(feature = "tracing")]
let fut = fut.instrument(tracing::debug_span!("run_fn"));
let res = fut.await;
if let Err(ref e) = res {
if e.is_next() && current == self.current.load(Ordering::Relaxed) {
*self.service.write().await = None;
}
}
res
}
#[cfg_attr(feature = "tracing", instrument(skip(self, run), err))]
pub async fn run<R, Fut, T>(&self, run: R) -> Result<T, E>
where
R: Fn(Arc<Svc>) -> Fut,
Fut: Future<Output = Result<T, E>>,
{
let n_svc = self.sources.len();
let mut attempts = 0usize;
loop {
let current = self.current.load(Ordering::Relaxed);
match self.run_inner(&run, current).await {
Ok(t) => return Ok(t),
Err(e) => {
if e.is_next() {
#[cfg(feature = "tracing")]
tracing::error!("Service {}/{} failed: {}", current % n_svc, n_svc, e);
#[cfg(not(feature = "tracing"))]
eprintln!("Service {}/{} failed: {}", current % n_svc, n_svc, e);
self.current.fetch_add(1, Ordering::Relaxed);
attempts += 1;
if attempts < self.max_attempts {
continue;
}
}
return Err(e);
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::{
atomic::{AtomicUsize, Ordering},
Arc,
};
#[derive(Debug, PartialEq)]
enum Error {
Timeout,
NotFound,
}
impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "{:?}", self)
}
}
impl Next for Error {
fn is_next(&self) -> bool {
*self == Self::Timeout
}
}
struct Conn {
count: Arc<AtomicUsize>,
ok_from: i32,
}
#[async_trait]
impl Connector<i32, i32, Error> for Conn {
async fn connect(&self, src: &i32) -> Result<i32, Error> {
self.count.fetch_add(1, Ordering::Relaxed);
if *src < self.ok_from {
Err(Error::Timeout)
} else {
Ok(*src)
}
}
}
fn build_rr(
svcs: Vec<i32>,
ok_from: i32,
) -> (RoundRobin<i32, i32, Error, Conn>, Arc<AtomicUsize>) {
let count = Arc::new(AtomicUsize::new(0));
let cnt = count.clone();
(RoundRobin::new(svcs, Conn { count: cnt, ok_from }), count)
}
#[tokio::test]
async fn test_first_called() {
let (rr, count_conn) = build_rr(vec![0, 1], 0);
let count_run = AtomicUsize::new(0);
rr.run(|_| async { Ok(count_run.fetch_add(1, Ordering::Relaxed)) }).await.unwrap();
assert_eq!(count_conn.load(Ordering::Relaxed), 1);
assert_eq!(count_run.load(Ordering::Relaxed), 1);
rr.run(|_| async { Ok(count_run.fetch_add(1, Ordering::Relaxed)) }).await.unwrap();
assert_eq!(count_conn.load(Ordering::Relaxed), 1);
assert_eq!(count_run.load(Ordering::Relaxed), 2);
}
#[tokio::test]
async fn test_exhausted_run() {
let (mut rr, count_conn) = build_rr(vec![0, 1], 0);
let count = AtomicUsize::new(0);
rr.set_max_attempts(1);
let res = rr
.run(|n| {
count.fetch_add(1, Ordering::Relaxed);
async move {
match *n {
0 => Err(Error::Timeout), _ => Ok(n),
}
}
})
.await;
assert_eq!(count_conn.load(Ordering::Relaxed), 1);
match res {
Ok(_) => panic!("Run did not error"),
Err(Error::Timeout) => (),
Err(_) => panic!("Wrong error"),
}
assert_eq!(count.load(Ordering::Relaxed), 1);
}
#[tokio::test]
async fn test_try_next_run() {
let (rr, count_conn) = build_rr(vec![0, 1], 0);
let count = AtomicUsize::new(0);
rr.run(|n| {
count.fetch_add(1, Ordering::Relaxed);
async move {
match *n {
0 => Err(Error::Timeout), _ => Ok(n),
}
}
})
.await
.unwrap();
assert_eq!(count_conn.load(Ordering::Relaxed), 2);
assert_eq!(count.load(Ordering::Relaxed), 2);
}
#[tokio::test]
async fn test_exhausted_connector() {
let (mut rr, count_conn) = build_rr(vec![0, 1], 1);
let count = AtomicUsize::new(0);
rr.set_max_attempts(1);
let res = rr.run(|_| async { Ok(count.fetch_add(1, Ordering::Relaxed)) }).await;
assert_eq!(count_conn.load(Ordering::Relaxed), 1);
match res {
Ok(_) => panic!("Connect did not error"),
Err(Error::Timeout) => (),
Err(_) => panic!("Wrong error"),
}
assert_eq!(count.load(Ordering::Relaxed), 0);
}
#[tokio::test]
async fn test_try_next_connector() {
let (rr, count_conn) = build_rr(vec![0, 1], 1);
let count = AtomicUsize::new(0);
rr.run(|_| async { Ok(count.fetch_add(1, Ordering::Relaxed)) }).await.unwrap();
assert_eq!(count_conn.load(Ordering::Relaxed), 2);
assert_eq!(count.load(Ordering::Relaxed), 1);
}
#[tokio::test]
async fn test_abort() {
let (rr, _) = build_rr(vec![0, 1], 1);
let res = rr.run(|_| async { Err::<(), _>(Error::NotFound) }).await;
match res {
Ok(_) => panic!("Run did not error"),
Err(Error::NotFound) => (),
Err(Error::Timeout) => panic!("Connector error aborted"),
}
}
}