use std::marker::PhantomData;
use std::sync::Arc;
use async_trait::async_trait;
use entelix_core::ExecutionContext;
use entelix_core::error::{Error, Result};
use entelix_core::transports::{DefaultRetryClassifier, RetryClassifier};
use crate::runnable::Runnable;
pub struct Fallback<R, F, I, O>
where
R: Runnable<I, O> + 'static,
F: Runnable<I, O> + 'static,
I: Clone + Send + 'static,
O: Send + 'static,
{
primary: Arc<R>,
fallbacks: Vec<Arc<F>>,
classifier: Arc<dyn RetryClassifier>,
_io: PhantomData<fn(I) -> O>,
}
impl<R, F, I, O> Fallback<R, F, I, O>
where
R: Runnable<I, O> + 'static,
F: Runnable<I, O> + 'static,
I: Clone + Send + 'static,
O: Send + 'static,
{
pub fn new(primary: R, fallbacks: Vec<F>) -> Self {
Self {
primary: Arc::new(primary),
fallbacks: fallbacks.into_iter().map(Arc::new).collect(),
classifier: Arc::new(DefaultRetryClassifier),
_io: PhantomData,
}
}
#[must_use]
pub fn with_classifier(mut self, classifier: Arc<dyn RetryClassifier>) -> Self {
self.classifier = classifier;
self
}
}
impl<R, F, I, O> std::fmt::Debug for Fallback<R, F, I, O>
where
R: Runnable<I, O> + 'static,
F: Runnable<I, O> + 'static,
I: Clone + Send + 'static,
O: Send + 'static,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Fallback")
.field("fallback_count", &self.fallbacks.len())
.finish_non_exhaustive()
}
}
#[async_trait]
impl<R, F, I, O> Runnable<I, O> for Fallback<R, F, I, O>
where
R: Runnable<I, O> + 'static,
F: Runnable<I, O> + 'static,
I: Clone + Send + 'static,
O: Send + 'static,
{
async fn invoke(&self, input: I, ctx: &ExecutionContext) -> Result<O> {
if ctx.is_cancelled() {
return Err(Error::Cancelled);
}
let mut attempt: u32 = 0;
let primary_result = self.primary.invoke(input.clone(), ctx).await;
let mut last_err = match primary_result {
Ok(value) => return Ok(value),
Err(err) => {
if !self.classifier.should_retry(&err, attempt).retry {
return Err(err);
}
err
}
};
for fallback in &self.fallbacks {
attempt = attempt.saturating_add(1);
if ctx.is_cancelled() {
return Err(Error::Cancelled);
}
match fallback.invoke(input.clone(), ctx).await {
Ok(value) => return Ok(value),
Err(err) => {
if !self.classifier.should_retry(&err, attempt).retry {
return Err(err);
}
last_err = err;
}
}
}
Err(last_err)
}
}