1use async_trait::async_trait;
60use std::future::Future;
61use std::pin::Pin;
62use std::sync::Arc;
63use tracing::Instrument;
64
65use super::{Worker, WorkerContext, WorkerResult};
66use crate::core::Job;
67use crate::error::Result;
68
69type BoxedFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
73
74#[async_trait]
81pub trait JobMiddleware: Send + Sync {
82 async fn call<'a>(
83 &'a self,
84 job: &'a Job,
85 ctx: &'a WorkerContext,
86 next: Next<'a>,
87 ) -> Result<WorkerResult>;
88}
89
90pub struct Next<'a> {
96 remaining: &'a [Arc<dyn JobMiddleware>],
97 worker: &'a dyn Worker,
98}
99
100impl<'a> Next<'a> {
101 pub(crate) fn new(remaining: &'a [Arc<dyn JobMiddleware>], worker: &'a dyn Worker) -> Self {
102 Self { remaining, worker }
103 }
104
105 pub fn run(
111 self,
112 job: &'a Job,
113 ctx: &'a WorkerContext,
114 ) -> BoxedFuture<'a, Result<WorkerResult>> {
115 Box::pin(async move {
116 match self.remaining.split_first() {
117 Some((first, rest)) => {
118 let next = Next {
119 remaining: rest,
120 worker: self.worker,
121 };
122 first.call(job, ctx, next).await
123 }
124 None => self.worker.execute(job, ctx).await,
125 }
126 })
127 }
128}
129
130pub(crate) async fn run_stack(
138 middleware: &[Arc<dyn JobMiddleware>],
139 worker: &dyn Worker,
140 job: &Job,
141 ctx: &WorkerContext,
142) -> Result<WorkerResult> {
143 Next::new(middleware, worker).run(job, ctx).await
144}
145
146pub struct TracingMiddleware;
154
155#[async_trait]
156impl JobMiddleware for TracingMiddleware {
157 async fn call<'a>(
158 &'a self,
159 job: &'a Job,
160 ctx: &'a WorkerContext,
161 next: Next<'a>,
162 ) -> Result<WorkerResult> {
163 let span = tracing::info_span!(
164 "qml.job.execute",
165 job.id = %job.id,
166 job.method = %job.method,
167 job.queue = %job.queue,
168 job.attempt = job.attempt,
169 );
170 next.run(job, ctx).instrument(span).await
171 }
172}
173
174#[cfg(test)]
175mod tests {
176 use super::*;
177 use crate::processing::{Worker, WorkerConfig};
178 use async_trait::async_trait;
179 use std::sync::Mutex;
180 use std::sync::atomic::{AtomicUsize, Ordering};
181
182 struct EchoWorker;
183
184 #[async_trait]
185 impl Worker for EchoWorker {
186 async fn execute(&self, _job: &Job, _ctx: &WorkerContext) -> Result<WorkerResult> {
187 Ok(WorkerResult::success(Some("ok".to_string()), 0))
188 }
189
190 fn method_name(&self) -> &str {
191 "echo"
192 }
193 }
194
195 struct FailingWorker;
196
197 #[async_trait]
198 impl Worker for FailingWorker {
199 async fn execute(&self, _job: &Job, _ctx: &WorkerContext) -> Result<WorkerResult> {
200 Ok(WorkerResult::failure("boom".to_string()))
201 }
202
203 fn method_name(&self) -> &str {
204 "fail"
205 }
206 }
207
208 struct RecordingMiddleware {
211 tag: &'static str,
212 log: Arc<Mutex<Vec<String>>>,
213 }
214
215 #[async_trait]
216 impl JobMiddleware for RecordingMiddleware {
217 async fn call<'a>(
218 &'a self,
219 job: &'a Job,
220 ctx: &'a WorkerContext,
221 next: Next<'a>,
222 ) -> Result<WorkerResult> {
223 self.log
224 .lock()
225 .unwrap()
226 .push(format!("{}:before", self.tag));
227 let result = next.run(job, ctx).await;
228 self.log.lock().unwrap().push(format!("{}:after", self.tag));
229 result
230 }
231 }
232
233 struct ShortCircuitMiddleware;
237
238 #[async_trait]
239 impl JobMiddleware for ShortCircuitMiddleware {
240 async fn call<'a>(
241 &'a self,
242 _job: &'a Job,
243 _ctx: &'a WorkerContext,
244 _next: Next<'a>,
245 ) -> Result<WorkerResult> {
246 Ok(WorkerResult::success(Some("short".to_string()), 0))
247 }
248 }
249
250 struct CountingMiddleware {
251 successes: Arc<AtomicUsize>,
252 failures: Arc<AtomicUsize>,
253 }
254
255 #[async_trait]
256 impl JobMiddleware for CountingMiddleware {
257 async fn call<'a>(
258 &'a self,
259 job: &'a Job,
260 ctx: &'a WorkerContext,
261 next: Next<'a>,
262 ) -> Result<WorkerResult> {
263 let result = next.run(job, ctx).await;
264 match &result {
265 Ok(WorkerResult::Success { .. }) => {
266 self.successes.fetch_add(1, Ordering::Relaxed);
267 }
268 _ => {
269 self.failures.fetch_add(1, Ordering::Relaxed);
270 }
271 }
272 result
273 }
274 }
275
276 fn test_job(method: &str) -> Job {
277 Job::new(method, serde_json::Value::Null)
278 }
279
280 fn test_ctx() -> WorkerContext {
281 WorkerContext::new(WorkerConfig::new("test-worker"))
282 }
283
284 #[tokio::test]
285 async fn empty_stack_runs_terminal_worker_directly() {
286 let job = test_job("echo");
287 let ctx = test_ctx();
288 let stack: Vec<Arc<dyn JobMiddleware>> = vec![];
289 let result = run_stack(&stack, &EchoWorker, &job, &ctx).await.unwrap();
290 assert!(matches!(result, WorkerResult::Success { .. }));
291 }
292
293 #[tokio::test]
294 async fn middleware_runs_in_registration_order_outer_to_inner() {
295 let log = Arc::new(Mutex::new(Vec::new()));
298 let stack: Vec<Arc<dyn JobMiddleware>> = vec![
299 Arc::new(RecordingMiddleware {
300 tag: "A",
301 log: log.clone(),
302 }),
303 Arc::new(RecordingMiddleware {
304 tag: "B",
305 log: log.clone(),
306 }),
307 Arc::new(RecordingMiddleware {
308 tag: "C",
309 log: log.clone(),
310 }),
311 ];
312
313 let job = test_job("echo");
314 let ctx = test_ctx();
315 run_stack(&stack, &EchoWorker, &job, &ctx).await.unwrap();
316
317 let log = log.lock().unwrap().clone();
318 assert_eq!(
319 log,
320 vec![
321 "A:before".to_string(),
322 "B:before".to_string(),
323 "C:before".to_string(),
324 "C:after".to_string(),
325 "B:after".to_string(),
326 "A:after".to_string(),
327 ]
328 );
329 }
330
331 #[tokio::test]
332 async fn middleware_can_short_circuit_the_stack() {
333 let log = Arc::new(Mutex::new(Vec::new()));
337 let stack: Vec<Arc<dyn JobMiddleware>> = vec![
338 Arc::new(RecordingMiddleware {
339 tag: "A",
340 log: log.clone(),
341 }),
342 Arc::new(ShortCircuitMiddleware),
343 Arc::new(RecordingMiddleware {
344 tag: "C",
345 log: log.clone(),
346 }),
347 ];
348
349 let job = test_job("echo");
350 let ctx = test_ctx();
351 let result = run_stack(&stack, &FailingWorker, &job, &ctx).await.unwrap();
354 assert!(matches!(result, WorkerResult::Success { .. }));
355
356 let log = log.lock().unwrap().clone();
357 assert_eq!(
358 log,
359 vec!["A:before".to_string(), "A:after".to_string()],
360 "C should never have run — short-circuit layer swallowed the chain"
361 );
362 }
363
364 #[tokio::test]
365 async fn counting_middleware_distinguishes_success_and_failure() {
366 let successes = Arc::new(AtomicUsize::new(0));
367 let failures = Arc::new(AtomicUsize::new(0));
368 let stack: Vec<Arc<dyn JobMiddleware>> = vec![Arc::new(CountingMiddleware {
369 successes: successes.clone(),
370 failures: failures.clone(),
371 })];
372
373 let ctx = test_ctx();
374 run_stack(&stack, &EchoWorker, &test_job("echo"), &ctx)
375 .await
376 .unwrap();
377 run_stack(&stack, &FailingWorker, &test_job("fail"), &ctx)
378 .await
379 .unwrap();
380
381 assert_eq!(successes.load(Ordering::Relaxed), 1);
382 assert_eq!(failures.load(Ordering::Relaxed), 1);
383 }
384}