1use apalis_core::backend::{BackendExpose, Stat, WorkerState};
2use apalis_core::codec::json::JsonCodec;
3use apalis_core::error::{BoxDynError, Error};
4use apalis_core::layers::{Ack, AckLayer};
5use apalis_core::notify::Notify;
6use apalis_core::poller::controller::Controller;
7use apalis_core::poller::stream::BackendStream;
8use apalis_core::poller::Poller;
9use apalis_core::request::{Parts, Request, RequestStream, State};
10use apalis_core::response::Response;
11use apalis_core::storage::Storage;
12use apalis_core::task::namespace::Namespace;
13use apalis_core::task::task_id::TaskId;
14use apalis_core::worker::{Context, Event, Worker, WorkerId};
15use apalis_core::{backend::Backend, codec::Codec};
16use async_stream::try_stream;
17use chrono::{DateTime, Utc};
18use futures::{Stream, StreamExt, TryStreamExt};
19use log::error;
20use serde::{de::DeserializeOwned, Serialize};
21use serde_json::Value;
22use sqlx::mysql::MySqlRow;
23use sqlx::{MySql, Pool, Row};
24use std::any::type_name;
25use std::convert::TryInto;
26use std::fmt::Debug;
27use std::sync::Arc;
28use std::{fmt, io};
29use std::{marker::PhantomData, ops::Add, time::Duration};
30
31use crate::context::SqlContext;
32use crate::from_row::SqlRequest;
33use crate::{calculate_status, Config, SqlError};
34
35pub use sqlx::mysql::MySqlPool;
36
37type MysqlCodec = JsonCodec<Value>;
38
39pub struct MysqlStorage<T, C = JsonCodec<Value>>
42where
43 C: Codec,
44{
45 pool: Pool<MySql>,
46 job_type: PhantomData<T>,
47 controller: Controller,
48 config: Config,
49 codec: PhantomData<C>,
50 ack_notify: Notify<(SqlContext, Response<C::Compact>)>,
51}
52
53impl<T, C> fmt::Debug for MysqlStorage<T, C>
54where
55 C: Debug + Codec,
56 C::Compact: Debug,
57{
58 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
59 f.debug_struct("MysqlStorage")
60 .field("pool", &self.pool)
61 .field("job_type", &"PhantomData<T>")
62 .field("controller", &self.controller)
63 .field("config", &self.config)
64 .field("codec", &self.codec)
65 .field("ack_notify", &self.ack_notify)
66 .finish()
67 }
68}
69
70impl<T, C> Clone for MysqlStorage<T, C>
71where
72 C: Debug + Codec,
73{
74 fn clone(&self) -> Self {
75 let pool = self.pool.clone();
76 MysqlStorage {
77 pool,
78 job_type: PhantomData,
79 controller: self.controller.clone(),
80 config: self.config.clone(),
81 codec: self.codec,
82 ack_notify: self.ack_notify.clone(),
83 }
84 }
85}
86
87impl MysqlStorage<(), JsonCodec<Value>> {
88 #[cfg(feature = "migrate")]
90 pub fn migrations() -> sqlx::migrate::Migrator {
91 sqlx::migrate!("migrations/mysql")
92 }
93
94 #[cfg(feature = "migrate")]
96 pub async fn setup(pool: &Pool<MySql>) -> Result<(), sqlx::Error> {
97 Self::migrations().run(pool).await?;
98 Ok(())
99 }
100}
101
102impl<T> MysqlStorage<T>
103where
104 T: Serialize + DeserializeOwned,
105{
106 pub fn new(pool: MySqlPool) -> Self {
108 Self::new_with_config(pool, Config::new(type_name::<T>()))
109 }
110
111 pub fn new_with_config(pool: MySqlPool, config: Config) -> Self {
113 Self {
114 pool,
115 job_type: PhantomData,
116 controller: Controller::new(),
117 config,
118 ack_notify: Notify::new(),
119 codec: PhantomData,
120 }
121 }
122
123 pub fn pool(&self) -> &Pool<MySql> {
125 &self.pool
126 }
127
128 pub fn get_config(&self) -> &Config {
130 &self.config
131 }
132}
133
134impl<T, C> MysqlStorage<T, C>
135where
136 T: DeserializeOwned + Send + Unpin + Sync + 'static,
137 C: Codec<Compact = Value> + Send + 'static,
138{
139 fn stream_jobs(
140 self,
141 worker: &Worker<Context>,
142 interval: Duration,
143 buffer_size: usize,
144 ) -> impl Stream<Item = Result<Option<Request<T, SqlContext>>, sqlx::Error>> {
145 let pool = self.pool.clone();
146 let worker = worker.clone();
147 let worker_id = worker.id().to_string();
148
149 try_stream! {
150 let buffer_size = u32::try_from(buffer_size)
151 .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidInput, e)))?;
152 loop {
153 apalis_core::sleep(interval).await;
154 if !worker.is_ready() {
155 continue;
156 }
157 let pool = pool.clone();
158 let job_type = self.config.namespace.clone();
159 let mut tx = pool.begin().await?;
160 let fetch_query = "SELECT id FROM jobs
161 WHERE status = 'Pending' AND run_at <= NOW() AND job_type = ? ORDER BY run_at ASC LIMIT ? FOR UPDATE SKIP LOCKED";
162 let task_ids: Vec<MySqlRow> = sqlx::query(fetch_query)
163 .bind(job_type)
164 .bind(buffer_size)
165 .fetch_all(&mut *tx).await?;
166 if task_ids.is_empty() {
167 tx.rollback().await?;
168 yield None
169 } else {
170 let task_ids: Vec<String> = task_ids.iter().map(|r| r.get_unchecked("id")).collect();
171 let id_params = format!("?{}", ", ?".repeat(task_ids.len() - 1));
172 let query = format!("UPDATE jobs SET status = 'Running', lock_by = ?, lock_at = NOW(), attempts = attempts + 1 WHERE id IN({}) AND status = 'Pending' AND lock_by IS NULL;", id_params);
173 let mut query = sqlx::query(&query).bind(worker_id.clone());
174 for i in &task_ids {
175 query = query.bind(i);
176 }
177 query.execute(&mut *tx).await?;
178 tx.commit().await?;
179
180 let fetch_query = format!("SELECT * FROM jobs WHERE ID IN ({})", id_params);
181 let mut query = sqlx::query_as(&fetch_query);
182 for i in task_ids {
183 query = query.bind(i);
184 }
185 let jobs: Vec<SqlRequest<Value>> = query.fetch_all(&pool).await?;
186
187 for job in jobs {
188 yield {
189 let (req, ctx) = job.req.take_parts();
190 let req = C::decode(req)
191 .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
192 let mut req: Request<T, SqlContext> = Request::new_with_parts(req, ctx);
193 req.parts.namespace = Some(Namespace(self.config.namespace.clone()));
194 Some(req)
195 }
196 }
197 }
198 }
199 }.boxed()
200 }
201
202 async fn keep_alive_at<Service>(
203 &mut self,
204 worker_id: &WorkerId,
205 last_seen: DateTime<Utc>,
206 ) -> Result<(), sqlx::Error> {
207 let pool = self.pool.clone();
208
209 let mut tx = pool.acquire().await?;
210 let worker_type = self.config.namespace.clone();
211 let storage_name = std::any::type_name::<Self>();
212 let query =
213 "INSERT INTO workers (id, worker_type, storage_name, layers, last_seen) VALUES (?, ?, ?, ?, ?) ON DUPLICATE KEY UPDATE id = ?;";
214 sqlx::query(query)
215 .bind(worker_id.to_string())
216 .bind(worker_type)
217 .bind(storage_name)
218 .bind(std::any::type_name::<Service>())
219 .bind(last_seen)
220 .bind(worker_id.to_string())
221 .execute(&mut *tx)
222 .await?;
223 Ok(())
224 }
225}
226
227impl<T, C> Storage for MysqlStorage<T, C>
228where
229 T: Serialize + DeserializeOwned + Send + 'static + Unpin + Sync,
230 C: Codec<Compact = Value> + Send,
231{
232 type Job = T;
233
234 type Error = sqlx::Error;
235
236 type Context = SqlContext;
237
238 async fn push_request(
239 &mut self,
240 job: Request<Self::Job, SqlContext>,
241 ) -> Result<Parts<SqlContext>, sqlx::Error> {
242 let (args, parts) = job.take_parts();
243 let query =
244 "INSERT INTO jobs VALUES (?, ?, ?, 'Pending', 0, ?, now(), NULL, NULL, NULL, NULL)";
245 let pool = self.pool.clone();
246
247 let job = C::encode(args)
248 .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
249 let job_type = self.config.namespace.clone();
250 sqlx::query(query)
251 .bind(job)
252 .bind(parts.task_id.to_string())
253 .bind(job_type.to_string())
254 .bind(parts.context.max_attempts())
255 .execute(&pool)
256 .await?;
257 Ok(parts)
258 }
259
260 async fn schedule_request(
261 &mut self,
262 req: Request<Self::Job, SqlContext>,
263 on: i64,
264 ) -> Result<Parts<Self::Context>, sqlx::Error> {
265 let query = "INSERT INTO jobs VALUES (?, ?, ?, 'Pending', 0, ?, ?, NULL, NULL, NULL, NULL)";
266 let pool = self.pool.clone();
267
268 let args = C::encode(&req.args)
269 .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
270
271 let job_type = self.config.namespace.clone();
272 sqlx::query(query)
273 .bind(args)
274 .bind(req.parts.task_id.to_string())
275 .bind(job_type)
276 .bind(req.parts.context.max_attempts())
277 .bind(on)
278 .execute(&pool)
279 .await?;
280 Ok(req.parts)
281 }
282
283 async fn fetch_by_id(
284 &mut self,
285 job_id: &TaskId,
286 ) -> Result<Option<Request<Self::Job, SqlContext>>, sqlx::Error> {
287 let pool = self.pool.clone();
288
289 let fetch_query = "SELECT * FROM jobs WHERE id = ?";
290 let res: Option<SqlRequest<serde_json::Value>> = sqlx::query_as(fetch_query)
291 .bind(job_id.to_string())
292 .fetch_optional(&pool)
293 .await?;
294 match res {
295 None => Ok(None),
296 Some(job) => Ok(Some({
297 let (req, parts) = job.req.take_parts();
298 let req = C::decode(req)
299 .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
300 let mut req = Request::new_with_parts(req, parts);
301 req.parts.namespace = Some(Namespace(self.config.namespace.clone()));
302 req
303 })),
304 }
305 }
306
307 async fn len(&mut self) -> Result<i64, sqlx::Error> {
308 let pool = self.pool.clone();
309
310 let query = "Select Count(*) as count from jobs where status='Pending'";
311 let record = sqlx::query(query).fetch_one(&pool).await?;
312 record.try_get("count")
313 }
314
315 async fn reschedule(
316 &mut self,
317 job: Request<T, SqlContext>,
318 wait: Duration,
319 ) -> Result<(), sqlx::Error> {
320 let pool = self.pool.clone();
321 let job_id = job.parts.task_id.clone();
322
323 let wait: i64 = wait
324 .as_secs()
325 .try_into()
326 .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidInput, e)))?;
327 let mut tx = pool.acquire().await?;
328 let query =
329 "UPDATE jobs SET status = 'Pending', done_at = NULL, lock_by = NULL, lock_at = NULL, run_at = ? WHERE id = ?";
330
331 sqlx::query(query)
332 .bind(Utc::now().timestamp().add(wait))
333 .bind(job_id.to_string())
334 .execute(&mut *tx)
335 .await?;
336 Ok(())
337 }
338
339 async fn update(&mut self, job: Request<Self::Job, SqlContext>) -> Result<(), sqlx::Error> {
340 let pool = self.pool.clone();
341 let ctx = job.parts.context;
342 let status = ctx.status().to_string();
343 let attempts = job.parts.attempt;
344 let done_at = *ctx.done_at();
345 let lock_by = ctx.lock_by().clone();
346 let lock_at = *ctx.lock_at();
347 let last_error = ctx.last_error().clone();
348 let job_id = job.parts.task_id;
349 let mut tx = pool.acquire().await?;
350 let query =
351 "UPDATE jobs SET status = ?, attempts = ?, done_at = ?, lock_by = ?, lock_at = ?, last_error = ? WHERE id = ?";
352 sqlx::query(query)
353 .bind(status.to_owned())
354 .bind::<i64>(
355 attempts
356 .current()
357 .try_into()
358 .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidInput, e)))?,
359 )
360 .bind(done_at)
361 .bind(lock_by.map(|w| w.name().to_string()))
362 .bind(lock_at)
363 .bind(last_error)
364 .bind(job_id.to_string())
365 .execute(&mut *tx)
366 .await?;
367 Ok(())
368 }
369
370 async fn is_empty(&mut self) -> Result<bool, Self::Error> {
371 Ok(self.len().await? == 0)
372 }
373
374 async fn vacuum(&mut self) -> Result<usize, sqlx::Error> {
375 let pool = self.pool.clone();
376 let query = "Delete from jobs where status='Done'";
377 let record = sqlx::query(query).execute(&pool).await?;
378 Ok(record.rows_affected().try_into().unwrap_or_default())
379 }
380}
381
382#[derive(thiserror::Error, Debug)]
384pub enum MysqlPollError {
385 #[error("Encountered an error during ACK: `{0}`")]
387 AckError(sqlx::Error),
388
389 #[error("Encountered an error during encoding the result: {0}")]
391 CodecError(BoxDynError),
392
393 #[error("Encountered an error during KeepAlive heartbeat: `{0}`")]
395 KeepAliveError(sqlx::Error),
396
397 #[error("Encountered an error during ReenqueueOrphaned heartbeat: `{0}`")]
399 ReenqueueOrphanedError(sqlx::Error),
400}
401
402impl<Req, Res, C> Backend<Request<Req, SqlContext>, Res> for MysqlStorage<Req, C>
403where
404 Req: Serialize + DeserializeOwned + Sync + Send + Unpin + 'static,
405 C: Debug + Codec<Compact = Value> + Clone + Send + 'static + Sync,
406 C::Error: std::error::Error + 'static + Send + Sync,
407{
408 type Stream = BackendStream<RequestStream<Request<Req, SqlContext>>>;
409
410 type Layer = AckLayer<MysqlStorage<Req, C>, Req, SqlContext, Res>;
411
412 fn poll<Svc>(self, worker: &Worker<Context>) -> Poller<Self::Stream, Self::Layer> {
413 let layer = AckLayer::new(self.clone());
414 let config = self.config.clone();
415 let controller = self.controller.clone();
416 let pool = self.pool.clone();
417 let ack_notify = self.ack_notify.clone();
418 let mut hb_storage = self.clone();
419 let requeue_storage = self.clone();
420 let stream = self
421 .stream_jobs(worker, config.poll_interval, config.buffer_size)
422 .map_err(|e| Error::SourceError(Arc::new(Box::new(e))));
423 let stream = BackendStream::new(stream.boxed(), controller);
424 let w = worker.clone();
425
426 let ack_heartbeat = async move {
427 while let Some(ids) = ack_notify
428 .clone()
429 .ready_chunks(config.buffer_size)
430 .next()
431 .await
432 {
433 for (ctx, res) in ids {
434 let query = "UPDATE jobs SET status = ?, done_at = now(), last_error = ? WHERE id = ? AND lock_by = ?";
435 let query = sqlx::query(query);
436 let last_result =
437 C::encode(res.inner.as_ref().map_err(|e| e.to_string())).map_err(Box::new);
438 match (last_result, ctx.lock_by()) {
439 (Ok(val), Some(worker_id)) => {
440 let query = query
441 .bind(calculate_status(&res.inner).to_string())
442 .bind(val)
443 .bind(res.task_id.to_string())
444 .bind(worker_id.to_string());
445 if let Err(e) = query.execute(&pool).await {
446 w.emit(Event::Error(Box::new(MysqlPollError::AckError(e))));
447 }
448 }
449 (Err(error), Some(_)) => {
450 w.emit(Event::Error(Box::new(MysqlPollError::CodecError(error))));
451 }
452 _ => {
453 unreachable!(
454 "Attempted to ACK without a worker attached. This is a bug, File it on the repo"
455 );
456 }
457 }
458 }
459
460 apalis_core::sleep(config.poll_interval).await;
461 }
462 };
463 let w = worker.clone();
464 let heartbeat = async move {
465 loop {
466 let now = Utc::now();
467 if let Err(e) = hb_storage.keep_alive_at::<Self::Layer>(w.id(), now).await {
468 w.emit(Event::Error(Box::new(MysqlPollError::KeepAliveError(e))));
469 }
470 apalis_core::sleep(config.keep_alive).await;
471 }
472 };
473 let w = worker.clone();
474 let reenqueue_beat = async move {
475 loop {
476 let dead_since = Utc::now()
477 - chrono::Duration::from_std(config.reenqueue_orphaned_after)
478 .expect("Could not calculate dead since");
479 if let Err(e) = requeue_storage
480 .reenqueue_orphaned(
481 config
482 .buffer_size
483 .try_into()
484 .expect("Could not convert usize to i32"),
485 dead_since,
486 )
487 .await
488 {
489 w.emit(Event::Error(Box::new(
490 MysqlPollError::ReenqueueOrphanedError(e),
491 )));
492 }
493 apalis_core::sleep(config.poll_interval).await;
494 }
495 };
496 Poller::new_with_layer(
497 stream,
498 async {
499 futures::join!(heartbeat, ack_heartbeat, reenqueue_beat);
500 },
501 layer,
502 )
503 }
504}
505
506impl<T, Res, C> Ack<T, Res> for MysqlStorage<T, C>
507where
508 T: Sync + Send,
509 Res: Serialize + Send + 'static + Sync,
510 C: Codec<Compact = Value> + Send,
511 C::Error: Debug,
512{
513 type Context = SqlContext;
514 type AckError = sqlx::Error;
515 async fn ack(&mut self, ctx: &Self::Context, res: &Response<Res>) -> Result<(), sqlx::Error> {
516 self.ack_notify
517 .notify((
518 ctx.clone(),
519 res.map(|res| C::encode(res).expect("Could not encode result")),
520 ))
521 .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::BrokenPipe, e)))?;
522
523 Ok(())
524 }
525}
526
527impl<T, C: Codec> MysqlStorage<T, C> {
528 pub async fn kill(&mut self, worker_id: &WorkerId, job_id: &TaskId) -> Result<(), sqlx::Error> {
530 let pool = self.pool.clone();
531
532 let mut tx = pool.acquire().await?;
533 let query =
534 "UPDATE jobs SET status = 'Killed', done_at = NOW() WHERE id = ? AND lock_by = ?";
535 sqlx::query(query)
536 .bind(job_id.to_string())
537 .bind(worker_id.to_string())
538 .execute(&mut *tx)
539 .await?;
540 Ok(())
541 }
542
543 pub async fn retry(
545 &mut self,
546 worker_id: &WorkerId,
547 job_id: &TaskId,
548 ) -> Result<(), sqlx::Error> {
549 let pool = self.pool.clone();
550
551 let mut tx = pool.acquire().await?;
552 let query =
553 "UPDATE jobs SET status = 'Pending', done_at = NULL, lock_by = NULL WHERE id = ? AND lock_by = ?";
554 sqlx::query(query)
555 .bind(job_id.to_string())
556 .bind(worker_id.to_string())
557 .execute(&mut *tx)
558 .await?;
559 Ok(())
560 }
561
562 pub async fn reenqueue_orphaned(
564 &self,
565 count: i32,
566 dead_since: DateTime<Utc>,
567 ) -> Result<bool, sqlx::Error> {
568 let job_type = self.config.namespace.clone();
569 let mut tx = self.pool.acquire().await?;
570 let query = r#"Update jobs
571 INNER JOIN ( SELECT workers.id as worker_id, jobs.id as job_id from workers INNER JOIN jobs ON jobs.lock_by = workers.id WHERE jobs.status = "Running" AND workers.last_seen < ? AND workers.worker_type = ?
572 ORDER BY lock_at ASC LIMIT ?) as workers ON jobs.lock_by = workers.worker_id AND jobs.id = workers.job_id
573 SET status = "Pending", done_at = NULL, lock_by = NULL, lock_at = NULL, last_error ="Job was abandoned";"#;
574
575 sqlx::query(query)
576 .bind(dead_since)
577 .bind(job_type)
578 .bind(count)
579 .execute(&mut *tx)
580 .await?;
581 Ok(true)
582 }
583}
584
585impl<J: 'static + Serialize + DeserializeOwned + Unpin + Send + Sync> BackendExpose<J>
586 for MysqlStorage<J>
587{
588 type Request = Request<J, Parts<SqlContext>>;
589 type Error = SqlError;
590 async fn stats(&self) -> Result<Stat, Self::Error> {
591 let fetch_query = "SELECT
592 COUNT(CASE WHEN status = 'Pending' THEN 1 END) AS pending,
593 COUNT(CASE WHEN status = 'Running' THEN 1 END) AS running,
594 COUNT(CASE WHEN status = 'Done' THEN 1 END) AS done,
595 COUNT(CASE WHEN status = 'Retry' THEN 1 END) AS retry,
596 COUNT(CASE WHEN status = 'Failed' THEN 1 END) AS failed,
597 COUNT(CASE WHEN status = 'Killed' THEN 1 END) AS killed
598 FROM jobs WHERE job_type = ?";
599
600 let res: (i64, i64, i64, i64, i64, i64) = sqlx::query_as(fetch_query)
601 .bind(self.get_config().namespace())
602 .fetch_one(self.pool())
603 .await?;
604
605 Ok(Stat {
606 pending: res.0.try_into()?,
607 running: res.1.try_into()?,
608 dead: res.4.try_into()?,
609 failed: res.3.try_into()?,
610 success: res.2.try_into()?,
611 })
612 }
613
614 async fn list_jobs(
615 &self,
616 status: &State,
617 page: i32,
618 ) -> Result<Vec<Self::Request>, Self::Error> {
619 let status = status.to_string();
620 let fetch_query = "SELECT * FROM jobs WHERE status = ? AND job_type = ? ORDER BY done_at DESC, run_at DESC LIMIT 10 OFFSET ?";
621 let res: Vec<SqlRequest<serde_json::Value>> = sqlx::query_as(fetch_query)
622 .bind(status)
623 .bind(self.get_config().namespace())
624 .bind(((page - 1) * 10).to_string())
625 .fetch_all(self.pool())
626 .await?;
627 Ok(res
628 .into_iter()
629 .map(|j| {
630 let (req, ctx) = j.req.take_parts();
631 let req: J = MysqlCodec::decode(req).unwrap();
632 Request::new_with_ctx(req, ctx)
633 })
634 .collect())
635 }
636
637 async fn list_workers(&self) -> Result<Vec<Worker<WorkerState>>, Self::Error> {
638 let fetch_query =
639 "SELECT id, layers, last_seen FROM workers WHERE worker_type = ? ORDER BY last_seen DESC LIMIT 20 OFFSET ?";
640 let res: Vec<(String, String, i64)> = sqlx::query_as(fetch_query)
641 .bind(self.get_config().namespace())
642 .bind(0)
643 .fetch_all(self.pool())
644 .await?;
645 Ok(res
646 .into_iter()
647 .map(|w| Worker::new(WorkerId::new(w.0), WorkerState::new::<Self>(w.1)))
648 .collect())
649 }
650}
651
652#[cfg(test)]
653mod tests {
654
655 use crate::sql_storage_tests;
656
657 use super::*;
658
659 use apalis_core::test_utils::DummyService;
660 use email_service::Email;
661 use futures::StreamExt;
662
663 use apalis_core::generic_storage_test;
664 use apalis_core::test_utils::apalis_test_service_fn;
665 use apalis_core::test_utils::TestWrapper;
666
667 generic_storage_test!(setup);
668
669 sql_storage_tests!(setup::<Email>, MysqlStorage<Email>, Email);
670
671 async fn setup<T: Serialize + DeserializeOwned>() -> MysqlStorage<T> {
673 let db_url = &std::env::var("DATABASE_URL").expect("No DATABASE_URL is specified");
674 let pool = MySqlPool::connect(db_url).await.unwrap();
678 MysqlStorage::setup(&pool)
679 .await
680 .expect("failed to migrate DB");
681 let mut storage = MysqlStorage::new(pool);
682 cleanup(&mut storage, &WorkerId::new("test-worker")).await;
683 storage
684 }
685
686 async fn cleanup<T>(storage: &mut MysqlStorage<T>, worker_id: &WorkerId) {
693 sqlx::query("DELETE FROM jobs WHERE job_type = ?")
694 .bind(storage.config.namespace())
695 .execute(&storage.pool)
696 .await
697 .expect("failed to delete jobs");
698 sqlx::query("DELETE FROM workers WHERE id = ?")
699 .bind(worker_id.to_string())
700 .execute(&storage.pool)
701 .await
702 .expect("failed to delete worker");
703 }
704
705 async fn consume_one(
706 storage: &mut MysqlStorage<Email>,
707 worker: &Worker<Context>,
708 ) -> Request<Email, SqlContext> {
709 let mut stream = storage
710 .clone()
711 .stream_jobs(worker, std::time::Duration::from_secs(10), 1);
712 stream
713 .next()
714 .await
715 .expect("stream is empty")
716 .expect("failed to poll job")
717 .expect("no job is pending")
718 }
719
720 fn example_email() -> Email {
721 Email {
722 subject: "Test Subject".to_string(),
723 to: "example@mysql".to_string(),
724 text: "Some Text".to_string(),
725 }
726 }
727
728 async fn register_worker_at(
729 storage: &mut MysqlStorage<Email>,
730 last_seen: DateTime<Utc>,
731 ) -> Worker<Context> {
732 let worker_id = WorkerId::new("test-worker");
733 let wrk = Worker::new(worker_id, Context::default());
734 wrk.start();
735 storage
736 .keep_alive_at::<DummyService>(&wrk.id(), last_seen)
737 .await
738 .expect("failed to register worker");
739 wrk
740 }
741
742 async fn register_worker(storage: &mut MysqlStorage<Email>) -> Worker<Context> {
743 let now = Utc::now();
744
745 register_worker_at(storage, now).await
746 }
747
748 async fn push_email(storage: &mut MysqlStorage<Email>, email: Email) {
749 storage.push(email).await.expect("failed to push a job");
750 }
751
752 async fn get_job(
753 storage: &mut MysqlStorage<Email>,
754 job_id: &TaskId,
755 ) -> Request<Email, SqlContext> {
756 apalis_core::sleep(Duration::from_secs(1)).await;
758 storage
759 .fetch_by_id(job_id)
760 .await
761 .expect("failed to fetch job by id")
762 .expect("no job found by id")
763 }
764
765 #[tokio::test]
766 async fn test_consume_last_pushed_job() {
767 let mut storage = setup().await;
768 push_email(&mut storage, example_email()).await;
769
770 let worker = register_worker(&mut storage).await;
771
772 let job = consume_one(&mut storage, &worker).await;
773 let ctx = job.parts.context;
774 assert_eq!(*ctx.status(), State::Running);
776 assert_eq!(*ctx.lock_by(), Some(worker.id().clone()));
777 assert!(ctx.lock_at().is_some());
778 }
779
780 #[tokio::test]
781 async fn test_kill_job() {
782 let mut storage = setup().await;
783
784 push_email(&mut storage, example_email()).await;
785
786 let worker = register_worker(&mut storage).await;
787
788 let job = consume_one(&mut storage, &worker).await;
789
790 let job_id = &job.parts.task_id;
791
792 storage
793 .kill(worker.id(), job_id)
794 .await
795 .expect("failed to kill job");
796
797 let job = get_job(&mut storage, job_id).await;
798 let ctx = job.parts.context;
799 assert_eq!(*ctx.status(), State::Killed);
801 assert!(ctx.done_at().is_some());
802 }
803
804 #[tokio::test]
805 async fn test_storage_heartbeat_reenqueuorphaned_pulse_last_seen_6min() {
806 let mut storage = setup().await;
807
808 storage
810 .push(example_email())
811 .await
812 .expect("failed to push job");
813
814 let worker_id = WorkerId::new("test-worker");
816 let worker = Worker::new(worker_id, Context::default());
817 worker.start();
818 let five_minutes_ago = Utc::now() - Duration::from_secs(5 * 60);
819
820 let six_minutes_ago = Utc::now() - Duration::from_secs(60 * 6);
821
822 storage
823 .keep_alive_at::<Email>(worker.id(), six_minutes_ago)
824 .await
825 .unwrap();
826
827 let job = consume_one(&mut storage, &worker).await;
829 let ctx = job.parts.context;
830
831 assert_eq!(*ctx.status(), State::Running);
832
833 storage
834 .reenqueue_orphaned(1, five_minutes_ago)
835 .await
836 .unwrap();
837
838 let job = storage
840 .fetch_by_id(&job.parts.task_id)
841 .await
842 .unwrap()
843 .unwrap();
844 let ctx = job.parts.context;
845 assert_eq!(*ctx.status(), State::Pending);
846 assert!(ctx.done_at().is_none());
847 assert!(ctx.lock_by().is_none());
848 assert!(ctx.lock_at().is_none());
849 assert_eq!(*ctx.last_error(), Some("Job was abandoned".to_owned()));
850 assert_eq!(job.parts.attempt.current(), 1);
851 }
852
853 #[tokio::test]
854 async fn test_storage_heartbeat_reenqueuorphaned_pulse_last_seen_4min() {
855 let mut storage = setup().await;
856
857 storage
859 .push(example_email())
860 .await
861 .expect("failed to push job");
862
863 let four_minutes_ago = Utc::now() - Duration::from_secs(4 * 60);
865 let six_minutes_ago = Utc::now() - Duration::from_secs(6 * 60);
866
867 let worker_id = WorkerId::new("test-worker");
868 let worker = Worker::new(worker_id, Context::default());
869 worker.start();
870 storage
871 .keep_alive_at::<Email>(&worker.id(), four_minutes_ago)
872 .await
873 .unwrap();
874
875 let job = consume_one(&mut storage, &worker).await;
877 let ctx = &job.parts.context;
878
879 assert_eq!(*ctx.status(), State::Running);
880
881 storage
883 .reenqueue_orphaned(1, six_minutes_ago)
884 .await
885 .unwrap();
886
887 let job = storage
889 .fetch_by_id(&job.parts.task_id)
890 .await
891 .unwrap()
892 .unwrap();
893 let ctx = job.parts.context;
894 assert_eq!(*ctx.status(), State::Running);
895 assert_eq!(*ctx.lock_by(), Some(worker.id().clone()));
896 assert!(ctx.lock_at().is_some());
897 assert_eq!(*ctx.last_error(), None);
898 assert_eq!(job.parts.attempt.current(), 1);
899 }
900}