use std::error::Error as StdError;
use std::fmt;
use std::time::Instant;
pub type DynError = Box<dyn StdError + Send + Sync>;
#[derive(Debug)]
pub struct Attempt {
pub name: String,
pub error: Option<DynError>,
pub duration_ms: f64,
}
#[derive(Debug)]
pub struct ChainResult<O> {
pub value: O,
pub provider: String,
pub attempts: Vec<Attempt>,
}
#[derive(Debug)]
pub struct AllProvidersFailed {
pub attempts: Vec<Attempt>,
}
impl fmt::Display for AllProvidersFailed {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let names: Vec<&str> = self.attempts.iter().map(|a| a.name.as_str()).collect();
write!(f, "all providers failed: {}", names.join(", "))
}
}
impl StdError for AllProvidersFailed {}
pub type SyncProvider<I, O> = Box<dyn Fn(&I) -> Result<O, DynError> + Send + Sync>;
pub type ShouldFallBack = Box<dyn Fn(&DynError) -> bool + Send + Sync>;
pub type OnFallback = Box<dyn Fn(&str, &DynError, &str) + Send + Sync>;
fn default_should_fall_back(_err: &DynError) -> bool {
true
}
pub struct FallbackChain<I, O> {
providers: Vec<(String, SyncProvider<I, O>)>,
should_fall_back: ShouldFallBack,
on_fallback: Option<OnFallback>,
}
impl<I, O> FallbackChain<I, O> {
pub fn new<S: Into<String>>(
providers: Vec<(S, SyncProvider<I, O>)>,
) -> Result<Self, &'static str> {
if providers.is_empty() {
return Err("providers must be a non-empty list");
}
let providers = providers
.into_iter()
.map(|(name, fn_)| (name.into(), fn_))
.collect();
Ok(Self {
providers,
should_fall_back: Box::new(default_should_fall_back),
on_fallback: None,
})
}
pub fn with_should_fall_back<F>(mut self, f: F) -> Self
where
F: Fn(&DynError) -> bool + Send + Sync + 'static,
{
self.should_fall_back = Box::new(f);
self
}
pub fn with_on_fallback<F>(mut self, f: F) -> Self
where
F: Fn(&str, &DynError, &str) + Send + Sync + 'static,
{
self.on_fallback = Some(Box::new(f));
self
}
pub fn with_skip<P>(mut self, predicate: P) -> Result<Self, &'static str>
where
P: Fn(&str) -> bool,
{
self.providers.retain(|(name, _)| !predicate(name));
if self.providers.is_empty() {
return Err("with_skip removed all providers");
}
Ok(self)
}
pub fn names(&self) -> Vec<&str> {
self.providers.iter().map(|(n, _)| n.as_str()).collect()
}
pub fn call(&self, input: &I) -> Result<ChainResult<O>, DynError> {
let mut failures: Vec<Attempt> = Vec::new();
let last = self.providers.len() - 1;
for (i, (name, fn_)) in self.providers.iter().enumerate() {
let start = Instant::now();
match fn_(input) {
Ok(value) => {
return Ok(ChainResult {
value,
provider: name.clone(),
attempts: failures,
});
}
Err(err) => {
let elapsed = start.elapsed().as_secs_f64() * 1000.0;
if !(self.should_fall_back)(&err) {
return Err(err);
}
if i < last {
if let Some(cb) = &self.on_fallback {
let next_name = &self.providers[i + 1].0;
cb(name, &err, next_name);
}
}
failures.push(Attempt {
name: name.clone(),
error: Some(err),
duration_ms: elapsed,
});
}
}
}
Err(Box::new(AllProvidersFailed { attempts: failures }))
}
}
#[cfg(feature = "tokio")]
mod async_chain {
use super::{
default_should_fall_back, AllProvidersFailed, Attempt, ChainResult, DynError, OnFallback,
ShouldFallBack,
};
use futures::future::BoxFuture;
use std::time::Instant;
pub type AsyncProvider<I, O> =
Box<dyn for<'a> Fn(&'a I) -> BoxFuture<'a, Result<O, DynError>> + Send + Sync>;
pub fn async_provider<I, O, F, Fut>(f: F) -> AsyncProvider<I, O>
where
F: for<'a> Fn(&'a I) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = Result<O, DynError>> + Send + 'static,
I: 'static,
{
Box::new(move |i: &I| {
let fut = f(i);
Box::pin(fut) as BoxFuture<'_, _>
})
}
pub struct AsyncFallbackChain<I, O> {
providers: Vec<(String, AsyncProvider<I, O>)>,
should_fall_back: ShouldFallBack,
on_fallback: Option<OnFallback>,
}
impl<I: Send + Sync, O: Send> AsyncFallbackChain<I, O> {
pub fn new<S: Into<String>>(
providers: Vec<(S, AsyncProvider<I, O>)>,
) -> Result<Self, &'static str> {
if providers.is_empty() {
return Err("providers must be a non-empty list");
}
let providers = providers
.into_iter()
.map(|(name, fn_)| (name.into(), fn_))
.collect();
Ok(Self {
providers,
should_fall_back: Box::new(default_should_fall_back),
on_fallback: None,
})
}
pub fn with_should_fall_back<F>(mut self, f: F) -> Self
where
F: Fn(&DynError) -> bool + Send + Sync + 'static,
{
self.should_fall_back = Box::new(f);
self
}
pub fn with_on_fallback<F>(mut self, f: F) -> Self
where
F: Fn(&str, &DynError, &str) + Send + Sync + 'static,
{
self.on_fallback = Some(Box::new(f));
self
}
pub fn with_skip<P>(mut self, predicate: P) -> Result<Self, &'static str>
where
P: Fn(&str) -> bool,
{
self.providers.retain(|(name, _)| !predicate(name));
if self.providers.is_empty() {
return Err("with_skip removed all providers");
}
Ok(self)
}
pub fn names(&self) -> Vec<&str> {
self.providers.iter().map(|(n, _)| n.as_str()).collect()
}
pub async fn call(&self, input: &I) -> Result<ChainResult<O>, DynError> {
let mut failures: Vec<Attempt> = Vec::new();
let last = self.providers.len() - 1;
for (i, (name, fn_)) in self.providers.iter().enumerate() {
let start = Instant::now();
match fn_(input).await {
Ok(value) => {
return Ok(ChainResult {
value,
provider: name.clone(),
attempts: failures,
});
}
Err(err) => {
let elapsed = start.elapsed().as_secs_f64() * 1000.0;
if !(self.should_fall_back)(&err) {
return Err(err);
}
if i < last {
if let Some(cb) = &self.on_fallback {
let next_name = &self.providers[i + 1].0;
cb(name, &err, next_name);
}
}
failures.push(Attempt {
name: name.clone(),
error: Some(err),
duration_ms: elapsed,
});
}
}
}
Err(Box::new(AllProvidersFailed { attempts: failures }))
}
}
}
#[cfg(feature = "tokio")]
pub use async_chain::{async_provider, AsyncFallbackChain, AsyncProvider};
#[cfg(feature = "serde")]
mod serde_impls {
use super::Attempt;
use serde::Serialize;
#[derive(Debug, Serialize)]
pub struct AttemptView {
pub name: String,
pub error: Option<String>,
pub duration_ms: f64,
}
impl From<&Attempt> for AttemptView {
fn from(a: &Attempt) -> Self {
Self {
name: a.name.clone(),
error: a.error.as_ref().map(|e| e.to_string()),
duration_ms: a.duration_ms,
}
}
}
}
#[cfg(feature = "serde")]
pub use serde_impls::AttemptView;