use apalis_core::error::AbortError;
use apalis_core::task::Task;
use apalis_core::task::builder::TaskBuilder;
use apalis_core::task::metadata::MetadataExt;
use apalis_core::task::status::Status;
use apalis_core::worker::context::WorkerContext;
use std::any::Any;
use std::fmt::Debug;
use tower::retry::backoff::Backoff;
pub use tower::retry::*;
pub use tower::util::rng::HasherRng;
#[derive(Clone, Debug)]
pub struct BackoffRetryPolicy<B> {
retries: usize,
backoff: B,
}
impl<B> BackoffRetryPolicy<B> {
pub fn new(retries: usize, backoff: B) -> Self {
Self { retries, backoff }
}
pub fn retry_if<F, Err>(self, predicate: F) -> RetryIfPolicy<Self, F>
where
F: Fn(&Err) -> bool + Send + Sync + 'static,
{
RetryIfPolicy::new(self, predicate)
}
pub fn from_task_config(self) -> FromTaskConfigPolicy<Self> {
FromTaskConfigPolicy::new(self)
}
}
impl<T, Res, Ctx, B, Err: Any, IdType> Policy<Task<T, Ctx, IdType>, Res, Err>
for BackoffRetryPolicy<B>
where
T: Clone,
Ctx: Clone,
IdType: Clone,
B: Backoff,
B::Future: Send + 'static,
{
type Future = B::Future;
fn retry(
&mut self,
req: &mut Task<T, Ctx, IdType>,
result: &mut Result<Res, Err>,
) -> Option<Self::Future> {
let attempt = req.parts.attempt.current();
let status = req.parts.status.load();
let worker = req.parts.data.get::<WorkerContext>()?;
if worker.is_shutting_down() {
return None;
}
match result.as_mut() {
Ok(_) => {
None
}
Err(_) if self.retries == 0 => None,
Err(_) if status == Status::Killed => None,
Err(err) if (err as &dyn Any).downcast_ref::<AbortError>().is_some() => None,
Err(_) if self.retries >= attempt => Some(self.backoff.next_backoff()),
Err(_) => None,
}
}
fn clone_request(&mut self, req: &Task<T, Ctx, IdType>) -> Option<Task<T, Ctx, IdType>> {
let req = req.clone();
Some(req)
}
}
#[derive(Clone, Debug)]
pub struct RetryPolicy {
retries: usize,
}
impl Default for RetryPolicy {
fn default() -> Self {
Self { retries: 5 }
}
}
impl RetryPolicy {
pub fn retries(retries: usize) -> Self {
Self { retries }
}
pub fn with_backoff<B: Backoff>(self, backoff: B) -> BackoffRetryPolicy<B> {
BackoffRetryPolicy {
retries: self.retries,
backoff,
}
}
pub fn retry_if<F, Err>(self, predicate: F) -> RetryIfPolicy<Self, F>
where
F: Fn(&Err) -> bool + Send + Sync + 'static,
{
RetryIfPolicy::new(self, predicate)
}
pub fn from_task_config(self) -> FromTaskConfigPolicy<Self> {
FromTaskConfigPolicy::new(self)
}
}
impl<T, Res, Ctx, Err: Any, IdType> Policy<Task<T, Ctx, IdType>, Res, Err> for RetryPolicy
where
T: Clone,
Ctx: Clone,
IdType: Clone,
{
type Future = std::future::Ready<()>;
fn retry(
&mut self,
req: &mut Task<T, Ctx, IdType>,
result: &mut Result<Res, Err>,
) -> Option<Self::Future> {
let attempt = req.parts.attempt.current();
let status = req.parts.status.load();
let worker = req.parts.data.get::<WorkerContext>()?;
if worker.is_shutting_down() {
return None;
}
match result.as_mut() {
Ok(_) => {
None
}
Err(_) if self.retries == 0 => None,
Err(_) if status == Status::Killed => None,
Err(err) if (err as &dyn Any).downcast_ref::<AbortError>().is_some() => None,
Err(_) if self.retries >= attempt => Some(std::future::ready(())),
Err(_) => None,
}
}
fn clone_request(&mut self, req: &Task<T, Ctx, IdType>) -> Option<Task<T, Ctx, IdType>> {
let req = req.clone();
Some(req)
}
}
#[derive(Debug, Clone)]
pub struct RetryIfPolicy<P, F> {
inner: P,
predicate: F,
}
impl<P, F> RetryIfPolicy<P, F> {
pub fn new(inner: P, predicate: F) -> Self {
Self { inner, predicate }
}
pub fn from_task_config(self) -> FromTaskConfigPolicy<Self> {
FromTaskConfigPolicy::new(self)
}
}
impl<T, Res, Ctx, P, F, Err, IdType> Policy<Task<T, Ctx, IdType>, Res, Err> for RetryIfPolicy<P, F>
where
T: Clone,
Ctx: Clone,
P: Policy<Task<T, Ctx, IdType>, Res, Err>,
F: Fn(&Err) -> bool + Send + Sync + 'static,
{
type Future = P::Future;
fn retry(
&mut self,
req: &mut Task<T, Ctx, IdType>,
result: &mut Result<Res, Err>,
) -> Option<Self::Future> {
let worker = req.parts.data.get::<WorkerContext>()?;
if worker.is_shutting_down() {
return None;
}
match result {
Ok(_) => None,
Err(err) => {
if !(self.predicate)(err) {
return None;
}
self.inner.retry(req, result)
}
}
}
fn clone_request(&mut self, req: &Task<T, Ctx, IdType>) -> Option<Task<T, Ctx, IdType>> {
self.inner.clone_request(req)
}
}
#[derive(Debug, Clone)]
pub struct RetryConfig {
pub retries: usize,
}
#[derive(Debug, Clone)]
pub struct FromTaskConfigPolicy<P> {
inner: P,
}
impl<P> FromTaskConfigPolicy<P> {
pub fn new(inner: P) -> Self {
Self { inner }
}
pub fn retry_if<F, Err>(self, predicate: F) -> RetryIfPolicy<Self, F>
where
F: Fn(&Err) -> bool + Send + Sync + 'static,
{
RetryIfPolicy::new(self, predicate)
}
}
impl Default for FromTaskConfigPolicy<RetryPolicy> {
fn default() -> Self {
Self {
inner: RetryPolicy::retries(0),
}
}
}
impl<T, Res, Ctx, P, Err, IdType> Policy<Task<T, Ctx, IdType>, Res, Err> for FromTaskConfigPolicy<P>
where
T: Clone,
Ctx: Clone,
P: Policy<Task<T, Ctx, IdType>, Res, Err>,
Ctx: MetadataExt<RetryConfig>,
{
type Future = P::Future;
fn retry(
&mut self,
req: &mut Task<T, Ctx, IdType>,
result: &mut Result<Res, Err>,
) -> Option<Self::Future> {
let worker = req.parts.data.get::<WorkerContext>()?;
if worker.is_shutting_down() {
return None;
}
match result {
Ok(_) => None,
Err(_) => {
let attempt = req.parts.attempt.current();
if let Ok(cfg) = req.parts.ctx.extract() {
if cfg.retries <= attempt {
return None;
}
};
self.inner.retry(req, result)
}
}
}
fn clone_request(&mut self, req: &Task<T, Ctx, IdType>) -> Option<Task<T, Ctx, IdType>> {
self.inner.clone_request(req)
}
}
pub trait RetryMetadataExt {
fn retries(self, retries: usize) -> Self;
}
impl<Args, Ctx, IdType> RetryMetadataExt for TaskBuilder<Args, Ctx, IdType>
where
Ctx: MetadataExt<RetryConfig>,
Ctx::Error: Debug,
{
fn retries(self, retries: usize) -> Self {
self.meta(RetryConfig { retries })
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use apalis_core::{
backend::memory::MemoryStorage,
error::BoxDynError,
task::{attempt::Attempt, builder::TaskBuilder},
worker::{
builder::WorkerBuilder, context::WorkerContext, ext::event_listener::EventListenerExt,
},
};
use futures_util::SinkExt;
use crate::layers::WorkerBuilderExt;
use super::*;
#[tokio::test]
async fn basic_worker_retries() {
let mut in_memory = MemoryStorage::new();
let task1 = TaskBuilder::new(1).meta(RetryConfig { retries: 3 }).build();
let task2 = TaskBuilder::new(2).retries(5).build();
let task3 = TaskBuilder::new(3).build();
in_memory.send(task1).await.unwrap();
in_memory.send(task2).await.unwrap();
in_memory.send(task3).await.unwrap();
async fn task(
task: u32,
worker: WorkerContext,
attempts: Attempt,
) -> Result<(), BoxDynError> {
if task == 1 && attempts.current() == 4 {
unreachable!("Task 1 reached 4 attempts");
}
if task == 3 && attempts.current() == 2 {
unreachable!("Task 3 reached retried");
}
println!("Task {task} attempt {attempts:?}");
tokio::time::sleep(Duration::from_secs(1)).await;
if task == 2 && attempts.current() == 4 {
worker.stop().unwrap();
}
if task == 3 {
return Err(SkipRetryError)?;
}
Err("Always fail if not 3")?
}
#[derive(Debug)]
struct SkipRetryError;
impl std::error::Error for SkipRetryError {}
impl std::fmt::Display for SkipRetryError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "SkipRetryError")
}
}
let worker = WorkerBuilder::new("rango-tango")
.backend(in_memory)
.retry(
RetryPolicy::retries(3)
.from_task_config()
.retry_if(|e: &BoxDynError| e.downcast_ref::<SkipRetryError>().is_none()),
)
.on_event(|ctx, ev| {
println!("CTX {:?}, On Event = {ev:?}", ctx.name());
})
.build(task);
worker.run().await.unwrap();
}
}