1use super::Result;
2use crate::{Counter, Job, RedisPool, RetryOpts, UnitOfWork, WorkerRef};
3use async_trait::async_trait;
4use std::sync::Arc;
5use tokio::sync::RwLock;
6use tracing::error;
7
8#[async_trait]
9pub trait ServerMiddleware {
10 async fn call(
11 &self,
12 iter: ChainIter,
13 job: &Job,
14 worker: Arc<WorkerRef>,
15 redis: RedisPool,
16 ) -> Result<()>;
17}
18
19#[derive(Clone)]
22pub struct ChainIter {
23 stack: Arc<RwLock<Vec<Box<dyn ServerMiddleware + Send + Sync>>>>,
24 index: usize,
25}
26
27impl ChainIter {
28 #[inline]
29 pub async fn next(&self, job: &Job, worker: Arc<WorkerRef>, redis: RedisPool) -> Result<()> {
30 let stack = self.stack.read().await;
31
32 if let Some(middleware) = stack.get(self.index) {
33 middleware
34 .call(
35 ChainIter {
36 stack: self.stack.clone(),
37 index: self.index + 1,
38 },
39 job,
40 worker,
41 redis,
42 )
43 .await?;
44 }
45
46 Ok(())
47 }
48}
49
50#[derive(Clone)]
52pub(crate) struct Chain {
53 stack: Arc<RwLock<Vec<Box<dyn ServerMiddleware + Send + Sync>>>>,
54}
55
56impl Chain {
57 #[allow(dead_code)]
59 pub(crate) fn empty() -> Self {
60 Self {
61 stack: Arc::new(RwLock::new(vec![])),
62 }
63 }
64
65 pub(crate) fn new_with_stats(counter: Counter) -> Self {
66 Self {
67 stack: Arc::new(RwLock::new(vec![
68 Box::new(RetryMiddleware),
69 Box::new(StatsMiddleware::new(counter)),
70 Box::new(HandlerMiddleware),
71 ])),
72 }
73 }
74
75 pub(crate) async fn using(&mut self, middleware: Box<dyn ServerMiddleware + Send + Sync>) {
76 let mut stack = self.stack.write().await;
77 let index = if stack.is_empty() { 0 } else { stack.len() - 1 };
79
80 stack.insert(index, middleware);
81 }
82
83 #[inline]
84 pub(crate) fn iter(&self) -> ChainIter {
85 ChainIter {
86 stack: self.stack.clone(),
87 index: 0,
88 }
89 }
90
91 #[inline]
92 pub(crate) async fn call(
93 &mut self,
94 job: &Job,
95 worker: Arc<WorkerRef>,
96 redis: RedisPool,
97 ) -> Result<()> {
98 self.iter().next(job, worker, redis).await
103 }
104}
105
106pub struct StatsMiddleware {
107 busy_count: Counter,
108}
109
110impl StatsMiddleware {
111 fn new(busy_count: Counter) -> Self {
112 Self { busy_count }
113 }
114}
115
116#[async_trait]
117impl ServerMiddleware for StatsMiddleware {
118 #[inline]
119 async fn call(
120 &self,
121 chain: ChainIter,
122 job: &Job,
123 worker: Arc<WorkerRef>,
124 redis: RedisPool,
125 ) -> Result<()> {
126 self.busy_count.incrby(1);
127 let res = chain.next(job, worker, redis).await;
128 self.busy_count.decrby(1);
129 res
130 }
131}
132
133struct HandlerMiddleware;
134
135#[async_trait]
136impl ServerMiddleware for HandlerMiddleware {
137 #[inline]
138 async fn call(
139 &self,
140 _chain: ChainIter,
141 job: &Job,
142 worker: Arc<WorkerRef>,
143 _redis: RedisPool,
144 ) -> Result<()> {
145 worker.call(job.args.clone()).await
146 }
147}
148
149struct RetryMiddleware;
150
151#[async_trait]
152impl ServerMiddleware for RetryMiddleware {
153 #[inline]
154 async fn call(
155 &self,
156 chain: ChainIter,
157 job: &Job,
158 worker: Arc<WorkerRef>,
159 redis: RedisPool,
160 ) -> Result<()> {
161 let max_retries = if let RetryOpts::Max(max_retries) = job.retry {
164 max_retries
165 } else {
166 worker.max_retries()
167 };
168
169 let err = {
170 match chain.next(job, worker, redis.clone()).await {
171 Ok(()) => return Ok(()),
172 Err(err) => err,
173 }
174 };
175
176 let mut job = job.clone();
177
178 job.error_message = Some(format!("{err:?}"));
180 job.error_class = Some(match &err {
181 crate::Error::Message(_) => "RuntimeError".to_string(),
182 crate::Error::Json(_) => "JSON::ParserError".to_string(),
183 crate::Error::Redis(_) | crate::Error::BB8(_) => "Redis::BaseError".to_string(),
184 _ => "StandardError".to_string(),
185 });
186 if job.retry_count.is_some() {
187 job.retried_at = Some(chrono::Utc::now().timestamp_millis() as f64);
188 } else {
189 job.failed_at = Some(chrono::Utc::now().timestamp_millis() as f64);
190 }
191 let retry_count = if job.retry_count.is_some() {
193 job.retry_count.unwrap_or(0) + 1
194 } else {
195 0
196 };
197 job.retry_count = Some(retry_count);
198
199 if retry_count >= max_retries || job.retry == RetryOpts::Never {
201 error!({
202 "status" = "dead",
203 "class" = &job.class,
204 "jid" = &job.jid,
205 "queue" = &job.queue,
206 "err" = &job.error_message
207 }, "Max retries exceeded, moving job to dead set");
208
209 let now = chrono::Utc::now().timestamp_millis() as f64 / 1000.0;
212 if let Err(err) = redis
213 .get()
214 .await?
215 .zadd("dead".to_string(), serde_json::to_string(&job)?, now)
216 .await
217 {
218 error!("Failed to add job to dead set: {:?}", err);
219 }
220 } else {
221 error!({
222 "status" = "fail",
223 "class" = &job.class,
224 "jid" = &job.jid,
225 "queue" = &job.queue,
226 "retry_queue" = &job.retry_queue,
227 "err" = &job.error_message
228 }, "Scheduling job for retry in the future");
229
230 if let Some(ref retry_queue) = job.retry_queue {
232 job.queue = retry_queue.into();
233 }
234
235 UnitOfWork::from_job(job).reenqueue(&redis).await?;
236 }
237
238 Ok(())
239 }
240}
241
242#[cfg(test)]
243mod test {
244 use super::*;
245 use crate::{RedisConnectionManager, RedisPool, RetryOpts, Worker};
246 use bb8::Pool;
247 use tokio::sync::Mutex;
248
249 async fn redis() -> RedisPool {
250 let manager = RedisConnectionManager::new("redis://127.0.0.1/").unwrap();
251 Pool::builder().build(manager).await.unwrap()
252 }
253
254 fn job() -> Job {
255 Job {
256 class: "TestWorker".into(),
257 queue: "default".into(),
258 args: vec![1337].into(),
259 retry: RetryOpts::Yes,
260 jid: crate::new_jid(),
261 created_at: 1337.0,
262 enqueued_at: None,
263 failed_at: None,
264 error_message: None,
265 error_class: None,
266 retry_count: None,
267 retried_at: None,
268 retry_queue: None,
269 unique_for: None,
270 }
271 }
272
273 #[derive(Clone)]
274 struct TestWorker {
275 touched: Arc<Mutex<bool>>,
276 }
277
278 #[async_trait]
279 impl Worker<()> for TestWorker {
280 async fn perform(&self, _args: ()) -> Result<()> {
281 *self.touched.lock().await = true;
282 Ok(())
283 }
284 }
285
286 #[tokio::test]
287 async fn calls_through_a_middleware_stack() {
288 let inner = Arc::new(TestWorker {
289 touched: Arc::new(Mutex::new(false)),
290 });
291 let worker = Arc::new(WorkerRef::wrap(Arc::clone(&inner)));
292
293 let job = job();
294 let mut chain = Chain::empty();
295 chain.using(Box::new(HandlerMiddleware)).await;
296 chain
297 .call(&job, worker.clone(), redis().await)
298 .await
299 .unwrap();
300
301 assert!(
302 *inner.touched.lock().await,
303 "The job was processed by the middleware",
304 );
305 }
306}