use super::task::Task;
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct TaskHookContext {
pub task: Task,
pub attempt: u32,
pub executor: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum RetryDecision {
Retry { delay_secs: u64 },
Skip,
Fail,
Ignore { message: String },
}
#[async_trait::async_trait]
pub trait TaskHooks: Send + Sync {
async fn before_execute(&self, _ctx: &TaskHookContext) {}
async fn after_execute(&self, _ctx: &TaskHookContext, _result: &str) {}
async fn on_failure(&self, _ctx: &TaskHookContext, _error: &str) -> RetryDecision {
RetryDecision::Fail
}
async fn on_timeout(&self, ctx: &TaskHookContext) -> RetryDecision {
self.on_failure(ctx, "Task timed out").await
}
async fn on_cancelled(&self, _ctx: &TaskHookContext) {}
async fn should_retry(&self, _ctx: &TaskHookContext, _error: &str) -> bool {
true
}
}
pub struct NoopHooks;
#[async_trait::async_trait]
impl TaskHooks for NoopHooks {}
pub struct LoggingHooks;
#[async_trait::async_trait]
impl TaskHooks for LoggingHooks {
async fn before_execute(&self, ctx: &TaskHookContext) {
tracing::info!(
task_id = %ctx.task.id,
subject = %ctx.task.subject,
attempt = ctx.attempt,
"task_before_execute"
);
}
async fn after_execute(&self, ctx: &TaskHookContext, result: &str) {
let preview = if result.len() > 100 {
format!("{}...", &result[..100])
} else {
result.to_string()
};
tracing::info!(
task_id = %ctx.task.id,
subject = %ctx.task.subject,
result = %preview,
"task_after_execute"
);
}
async fn on_failure(&self, ctx: &TaskHookContext, error: &str) -> RetryDecision {
tracing::warn!(
task_id = %ctx.task.id,
subject = %ctx.task.subject,
attempt = ctx.attempt,
error = %error,
"task_on_failure"
);
if ctx.task.retry_count < ctx.task.max_retries {
RetryDecision::Retry { delay_secs: 1 }
} else {
RetryDecision::Fail
}
}
async fn on_timeout(&self, ctx: &TaskHookContext) -> RetryDecision {
tracing::warn!(
task_id = %ctx.task.id,
subject = %ctx.task.subject,
timeout_secs = ctx.task.timeout_secs,
"task_on_timeout"
);
RetryDecision::Fail
}
async fn on_cancelled(&self, ctx: &TaskHookContext) {
tracing::info!(
task_id = %ctx.task.id,
subject = %ctx.task.subject,
"task_on_cancelled"
);
}
}
pub struct TaskHookRegistry {
hooks: Vec<Arc<dyn TaskHooks>>,
}
impl TaskHookRegistry {
pub fn new() -> Self {
Self { hooks: Vec::new() }
}
pub fn with_logging() -> Self {
let mut registry = Self::new();
registry.register(Arc::new(LoggingHooks));
registry
}
pub fn register(&mut self, hook: Arc<dyn TaskHooks>) {
self.hooks.push(hook);
}
pub async fn before_execute(&self, ctx: &TaskHookContext) {
for hook in &self.hooks {
hook.before_execute(ctx).await;
}
}
pub async fn after_execute(&self, ctx: &TaskHookContext, result: &str) {
for hook in &self.hooks {
hook.after_execute(ctx, result).await;
}
}
pub async fn on_failure(&self, ctx: &TaskHookContext, error: &str) -> RetryDecision {
for hook in &self.hooks {
let decision = hook.on_failure(ctx, error).await;
if decision != RetryDecision::Fail {
return decision;
}
}
RetryDecision::Fail
}
pub async fn on_timeout(&self, ctx: &TaskHookContext) -> RetryDecision {
for hook in &self.hooks {
let decision = hook.on_timeout(ctx).await;
if decision != RetryDecision::Fail {
return decision;
}
}
RetryDecision::Fail
}
pub async fn on_cancelled(&self, ctx: &TaskHookContext) {
for hook in &self.hooks {
hook.on_cancelled(ctx).await;
}
}
pub async fn should_retry(&self, ctx: &TaskHookContext, error: &str) -> bool {
for hook in &self.hooks {
if !hook.should_retry(ctx, error).await {
return false;
}
}
true
}
pub fn is_empty(&self) -> bool {
self.hooks.is_empty()
}
pub fn len(&self) -> usize {
self.hooks.len()
}
}
impl Default for TaskHookRegistry {
fn default() -> Self {
Self::new()
}
}
impl Clone for TaskHookRegistry {
fn clone(&self) -> Self {
Self {
hooks: self.hooks.clone(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
struct TestHooks {
before_count: std::sync::atomic::AtomicU32,
after_count: std::sync::atomic::AtomicU32,
failure_count: std::sync::atomic::AtomicU32,
}
impl TestHooks {
fn new() -> Self {
Self {
before_count: std::sync::atomic::AtomicU32::new(0),
after_count: std::sync::atomic::AtomicU32::new(0),
failure_count: std::sync::atomic::AtomicU32::new(0),
}
}
}
#[async_trait::async_trait]
impl TaskHooks for TestHooks {
async fn before_execute(&self, _ctx: &TaskHookContext) {
self.before_count
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
}
async fn after_execute(&self, _ctx: &TaskHookContext, _result: &str) {
self.after_count
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
}
async fn on_failure(&self, _ctx: &TaskHookContext, _error: &str) -> RetryDecision {
self.failure_count
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
RetryDecision::Skip
}
}
#[tokio::test]
async fn test_hooks_called() {
let hooks = Arc::new(TestHooks::new());
let mut registry = TaskHookRegistry::new();
registry.register(hooks.clone());
let task = Task::new("test", "Test task");
let ctx = TaskHookContext {
task,
attempt: 1,
executor: None,
};
registry.before_execute(&ctx).await;
registry.after_execute(&ctx, "done").await;
registry.on_failure(&ctx, "error").await;
assert_eq!(
hooks.before_count.load(std::sync::atomic::Ordering::SeqCst),
1
);
assert_eq!(
hooks.after_count.load(std::sync::atomic::Ordering::SeqCst),
1
);
assert_eq!(
hooks
.failure_count
.load(std::sync::atomic::Ordering::SeqCst),
1
);
}
#[test]
fn test_registry_default() {
let registry = TaskHookRegistry::default();
assert!(registry.is_empty());
}
#[test]
fn test_registry_with_logging() {
let registry = TaskHookRegistry::with_logging();
assert_eq!(registry.len(), 1);
}
}