1use apalis_core::backend::{BackendExpose, Stat, WorkerState};
2use apalis_core::codec::json::JsonCodec;
3use apalis_core::error::{BoxDynError, Error};
4use apalis_core::layers::{Ack, AckLayer};
5use apalis_core::notify::Notify;
6use apalis_core::poller::controller::Controller;
7use apalis_core::poller::stream::BackendStream;
8use apalis_core::poller::Poller;
9use apalis_core::request::{Parts, Request, RequestStream, State};
10use apalis_core::response::Response;
11use apalis_core::storage::Storage;
12use apalis_core::task::namespace::Namespace;
13use apalis_core::task::task_id::TaskId;
14use apalis_core::worker::{Context, Event, Worker, WorkerId};
15use apalis_core::{backend::Backend, codec::Codec};
16use async_stream::try_stream;
17use chrono::{DateTime, Utc};
18use futures::{Stream, StreamExt, TryStreamExt};
19use log::error;
20use serde::{de::DeserializeOwned, Serialize};
21use serde_json::Value;
22use sqlx::mysql::MySqlRow;
23use sqlx::{MySql, Pool, Row};
24use std::any::type_name;
25use std::convert::TryInto;
26use std::fmt::Debug;
27use std::sync::Arc;
28use std::{fmt, io};
29use std::{marker::PhantomData, ops::Add, time::Duration};
30
31use crate::context::SqlContext;
32use crate::from_row::SqlRequest;
33use crate::{calculate_status, Config, SqlError};
34
35pub use sqlx::mysql::MySqlPool;
36
37type MysqlCodec = JsonCodec<Value>;
38
39pub struct MysqlStorage<T, C = JsonCodec<Value>>
42where
43 C: Codec,
44{
45 pool: Pool<MySql>,
46 job_type: PhantomData<T>,
47 controller: Controller,
48 config: Config,
49 codec: PhantomData<C>,
50 ack_notify: Notify<(SqlContext, Response<C::Compact>)>,
51}
52
53impl<T, C> fmt::Debug for MysqlStorage<T, C>
54where
55 C: Codec,
56 C::Compact: Debug,
57{
58 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
59 f.debug_struct("MysqlStorage")
60 .field("pool", &self.pool)
61 .field("job_type", &"PhantomData<T>")
62 .field("controller", &self.controller)
63 .field("config", &self.config)
64 .field("codec", &std::any::type_name::<C>())
65 .field("ack_notify", &self.ack_notify)
66 .finish()
67 }
68}
69
70impl<T, C> Clone for MysqlStorage<T, C>
71where
72 C: Codec,
73{
74 fn clone(&self) -> Self {
75 let pool = self.pool.clone();
76 MysqlStorage {
77 pool,
78 job_type: PhantomData,
79 controller: self.controller.clone(),
80 config: self.config.clone(),
81 codec: self.codec,
82 ack_notify: self.ack_notify.clone(),
83 }
84 }
85}
86
87impl MysqlStorage<(), JsonCodec<Value>> {
88 #[cfg(feature = "migrate")]
90 pub fn migrations() -> sqlx::migrate::Migrator {
91 sqlx::migrate!("migrations/mysql")
92 }
93
94 #[cfg(feature = "migrate")]
96 pub async fn setup(pool: &Pool<MySql>) -> Result<(), sqlx::Error> {
97 Self::migrations().run(pool).await?;
98 Ok(())
99 }
100}
101
102impl<T> MysqlStorage<T>
103where
104 T: Serialize + DeserializeOwned,
105{
106 pub fn new(pool: MySqlPool) -> Self {
108 Self::new_with_config(pool, Config::new(type_name::<T>()))
109 }
110
111 pub fn new_with_config(pool: MySqlPool, config: Config) -> Self {
113 Self {
114 pool,
115 job_type: PhantomData,
116 controller: Controller::new(),
117 config,
118 ack_notify: Notify::new(),
119 codec: PhantomData,
120 }
121 }
122
123 pub fn pool(&self) -> &Pool<MySql> {
125 &self.pool
126 }
127
128 pub fn get_config(&self) -> &Config {
130 &self.config
131 }
132}
133
134impl<T, C> MysqlStorage<T, C>
135where
136 T: DeserializeOwned + Send + Sync + 'static,
137 C: Codec<Compact = Value> + Send + 'static,
138{
139 fn stream_jobs(
140 self,
141 worker: &Worker<Context>,
142 interval: Duration,
143 buffer_size: usize,
144 ) -> impl Stream<Item = Result<Option<Request<T, SqlContext>>, sqlx::Error>> {
145 let pool = self.pool.clone();
146 let worker = worker.clone();
147 let worker_id = worker.id().to_string();
148
149 try_stream! {
150 let buffer_size = u32::try_from(buffer_size)
151 .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidInput, e)))?;
152 loop {
153 apalis_core::sleep(interval).await;
154 if !worker.is_ready() {
155 continue;
156 }
157 let pool = pool.clone();
158 let job_type = self.config.namespace.clone();
159 let mut tx = pool.begin().await?;
160 let fetch_query = "SELECT id FROM jobs
161 WHERE (status='Pending' OR (status = 'Failed' AND attempts < max_attempts)) AND run_at <= NOW() AND job_type = ? ORDER BY priority DESC, run_at ASC LIMIT ? FOR UPDATE SKIP LOCKED";
162 let task_ids: Vec<MySqlRow> = sqlx::query(fetch_query)
163 .bind(job_type)
164 .bind(buffer_size)
165 .fetch_all(&mut *tx).await?;
166 if task_ids.is_empty() {
167 tx.rollback().await?;
168 yield None
169 } else {
170 let task_ids: Vec<String> = task_ids.iter().map(|r| r.get_unchecked("id")).collect();
171 let id_params = format!("?{}", ", ?".repeat(task_ids.len() - 1));
172 let query = format!("UPDATE jobs SET status = 'Running', lock_by = ?, lock_at = NOW() WHERE id IN({}) AND status = 'Pending' AND lock_by IS NULL;", id_params);
173 let mut query = sqlx::query(&query).bind(worker_id.clone());
174 for i in &task_ids {
175 query = query.bind(i);
176 }
177 query.execute(&mut *tx).await?;
178 tx.commit().await?;
179
180 let fetch_query = format!("SELECT * FROM jobs WHERE ID IN ({}) ORDER BY priority DESC, run_at ASC", id_params);
181 let mut query = sqlx::query_as(&fetch_query);
182 for i in task_ids {
183 query = query.bind(i);
184 }
185 let jobs: Vec<SqlRequest<Value>> = query.fetch_all(&pool).await?;
186
187 for job in jobs {
188 yield {
189 let (req, ctx) = job.req.take_parts();
190 let req = C::decode(req)
191 .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
192 let mut req: Request<T, SqlContext> = Request::new_with_parts(req, ctx);
193 req.parts.namespace = Some(Namespace(self.config.namespace.clone()));
194 Some(req)
195 }
196 }
197 }
198 }
199 }.boxed()
200 }
201
202 async fn keep_alive_at<Service>(
203 &mut self,
204 worker_id: &WorkerId,
205 last_seen: DateTime<Utc>,
206 ) -> Result<(), sqlx::Error> {
207 let pool = self.pool.clone();
208
209 let mut tx = pool.acquire().await?;
210 let worker_type = self.config.namespace.clone();
211 let storage_name = std::any::type_name::<Self>();
212 let query =
213 "INSERT INTO workers (id, worker_type, storage_name, layers, last_seen) VALUES (?, ?, ?, ?, ?) ON DUPLICATE KEY UPDATE id = ?;";
214 sqlx::query(query)
215 .bind(worker_id.to_string())
216 .bind(worker_type)
217 .bind(storage_name)
218 .bind(std::any::type_name::<Service>())
219 .bind(last_seen)
220 .bind(worker_id.to_string())
221 .execute(&mut *tx)
222 .await?;
223 Ok(())
224 }
225}
226
227impl<T, C> Storage for MysqlStorage<T, C>
228where
229 T: Serialize + DeserializeOwned + Send + 'static + Unpin + Sync,
230 C: Codec<Compact = Value> + Send + Sync + 'static,
231 C::Error: std::error::Error + 'static + Send + Sync,
232{
233 type Job = T;
234
235 type Error = sqlx::Error;
236
237 type Context = SqlContext;
238
239 type Compact = Value;
240
241 async fn push_request(
242 &mut self,
243 job: Request<Self::Job, SqlContext>,
244 ) -> Result<Parts<SqlContext>, sqlx::Error> {
245 let (args, parts) = job.take_parts();
246 let query =
247 "INSERT INTO jobs VALUES (?, ?, ?, 'Pending', 0, ?, now(), NULL, NULL, NULL, NULL, ?)";
248 let pool = self.pool.clone();
249
250 let job = C::encode(args)
251 .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
252 let job_type = self.config.namespace.clone();
253 sqlx::query(query)
254 .bind(job)
255 .bind(parts.task_id.to_string())
256 .bind(job_type.to_string())
257 .bind(parts.context.max_attempts())
258 .bind(parts.context.priority())
259 .execute(&pool)
260 .await?;
261 Ok(parts)
262 }
263
264 async fn push_raw_request(
265 &mut self,
266 job: Request<Self::Compact, SqlContext>,
267 ) -> Result<Parts<SqlContext>, sqlx::Error> {
268 let (args, parts) = job.take_parts();
269 let query =
270 "INSERT INTO jobs VALUES (?, ?, ?, 'Pending', 0, ?, now(), NULL, NULL, NULL, NULL, ?)";
271 let pool = self.pool.clone();
272
273 let job = C::encode(args)
274 .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
275 let job_type = self.config.namespace.clone();
276 sqlx::query(query)
277 .bind(job)
278 .bind(parts.task_id.to_string())
279 .bind(job_type.to_string())
280 .bind(parts.context.max_attempts())
281 .bind(parts.context.priority())
282 .execute(&pool)
283 .await?;
284 Ok(parts)
285 }
286
287 async fn schedule_request(
288 &mut self,
289 req: Request<Self::Job, SqlContext>,
290 on: i64,
291 ) -> Result<Parts<Self::Context>, sqlx::Error> {
292 let query =
293 "INSERT INTO jobs VALUES (?, ?, ?, 'Pending', 0, ?, ?, NULL, NULL, NULL, NULL, ?)";
294 let pool = self.pool.clone();
295
296 let args = C::encode(&req.args)
297 .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
298
299 let on = DateTime::from_timestamp(on, 0);
300
301 let job_type = self.config.namespace.clone();
302 sqlx::query(query)
303 .bind(args)
304 .bind(req.parts.task_id.to_string())
305 .bind(job_type)
306 .bind(req.parts.context.max_attempts())
307 .bind(on)
308 .bind(req.parts.context.priority())
309 .execute(&pool)
310 .await?;
311 Ok(req.parts)
312 }
313
314 async fn fetch_by_id(
315 &mut self,
316 job_id: &TaskId,
317 ) -> Result<Option<Request<Self::Job, SqlContext>>, sqlx::Error> {
318 let pool = self.pool.clone();
319
320 let fetch_query = "SELECT * FROM jobs WHERE id = ?";
321 let res: Option<SqlRequest<serde_json::Value>> = sqlx::query_as(fetch_query)
322 .bind(job_id.to_string())
323 .fetch_optional(&pool)
324 .await?;
325 match res {
326 None => Ok(None),
327 Some(job) => Ok(Some({
328 let (req, parts) = job.req.take_parts();
329 let req = C::decode(req)
330 .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
331 let mut req = Request::new_with_parts(req, parts);
332 req.parts.namespace = Some(Namespace(self.config.namespace.clone()));
333 req
334 })),
335 }
336 }
337
338 async fn len(&mut self) -> Result<i64, sqlx::Error> {
339 let pool = self.pool.clone();
340
341 let query = "Select Count(*) as count from jobs where status='Pending'";
342 let record = sqlx::query(query).fetch_one(&pool).await?;
343 record.try_get("count")
344 }
345
346 async fn reschedule(
347 &mut self,
348 job: Request<T, SqlContext>,
349 wait: Duration,
350 ) -> Result<(), sqlx::Error> {
351 let pool = self.pool.clone();
352 let job_id = job.parts.task_id.clone();
353
354 let wait: i64 = wait
355 .as_secs()
356 .try_into()
357 .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidInput, e)))?;
358 let mut tx = pool.acquire().await?;
359 let query =
360 "UPDATE jobs SET status = 'Pending', done_at = NULL, lock_by = NULL, lock_at = NULL, run_at = ? WHERE id = ?";
361
362 sqlx::query(query)
363 .bind(Utc::now().timestamp().add(wait))
364 .bind(job_id.to_string())
365 .execute(&mut *tx)
366 .await?;
367 Ok(())
368 }
369
370 async fn update(&mut self, job: Request<Self::Job, SqlContext>) -> Result<(), sqlx::Error> {
371 let pool = self.pool.clone();
372 let ctx = job.parts.context;
373 let status = ctx.status().to_string();
374 let attempts = job.parts.attempt;
375 let done_at = *ctx.done_at();
376 let lock_by = ctx.lock_by().clone();
377 let lock_at = *ctx.lock_at();
378 let last_error = ctx.last_error().clone();
379 let priority = *ctx.priority();
380 let job_id = job.parts.task_id;
381 let mut tx = pool.acquire().await?;
382 let query =
383 "UPDATE jobs SET status = ?, attempts = ?, done_at = ?, lock_by = ?, lock_at = ?, last_error = ?, priority = ? WHERE id = ?";
384 sqlx::query(query)
385 .bind(status.to_owned())
386 .bind::<i64>(
387 attempts
388 .current()
389 .try_into()
390 .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidInput, e)))?,
391 )
392 .bind(done_at)
393 .bind(lock_by.map(|w| w.name().to_string()))
394 .bind(lock_at)
395 .bind(last_error)
396 .bind(priority)
397 .bind(job_id.to_string())
398 .execute(&mut *tx)
399 .await?;
400 Ok(())
401 }
402
403 async fn is_empty(&mut self) -> Result<bool, Self::Error> {
404 Ok(self.len().await? == 0)
405 }
406
407 async fn vacuum(&mut self) -> Result<usize, sqlx::Error> {
408 let pool = self.pool.clone();
409 let query = "Delete from jobs where status='Done'";
410 let record = sqlx::query(query).execute(&pool).await?;
411 Ok(record.rows_affected().try_into().unwrap_or_default())
412 }
413}
414
415#[derive(thiserror::Error, Debug)]
417pub enum MysqlPollError {
418 #[error("Encountered an error during ACK: `{0}`")]
420 AckError(sqlx::Error),
421
422 #[error("Encountered an error during encoding the result: {0}")]
424 CodecError(BoxDynError),
425
426 #[error("Encountered an error during KeepAlive heartbeat: `{0}`")]
428 KeepAliveError(sqlx::Error),
429
430 #[error("Encountered an error during ReenqueueOrphaned heartbeat: `{0}`")]
432 ReenqueueOrphanedError(sqlx::Error),
433}
434
435impl<Req, C> Backend<Request<Req, SqlContext>> for MysqlStorage<Req, C>
436where
437 Req: Serialize + DeserializeOwned + Sync + Send + 'static,
438 C: Codec<Compact = Value> + Send + 'static + Sync,
439 C::Error: std::error::Error + 'static + Send + Sync,
440{
441 type Stream = BackendStream<RequestStream<Request<Req, SqlContext>>>;
442
443 type Layer = AckLayer<MysqlStorage<Req, C>, Req, SqlContext, C>;
444
445 type Codec = C;
446
447 fn poll(self, worker: &Worker<Context>) -> Poller<Self::Stream, Self::Layer> {
448 let layer = AckLayer::new(self.clone());
449 let config = self.config.clone();
450 let controller = self.controller.clone();
451 let pool = self.pool.clone();
452 let ack_notify = self.ack_notify.clone();
453 let mut hb_storage = self.clone();
454 let requeue_storage = self.clone();
455 let stream = self
456 .stream_jobs(worker, config.poll_interval, config.buffer_size)
457 .map_err(|e| Error::SourceError(Arc::new(Box::new(e))));
458 let stream = BackendStream::new(stream.boxed(), controller);
459 let w = worker.clone();
460
461 let ack_heartbeat = async move {
462 while let Some(ids) = ack_notify
463 .clone()
464 .ready_chunks(config.buffer_size)
465 .next()
466 .await
467 {
468 for (ctx, res) in ids {
469 let query = "UPDATE jobs SET status = ?, done_at = now(), last_error = ?, attempts = ? WHERE id = ? AND lock_by = ?";
470 let query = sqlx::query(query);
471 let last_result =
472 C::encode(res.inner.as_ref().map_err(|e| e.to_string())).map_err(Box::new);
473 match (last_result, ctx.lock_by()) {
474 (Ok(val), Some(worker_id)) => {
475 let query = query
476 .bind(calculate_status(&ctx, &res).to_string())
477 .bind(val)
478 .bind(res.attempt.current() as i32)
479 .bind(res.task_id.to_string())
480 .bind(worker_id.to_string());
481 if let Err(e) = query.execute(&pool).await {
482 w.emit(Event::Error(Box::new(MysqlPollError::AckError(e))));
483 }
484 }
485 (Err(error), Some(_)) => {
486 w.emit(Event::Error(Box::new(MysqlPollError::CodecError(error))));
487 }
488 _ => {
489 unreachable!(
490 "Attempted to ACK without a worker attached. This is a bug, File it on the repo"
491 );
492 }
493 }
494 }
495
496 apalis_core::sleep(config.poll_interval).await;
497 }
498 };
499 let w = worker.clone();
500 let heartbeat = async move {
501 if let Err(e) = hb_storage
503 .reenqueue_orphaned((config.buffer_size * 10) as i32, Utc::now())
504 .await
505 {
506 w.emit(Event::Error(Box::new(
507 MysqlPollError::ReenqueueOrphanedError(e),
508 )));
509 }
510
511 loop {
512 let now = Utc::now();
513 if let Err(e) = hb_storage.keep_alive_at::<Self::Layer>(w.id(), now).await {
514 w.emit(Event::Error(Box::new(MysqlPollError::KeepAliveError(e))));
515 }
516 apalis_core::sleep(config.keep_alive).await;
517 }
518 };
519 let w = worker.clone();
520 let reenqueue_beat = async move {
521 loop {
522 let dead_since = Utc::now()
523 - chrono::Duration::from_std(config.reenqueue_orphaned_after)
524 .expect("Could not calculate dead since");
525 if let Err(e) = requeue_storage
526 .reenqueue_orphaned(
527 config
528 .buffer_size
529 .try_into()
530 .expect("Could not convert usize to i32"),
531 dead_since,
532 )
533 .await
534 {
535 w.emit(Event::Error(Box::new(
536 MysqlPollError::ReenqueueOrphanedError(e),
537 )));
538 }
539 apalis_core::sleep(config.poll_interval).await;
540 }
541 };
542 Poller::new_with_layer(
543 stream,
544 async {
545 futures::join!(heartbeat, ack_heartbeat, reenqueue_beat);
546 },
547 layer,
548 )
549 }
550}
551
552impl<T, Res, C> Ack<T, Res, C> for MysqlStorage<T, C>
553where
554 T: Sync + Send,
555 Res: Serialize + Send + 'static + Sync,
556 C: Codec<Compact = Value> + Send,
557 C::Error: Debug,
558{
559 type Context = SqlContext;
560 type AckError = sqlx::Error;
561 async fn ack(&mut self, ctx: &Self::Context, res: &Response<Res>) -> Result<(), sqlx::Error> {
562 self.ack_notify
563 .notify((
564 ctx.clone(),
565 res.map(|res| C::encode(res).expect("Could not encode result")),
566 ))
567 .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::BrokenPipe, e)))?;
568
569 Ok(())
570 }
571}
572
573impl<T, C: Codec> MysqlStorage<T, C> {
574 pub async fn kill(&mut self, worker_id: &WorkerId, job_id: &TaskId) -> Result<(), sqlx::Error> {
576 let pool = self.pool.clone();
577
578 let mut tx = pool.acquire().await?;
579 let query =
580 "UPDATE jobs SET status = 'Killed', done_at = NOW() WHERE id = ? AND lock_by = ?";
581 sqlx::query(query)
582 .bind(job_id.to_string())
583 .bind(worker_id.to_string())
584 .execute(&mut *tx)
585 .await?;
586 Ok(())
587 }
588
589 pub async fn retry(
591 &mut self,
592 worker_id: &WorkerId,
593 job_id: &TaskId,
594 ) -> Result<(), sqlx::Error> {
595 let pool = self.pool.clone();
596
597 let mut tx = pool.acquire().await?;
598 let query =
599 "UPDATE jobs SET status = 'Pending', done_at = NULL, lock_by = NULL WHERE id = ? AND lock_by = ?";
600 sqlx::query(query)
601 .bind(job_id.to_string())
602 .bind(worker_id.to_string())
603 .execute(&mut *tx)
604 .await?;
605 Ok(())
606 }
607
608 pub async fn reenqueue_orphaned(
610 &self,
611 count: i32,
612 dead_since: DateTime<Utc>,
613 ) -> Result<bool, sqlx::Error> {
614 let job_type = self.config.namespace.clone();
615 let mut tx = self.pool.acquire().await?;
616 let query = r#"Update jobs
617 INNER JOIN ( SELECT workers.id as worker_id, jobs.id as job_id from workers INNER JOIN jobs ON jobs.lock_by = workers.id WHERE jobs.attempts < jobs.max_attempts AND jobs.status = "Running" AND workers.last_seen < ? AND workers.worker_type = ?
618 ORDER BY lock_at ASC LIMIT ?) as workers ON jobs.lock_by = workers.worker_id AND jobs.id = workers.job_id
619 SET status = "Pending", done_at = NULL, lock_by = NULL, lock_at = NULL, last_error ="Job was abandoned", attempts = attempts + 1;"#;
620
621 sqlx::query(query)
622 .bind(dead_since)
623 .bind(job_type)
624 .bind(count)
625 .execute(&mut *tx)
626 .await?;
627 Ok(true)
628 }
629}
630
631impl<J: 'static + Serialize + DeserializeOwned + Unpin + Send + Sync> BackendExpose<J>
632 for MysqlStorage<J>
633{
634 type Request = Request<J, Parts<SqlContext>>;
635 type Error = SqlError;
636 async fn stats(&self) -> Result<Stat, Self::Error> {
637 let fetch_query = "SELECT
638 COUNT(CASE WHEN status = 'Pending' THEN 1 END) AS pending,
639 COUNT(CASE WHEN status = 'Running' THEN 1 END) AS running,
640 COUNT(CASE WHEN status = 'Done' THEN 1 END) AS done,
641 COUNT(CASE WHEN status = 'Retry' THEN 1 END) AS retry,
642 COUNT(CASE WHEN status = 'Failed' THEN 1 END) AS failed,
643 COUNT(CASE WHEN status = 'Killed' THEN 1 END) AS killed
644 FROM jobs WHERE job_type = ?";
645
646 let res: (i64, i64, i64, i64, i64, i64) = sqlx::query_as(fetch_query)
647 .bind(self.get_config().namespace())
648 .fetch_one(self.pool())
649 .await?;
650
651 Ok(Stat {
652 pending: res.0.try_into()?,
653 running: res.1.try_into()?,
654 dead: res.4.try_into()?,
655 failed: res.3.try_into()?,
656 success: res.2.try_into()?,
657 })
658 }
659
660 async fn list_jobs(
661 &self,
662 status: &State,
663 page: i32,
664 ) -> Result<Vec<Self::Request>, Self::Error> {
665 let status = status.to_string();
666 let fetch_query = "SELECT * FROM jobs WHERE status = ? AND job_type = ? ORDER BY done_at DESC, run_at DESC LIMIT 10 OFFSET ?";
667 let res: Vec<SqlRequest<serde_json::Value>> = sqlx::query_as(fetch_query)
668 .bind(status)
669 .bind(self.get_config().namespace())
670 .bind(((page - 1) * 10).to_string())
671 .fetch_all(self.pool())
672 .await?;
673 Ok(res
674 .into_iter()
675 .map(|j| {
676 let (req, ctx) = j.req.take_parts();
677 let req: J = MysqlCodec::decode(req).unwrap();
678 Request::new_with_ctx(req, ctx)
679 })
680 .collect())
681 }
682
683 async fn list_workers(&self) -> Result<Vec<Worker<WorkerState>>, Self::Error> {
684 let fetch_query =
685 "SELECT id, layers, last_seen FROM workers WHERE worker_type = ? ORDER BY last_seen DESC LIMIT 20 OFFSET ?";
686 let res: Vec<(String, String, i64)> = sqlx::query_as(fetch_query)
687 .bind(self.get_config().namespace())
688 .bind(0)
689 .fetch_all(self.pool())
690 .await?;
691 Ok(res
692 .into_iter()
693 .map(|w| Worker::new(WorkerId::new(w.0), WorkerState::new::<Self>(w.1)))
694 .collect())
695 }
696}
697
698#[cfg(test)]
699mod tests {
700
701 use crate::sql_storage_tests;
702
703 use super::*;
704
705 use apalis_core::test_utils::DummyService;
706 use email_service::Email;
707 use futures::StreamExt;
708
709 use apalis_core::generic_storage_test;
710 use apalis_core::test_utils::apalis_test_service_fn;
711 use apalis_core::test_utils::TestWrapper;
712
713 generic_storage_test!(setup);
714
715 sql_storage_tests!(setup::<Email>, MysqlStorage<Email>, Email);
716
717 async fn setup<T: Serialize + DeserializeOwned>() -> MysqlStorage<T> {
719 let db_url = &std::env::var("DATABASE_URL").expect("No DATABASE_URL is specified");
720 let pool = MySqlPool::connect(db_url).await.unwrap();
724 MysqlStorage::setup(&pool)
725 .await
726 .expect("failed to migrate DB");
727 let mut storage = MysqlStorage::new(pool);
728 cleanup(&mut storage, &WorkerId::new("test-worker")).await;
729 storage
730 }
731
732 async fn cleanup<T>(storage: &mut MysqlStorage<T>, worker_id: &WorkerId) {
739 sqlx::query("DELETE FROM jobs WHERE job_type = ?")
740 .bind(storage.config.namespace())
741 .execute(&storage.pool)
742 .await
743 .expect("failed to delete jobs");
744 sqlx::query("DELETE FROM workers WHERE id = ?")
745 .bind(worker_id.to_string())
746 .execute(&storage.pool)
747 .await
748 .expect("failed to delete worker");
749 }
750
751 async fn consume_one(
752 storage: &mut MysqlStorage<Email>,
753 worker: &Worker<Context>,
754 ) -> Request<Email, SqlContext> {
755 let mut stream = storage
756 .clone()
757 .stream_jobs(worker, std::time::Duration::from_secs(10), 1);
758 stream
759 .next()
760 .await
761 .expect("stream is empty")
762 .expect("failed to poll job")
763 .expect("no job is pending")
764 }
765
766 fn example_email() -> Email {
767 Email {
768 subject: "Test Subject".to_string(),
769 to: "example@mysql".to_string(),
770 text: "Some Text".to_string(),
771 }
772 }
773
774 async fn register_worker_at(
775 storage: &mut MysqlStorage<Email>,
776 last_seen: DateTime<Utc>,
777 ) -> Worker<Context> {
778 let worker_id = WorkerId::new("test-worker");
779 let wrk = Worker::new(worker_id, Context::default());
780 wrk.start();
781 storage
782 .keep_alive_at::<DummyService>(&wrk.id(), last_seen)
783 .await
784 .expect("failed to register worker");
785 wrk
786 }
787
788 async fn register_worker(storage: &mut MysqlStorage<Email>) -> Worker<Context> {
789 let now = Utc::now();
790
791 register_worker_at(storage, now).await
792 }
793
794 async fn push_email(storage: &mut MysqlStorage<Email>, email: Email) {
795 storage.push(email).await.expect("failed to push a job");
796 }
797
798 async fn get_job(
799 storage: &mut MysqlStorage<Email>,
800 job_id: &TaskId,
801 ) -> Request<Email, SqlContext> {
802 apalis_core::sleep(Duration::from_secs(1)).await;
804 storage
805 .fetch_by_id(job_id)
806 .await
807 .expect("failed to fetch job by id")
808 .expect("no job found by id")
809 }
810
811 #[tokio::test]
812 async fn test_consume_last_pushed_job() {
813 let mut storage = setup().await;
814 push_email(&mut storage, example_email()).await;
815
816 let worker = register_worker(&mut storage).await;
817
818 let job = consume_one(&mut storage, &worker).await;
819 let ctx = job.parts.context;
820 assert_eq!(*ctx.status(), State::Running);
822 assert_eq!(*ctx.lock_by(), Some(worker.id().clone()));
823 assert!(ctx.lock_at().is_some());
824 }
825
826 #[tokio::test]
827 async fn test_kill_job() {
828 let mut storage = setup().await;
829
830 push_email(&mut storage, example_email()).await;
831
832 let worker = register_worker(&mut storage).await;
833
834 let job = consume_one(&mut storage, &worker).await;
835
836 let job_id = &job.parts.task_id;
837
838 storage
839 .kill(worker.id(), job_id)
840 .await
841 .expect("failed to kill job");
842
843 let job = get_job(&mut storage, job_id).await;
844 let ctx = job.parts.context;
845 assert_eq!(*ctx.status(), State::Killed);
847 assert!(ctx.done_at().is_some());
848 }
849
850 #[tokio::test]
851 async fn test_storage_heartbeat_reenqueuorphaned_pulse_last_seen_6min() {
852 let mut storage = setup().await;
853
854 storage
856 .push(example_email())
857 .await
858 .expect("failed to push job");
859
860 let worker_id = WorkerId::new("test-worker");
862 let worker = Worker::new(worker_id, Context::default());
863 worker.start();
864 let five_minutes_ago = Utc::now() - Duration::from_secs(5 * 60);
865
866 let six_minutes_ago = Utc::now() - Duration::from_secs(60 * 6);
867
868 storage
869 .keep_alive_at::<Email>(worker.id(), six_minutes_ago)
870 .await
871 .unwrap();
872
873 let job = consume_one(&mut storage, &worker).await;
875 let ctx = job.parts.context;
876
877 assert_eq!(*ctx.status(), State::Running);
878
879 storage
880 .reenqueue_orphaned(1, five_minutes_ago)
881 .await
882 .unwrap();
883
884 let job = storage
886 .fetch_by_id(&job.parts.task_id)
887 .await
888 .unwrap()
889 .unwrap();
890 let ctx = job.parts.context;
891 assert_eq!(*ctx.status(), State::Pending);
892 assert!(ctx.done_at().is_none());
893 assert!(ctx.lock_by().is_none());
894 assert!(ctx.lock_at().is_none());
895 assert_eq!(*ctx.last_error(), Some("Job was abandoned".to_owned()));
896 assert_eq!(job.parts.attempt.current(), 1);
897 }
898
899 #[tokio::test]
900 async fn test_storage_heartbeat_reenqueuorphaned_pulse_last_seen_4min() {
901 let mut storage = setup().await;
902
903 let service = apalis_test_service_fn(|_: Request<Email, _>| async move {
904 apalis_core::sleep(Duration::from_millis(500)).await;
905 Ok::<_, io::Error>("success")
906 });
907 let (mut t, poller) = TestWrapper::new_with_service(storage.clone(), service);
908 let four_minutes_ago = Utc::now() - Duration::from_secs(4 * 60);
909 storage
910 .keep_alive_at::<Email>(&t.worker.id(), four_minutes_ago)
911 .await
912 .unwrap();
913
914 tokio::spawn(poller);
915
916 let parts = storage
918 .push(example_email())
919 .await
920 .expect("failed to push job");
921
922 let six_minutes_ago = Utc::now() - Duration::from_secs(6 * 60);
924 storage
926 .reenqueue_orphaned(1, six_minutes_ago)
927 .await
928 .unwrap();
929
930 let job = storage.fetch_by_id(&parts.task_id).await.unwrap().unwrap();
932 let ctx = job.parts.context;
933 assert_eq!(*ctx.status(), State::Pending);
934 assert_eq!(*ctx.lock_by(), None);
935 assert!(ctx.lock_at().is_none());
936 assert_eq!(*ctx.last_error(), None);
937 assert_eq!(job.parts.attempt.current(), 0);
938
939 let res = t.execute_next().await.unwrap();
940
941 apalis_core::sleep(Duration::from_millis(1000)).await;
942
943 let job = storage.fetch_by_id(&res.0).await.unwrap().unwrap();
944 let ctx = job.parts.context;
945 assert_eq!(*ctx.status(), State::Done);
946 assert_eq!(*ctx.lock_by(), Some(t.worker.id().clone()));
947 assert!(ctx.lock_at().is_some());
948 assert_eq!(*ctx.last_error(), Some("{\"Ok\":\"success\"}".to_owned()));
949 assert_eq!(job.parts.attempt.current(), 1);
950 }
951}