use async_trait::async_trait;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use tracing::Instrument;
use super::{Worker, WorkerContext, WorkerResult};
use crate::core::Job;
use crate::error::Result;
type BoxedFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
#[async_trait]
pub trait JobMiddleware: Send + Sync {
async fn call<'a>(
&'a self,
job: &'a Job,
ctx: &'a WorkerContext,
next: Next<'a>,
) -> Result<WorkerResult>;
}
pub struct Next<'a> {
remaining: &'a [Arc<dyn JobMiddleware>],
worker: &'a dyn Worker,
}
impl<'a> Next<'a> {
pub(crate) fn new(remaining: &'a [Arc<dyn JobMiddleware>], worker: &'a dyn Worker) -> Self {
Self { remaining, worker }
}
pub fn run(
self,
job: &'a Job,
ctx: &'a WorkerContext,
) -> BoxedFuture<'a, Result<WorkerResult>> {
Box::pin(async move {
match self.remaining.split_first() {
Some((first, rest)) => {
let next = Next {
remaining: rest,
worker: self.worker,
};
first.call(job, ctx, next).await
}
None => self.worker.execute(job, ctx).await,
}
})
}
}
pub(crate) async fn run_stack(
middleware: &[Arc<dyn JobMiddleware>],
worker: &dyn Worker,
job: &Job,
ctx: &WorkerContext,
) -> Result<WorkerResult> {
Next::new(middleware, worker).run(job, ctx).await
}
pub struct TracingMiddleware;
#[async_trait]
impl JobMiddleware for TracingMiddleware {
async fn call<'a>(
&'a self,
job: &'a Job,
ctx: &'a WorkerContext,
next: Next<'a>,
) -> Result<WorkerResult> {
let span = tracing::info_span!(
"qml.job.execute",
job.id = %job.id,
job.method = %job.method,
job.queue = %job.queue,
job.attempt = job.attempt,
);
next.run(job, ctx).instrument(span).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::processing::{Worker, WorkerConfig};
use async_trait::async_trait;
use std::sync::Mutex;
use std::sync::atomic::{AtomicUsize, Ordering};
struct EchoWorker;
#[async_trait]
impl Worker for EchoWorker {
async fn execute(&self, _job: &Job, _ctx: &WorkerContext) -> Result<WorkerResult> {
Ok(WorkerResult::success(Some("ok".to_string()), 0))
}
fn method_name(&self) -> &str {
"echo"
}
}
struct FailingWorker;
#[async_trait]
impl Worker for FailingWorker {
async fn execute(&self, _job: &Job, _ctx: &WorkerContext) -> Result<WorkerResult> {
Ok(WorkerResult::failure("boom".to_string()))
}
fn method_name(&self) -> &str {
"fail"
}
}
struct RecordingMiddleware {
tag: &'static str,
log: Arc<Mutex<Vec<String>>>,
}
#[async_trait]
impl JobMiddleware for RecordingMiddleware {
async fn call<'a>(
&'a self,
job: &'a Job,
ctx: &'a WorkerContext,
next: Next<'a>,
) -> Result<WorkerResult> {
self.log
.lock()
.unwrap()
.push(format!("{}:before", self.tag));
let result = next.run(job, ctx).await;
self.log.lock().unwrap().push(format!("{}:after", self.tag));
result
}
}
struct ShortCircuitMiddleware;
#[async_trait]
impl JobMiddleware for ShortCircuitMiddleware {
async fn call<'a>(
&'a self,
_job: &'a Job,
_ctx: &'a WorkerContext,
_next: Next<'a>,
) -> Result<WorkerResult> {
Ok(WorkerResult::success(Some("short".to_string()), 0))
}
}
struct CountingMiddleware {
successes: Arc<AtomicUsize>,
failures: Arc<AtomicUsize>,
}
#[async_trait]
impl JobMiddleware for CountingMiddleware {
async fn call<'a>(
&'a self,
job: &'a Job,
ctx: &'a WorkerContext,
next: Next<'a>,
) -> Result<WorkerResult> {
let result = next.run(job, ctx).await;
match &result {
Ok(WorkerResult::Success { .. }) => {
self.successes.fetch_add(1, Ordering::Relaxed);
}
_ => {
self.failures.fetch_add(1, Ordering::Relaxed);
}
}
result
}
}
fn test_job(method: &str) -> Job {
Job::new(method, serde_json::Value::Null)
}
fn test_ctx() -> WorkerContext {
WorkerContext::new(WorkerConfig::new("test-worker"))
}
#[tokio::test]
async fn empty_stack_runs_terminal_worker_directly() {
let job = test_job("echo");
let ctx = test_ctx();
let stack: Vec<Arc<dyn JobMiddleware>> = vec![];
let result = run_stack(&stack, &EchoWorker, &job, &ctx).await.unwrap();
assert!(matches!(result, WorkerResult::Success { .. }));
}
#[tokio::test]
async fn middleware_runs_in_registration_order_outer_to_inner() {
let log = Arc::new(Mutex::new(Vec::new()));
let stack: Vec<Arc<dyn JobMiddleware>> = vec![
Arc::new(RecordingMiddleware {
tag: "A",
log: log.clone(),
}),
Arc::new(RecordingMiddleware {
tag: "B",
log: log.clone(),
}),
Arc::new(RecordingMiddleware {
tag: "C",
log: log.clone(),
}),
];
let job = test_job("echo");
let ctx = test_ctx();
run_stack(&stack, &EchoWorker, &job, &ctx).await.unwrap();
let log = log.lock().unwrap().clone();
assert_eq!(
log,
vec![
"A:before".to_string(),
"B:before".to_string(),
"C:before".to_string(),
"C:after".to_string(),
"B:after".to_string(),
"A:after".to_string(),
]
);
}
#[tokio::test]
async fn middleware_can_short_circuit_the_stack() {
let log = Arc::new(Mutex::new(Vec::new()));
let stack: Vec<Arc<dyn JobMiddleware>> = vec![
Arc::new(RecordingMiddleware {
tag: "A",
log: log.clone(),
}),
Arc::new(ShortCircuitMiddleware),
Arc::new(RecordingMiddleware {
tag: "C",
log: log.clone(),
}),
];
let job = test_job("echo");
let ctx = test_ctx();
let result = run_stack(&stack, &FailingWorker, &job, &ctx).await.unwrap();
assert!(matches!(result, WorkerResult::Success { .. }));
let log = log.lock().unwrap().clone();
assert_eq!(
log,
vec!["A:before".to_string(), "A:after".to_string()],
"C should never have run — short-circuit layer swallowed the chain"
);
}
#[tokio::test]
async fn counting_middleware_distinguishes_success_and_failure() {
let successes = Arc::new(AtomicUsize::new(0));
let failures = Arc::new(AtomicUsize::new(0));
let stack: Vec<Arc<dyn JobMiddleware>> = vec![Arc::new(CountingMiddleware {
successes: successes.clone(),
failures: failures.clone(),
})];
let ctx = test_ctx();
run_stack(&stack, &EchoWorker, &test_job("echo"), &ctx)
.await
.unwrap();
run_stack(&stack, &FailingWorker, &test_job("fail"), &ctx)
.await
.unwrap();
assert_eq!(successes.load(Ordering::Relaxed), 1);
assert_eq!(failures.load(Ordering::Relaxed), 1);
}
}