1use apalis_core::backend::{BackendExpose, Stat, WorkerState};
2use apalis_core::codec::json::JsonCodec;
3use apalis_core::error::{BoxDynError, Error};
4use apalis_core::layers::{Ack, AckLayer};
5use apalis_core::notify::Notify;
6use apalis_core::poller::controller::Controller;
7use apalis_core::poller::stream::BackendStream;
8use apalis_core::poller::Poller;
9use apalis_core::request::{Parts, Request, RequestStream, State};
10use apalis_core::response::Response;
11use apalis_core::storage::Storage;
12use apalis_core::task::namespace::Namespace;
13use apalis_core::task::task_id::TaskId;
14use apalis_core::worker::{Context, Event, Worker, WorkerId};
15use apalis_core::{backend::Backend, codec::Codec};
16use async_stream::try_stream;
17use chrono::{DateTime, Utc};
18use futures::{Stream, StreamExt, TryStreamExt};
19use log::error;
20use serde::{de::DeserializeOwned, Serialize};
21use serde_json::Value;
22use sqlx::mysql::MySqlRow;
23use sqlx::{MySql, Pool, Row};
24use std::any::type_name;
25use std::convert::TryInto;
26use std::fmt::Debug;
27use std::sync::Arc;
28use std::{fmt, io};
29use std::{marker::PhantomData, ops::Add, time::Duration};
30
31use crate::context::SqlContext;
32use crate::from_row::SqlRequest;
33use crate::{calculate_status, Config, SqlError};
34
35pub use sqlx::mysql::MySqlPool;
36
37type MysqlCodec = JsonCodec<Value>;
38
39pub struct MysqlStorage<T, C = JsonCodec<Value>>
42where
43 C: Codec,
44{
45 pool: Pool<MySql>,
46 job_type: PhantomData<T>,
47 controller: Controller,
48 config: Config,
49 codec: PhantomData<C>,
50 ack_notify: Notify<(SqlContext, Response<C::Compact>)>,
51}
52
53impl<T, C> fmt::Debug for MysqlStorage<T, C>
54where
55 C: Codec,
56 C::Compact: Debug,
57{
58 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
59 f.debug_struct("MysqlStorage")
60 .field("pool", &self.pool)
61 .field("job_type", &"PhantomData<T>")
62 .field("controller", &self.controller)
63 .field("config", &self.config)
64 .field("codec", &std::any::type_name::<C>())
65 .field("ack_notify", &self.ack_notify)
66 .finish()
67 }
68}
69
70impl<T, C> Clone for MysqlStorage<T, C>
71where
72 C: Codec,
73{
74 fn clone(&self) -> Self {
75 let pool = self.pool.clone();
76 MysqlStorage {
77 pool,
78 job_type: PhantomData,
79 controller: self.controller.clone(),
80 config: self.config.clone(),
81 codec: self.codec,
82 ack_notify: self.ack_notify.clone(),
83 }
84 }
85}
86
87impl MysqlStorage<(), JsonCodec<Value>> {
88 #[cfg(feature = "migrate")]
90 pub fn migrations() -> sqlx::migrate::Migrator {
91 sqlx::migrate!("migrations/mysql")
92 }
93
94 #[cfg(feature = "migrate")]
96 pub async fn setup(pool: &Pool<MySql>) -> Result<(), sqlx::Error> {
97 Self::migrations().run(pool).await?;
98 Ok(())
99 }
100}
101
102impl<T> MysqlStorage<T>
103where
104 T: Serialize + DeserializeOwned,
105{
106 pub fn new(pool: MySqlPool) -> Self {
108 Self::new_with_config(pool, Config::new(type_name::<T>()))
109 }
110
111 pub fn new_with_config(pool: MySqlPool, config: Config) -> Self {
113 Self {
114 pool,
115 job_type: PhantomData,
116 controller: Controller::new(),
117 config,
118 ack_notify: Notify::new(),
119 codec: PhantomData,
120 }
121 }
122
123 pub fn pool(&self) -> &Pool<MySql> {
125 &self.pool
126 }
127
128 pub fn get_config(&self) -> &Config {
130 &self.config
131 }
132}
133
134impl<T, C> MysqlStorage<T, C>
135where
136 T: DeserializeOwned + Send + Sync + 'static,
137 C: Codec<Compact = Value> + Send + 'static,
138{
139 fn stream_jobs(
140 self,
141 worker: &Worker<Context>,
142 interval: Duration,
143 buffer_size: usize,
144 ) -> impl Stream<Item = Result<Option<Request<T, SqlContext>>, sqlx::Error>> {
145 let pool = self.pool.clone();
146 let worker = worker.clone();
147 let worker_id = worker.id().to_string();
148
149 try_stream! {
150 let buffer_size = u32::try_from(buffer_size)
151 .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidInput, e)))?;
152 loop {
153 apalis_core::sleep(interval).await;
154 if !worker.is_ready() {
155 continue;
156 }
157 let pool = pool.clone();
158 let job_type = self.config.namespace.clone();
159 let mut tx = pool.begin().await?;
160 let fetch_query = "SELECT id FROM jobs
161 WHERE (status='Pending' OR (status = 'Failed' AND attempts < max_attempts)) AND run_at <= NOW() AND job_type = ? ORDER BY priority DESC, run_at ASC LIMIT ? FOR UPDATE SKIP LOCKED";
162 let task_ids: Vec<MySqlRow> = sqlx::query(fetch_query)
163 .bind(job_type)
164 .bind(buffer_size)
165 .fetch_all(&mut *tx).await?;
166 if task_ids.is_empty() {
167 tx.rollback().await?;
168 yield None
169 } else {
170 let task_ids: Vec<String> = task_ids.iter().map(|r| r.get_unchecked("id")).collect();
171 let id_params = format!("?{}", ", ?".repeat(task_ids.len() - 1));
172 let query = format!("UPDATE jobs SET status = 'Running', lock_by = ?, lock_at = NOW() WHERE id IN({}) AND status = 'Pending' AND lock_by IS NULL;", id_params);
173 let mut query = sqlx::query(&query).bind(worker_id.clone());
174 for i in &task_ids {
175 query = query.bind(i);
176 }
177 query.execute(&mut *tx).await?;
178 tx.commit().await?;
179
180 let fetch_query = format!("SELECT * FROM jobs WHERE ID IN ({}) ORDER BY priority DESC, run_at ASC", id_params);
181 let mut query = sqlx::query_as(&fetch_query);
182 for i in task_ids {
183 query = query.bind(i);
184 }
185 let jobs: Vec<SqlRequest<Value>> = query.fetch_all(&pool).await?;
186
187 for job in jobs {
188 yield {
189 let (req, ctx) = job.req.take_parts();
190 let req = C::decode(req)
191 .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
192 let mut req: Request<T, SqlContext> = Request::new_with_parts(req, ctx);
193 req.parts.namespace = Some(Namespace(self.config.namespace.clone()));
194 Some(req)
195 }
196 }
197 }
198 }
199 }.boxed()
200 }
201
202 async fn keep_alive_at<Service>(
203 &mut self,
204 worker_id: &WorkerId,
205 last_seen: DateTime<Utc>,
206 ) -> Result<(), sqlx::Error> {
207 let pool = self.pool.clone();
208
209 let mut tx = pool.acquire().await?;
210 let worker_type = self.config.namespace.clone();
211 let storage_name = std::any::type_name::<Self>();
212 let query =
213 "INSERT INTO workers (id, worker_type, storage_name, layers, last_seen) VALUES (?, ?, ?, ?, ?) ON DUPLICATE KEY UPDATE id = ?;";
214 sqlx::query(query)
215 .bind(worker_id.to_string())
216 .bind(worker_type)
217 .bind(storage_name)
218 .bind(std::any::type_name::<Service>())
219 .bind(last_seen)
220 .bind(worker_id.to_string())
221 .execute(&mut *tx)
222 .await?;
223 Ok(())
224 }
225}
226
227impl<T, C> Storage for MysqlStorage<T, C>
228where
229 T: Serialize + DeserializeOwned + Send + 'static + Unpin + Sync,
230 C: Codec<Compact = Value> + Send + Sync + 'static,
231 C::Error: std::error::Error + 'static + Send + Sync,
232{
233 type Job = T;
234
235 type Error = sqlx::Error;
236
237 type Context = SqlContext;
238
239 type Compact = Value;
240
241 async fn push_request(
242 &mut self,
243 job: Request<Self::Job, SqlContext>,
244 ) -> Result<Parts<SqlContext>, sqlx::Error> {
245 let (args, parts) = job.take_parts();
246 let query =
247 "INSERT INTO jobs VALUES (?, ?, ?, 'Pending', 0, ?, now(), NULL, NULL, NULL, NULL, ?)";
248 let pool = self.pool.clone();
249
250 let job = C::encode(args)
251 .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
252 let job_type = self.config.namespace.clone();
253 sqlx::query(query)
254 .bind(job)
255 .bind(parts.task_id.to_string())
256 .bind(job_type.to_string())
257 .bind(parts.context.max_attempts())
258 .bind(parts.context.priority())
259 .execute(&pool)
260 .await?;
261 Ok(parts)
262 }
263
264 async fn push_raw_request(
265 &mut self,
266 job: Request<Self::Compact, SqlContext>,
267 ) -> Result<Parts<SqlContext>, sqlx::Error> {
268 let (args, parts) = job.take_parts();
269 let query =
270 "INSERT INTO jobs VALUES (?, ?, ?, 'Pending', 0, ?, now(), NULL, NULL, NULL, NULL, ?)";
271 let pool = self.pool.clone();
272
273 let job = C::encode(args)
274 .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
275 let job_type = self.config.namespace.clone();
276 sqlx::query(query)
277 .bind(job)
278 .bind(parts.task_id.to_string())
279 .bind(job_type.to_string())
280 .bind(parts.context.max_attempts())
281 .bind(parts.context.priority())
282 .execute(&pool)
283 .await?;
284 Ok(parts)
285 }
286
287 async fn schedule_request(
288 &mut self,
289 req: Request<Self::Job, SqlContext>,
290 on: i64,
291 ) -> Result<Parts<Self::Context>, sqlx::Error> {
292 let query =
293 "INSERT INTO jobs VALUES (?, ?, ?, 'Pending', 0, ?, ?, NULL, NULL, NULL, NULL, ?)";
294 let pool = self.pool.clone();
295
296 let args = C::encode(&req.args)
297 .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
298
299 let job_type = self.config.namespace.clone();
300 sqlx::query(query)
301 .bind(args)
302 .bind(req.parts.task_id.to_string())
303 .bind(job_type)
304 .bind(req.parts.context.max_attempts())
305 .bind(on)
306 .bind(req.parts.context.priority())
307 .execute(&pool)
308 .await?;
309 Ok(req.parts)
310 }
311
312 async fn fetch_by_id(
313 &mut self,
314 job_id: &TaskId,
315 ) -> Result<Option<Request<Self::Job, SqlContext>>, sqlx::Error> {
316 let pool = self.pool.clone();
317
318 let fetch_query = "SELECT * FROM jobs WHERE id = ?";
319 let res: Option<SqlRequest<serde_json::Value>> = sqlx::query_as(fetch_query)
320 .bind(job_id.to_string())
321 .fetch_optional(&pool)
322 .await?;
323 match res {
324 None => Ok(None),
325 Some(job) => Ok(Some({
326 let (req, parts) = job.req.take_parts();
327 let req = C::decode(req)
328 .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
329 let mut req = Request::new_with_parts(req, parts);
330 req.parts.namespace = Some(Namespace(self.config.namespace.clone()));
331 req
332 })),
333 }
334 }
335
336 async fn len(&mut self) -> Result<i64, sqlx::Error> {
337 let pool = self.pool.clone();
338
339 let query = "Select Count(*) as count from jobs where status='Pending'";
340 let record = sqlx::query(query).fetch_one(&pool).await?;
341 record.try_get("count")
342 }
343
344 async fn reschedule(
345 &mut self,
346 job: Request<T, SqlContext>,
347 wait: Duration,
348 ) -> Result<(), sqlx::Error> {
349 let pool = self.pool.clone();
350 let job_id = job.parts.task_id.clone();
351
352 let wait: i64 = wait
353 .as_secs()
354 .try_into()
355 .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidInput, e)))?;
356 let mut tx = pool.acquire().await?;
357 let query =
358 "UPDATE jobs SET status = 'Pending', done_at = NULL, lock_by = NULL, lock_at = NULL, run_at = ? WHERE id = ?";
359
360 sqlx::query(query)
361 .bind(Utc::now().timestamp().add(wait))
362 .bind(job_id.to_string())
363 .execute(&mut *tx)
364 .await?;
365 Ok(())
366 }
367
368 async fn update(&mut self, job: Request<Self::Job, SqlContext>) -> Result<(), sqlx::Error> {
369 let pool = self.pool.clone();
370 let ctx = job.parts.context;
371 let status = ctx.status().to_string();
372 let attempts = job.parts.attempt;
373 let done_at = *ctx.done_at();
374 let lock_by = ctx.lock_by().clone();
375 let lock_at = *ctx.lock_at();
376 let last_error = ctx.last_error().clone();
377 let priority = *ctx.priority();
378 let job_id = job.parts.task_id;
379 let mut tx = pool.acquire().await?;
380 let query =
381 "UPDATE jobs SET status = ?, attempts = ?, done_at = ?, lock_by = ?, lock_at = ?, last_error = ?, priority = ? WHERE id = ?";
382 sqlx::query(query)
383 .bind(status.to_owned())
384 .bind::<i64>(
385 attempts
386 .current()
387 .try_into()
388 .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidInput, e)))?,
389 )
390 .bind(done_at)
391 .bind(lock_by.map(|w| w.name().to_string()))
392 .bind(lock_at)
393 .bind(last_error)
394 .bind(priority)
395 .bind(job_id.to_string())
396 .execute(&mut *tx)
397 .await?;
398 Ok(())
399 }
400
401 async fn is_empty(&mut self) -> Result<bool, Self::Error> {
402 Ok(self.len().await? == 0)
403 }
404
405 async fn vacuum(&mut self) -> Result<usize, sqlx::Error> {
406 let pool = self.pool.clone();
407 let query = "Delete from jobs where status='Done'";
408 let record = sqlx::query(query).execute(&pool).await?;
409 Ok(record.rows_affected().try_into().unwrap_or_default())
410 }
411}
412
413#[derive(thiserror::Error, Debug)]
415pub enum MysqlPollError {
416 #[error("Encountered an error during ACK: `{0}`")]
418 AckError(sqlx::Error),
419
420 #[error("Encountered an error during encoding the result: {0}")]
422 CodecError(BoxDynError),
423
424 #[error("Encountered an error during KeepAlive heartbeat: `{0}`")]
426 KeepAliveError(sqlx::Error),
427
428 #[error("Encountered an error during ReenqueueOrphaned heartbeat: `{0}`")]
430 ReenqueueOrphanedError(sqlx::Error),
431}
432
433impl<Req, C> Backend<Request<Req, SqlContext>> for MysqlStorage<Req, C>
434where
435 Req: Serialize + DeserializeOwned + Sync + Send + 'static,
436 C: Codec<Compact = Value> + Send + 'static + Sync,
437 C::Error: std::error::Error + 'static + Send + Sync,
438{
439 type Stream = BackendStream<RequestStream<Request<Req, SqlContext>>>;
440
441 type Layer = AckLayer<MysqlStorage<Req, C>, Req, SqlContext, C>;
442
443 type Codec = C;
444
445 fn poll(self, worker: &Worker<Context>) -> Poller<Self::Stream, Self::Layer> {
446 let layer = AckLayer::new(self.clone());
447 let config = self.config.clone();
448 let controller = self.controller.clone();
449 let pool = self.pool.clone();
450 let ack_notify = self.ack_notify.clone();
451 let mut hb_storage = self.clone();
452 let requeue_storage = self.clone();
453 let stream = self
454 .stream_jobs(worker, config.poll_interval, config.buffer_size)
455 .map_err(|e| Error::SourceError(Arc::new(Box::new(e))));
456 let stream = BackendStream::new(stream.boxed(), controller);
457 let w = worker.clone();
458
459 let ack_heartbeat = async move {
460 while let Some(ids) = ack_notify
461 .clone()
462 .ready_chunks(config.buffer_size)
463 .next()
464 .await
465 {
466 for (ctx, res) in ids {
467 let query = "UPDATE jobs SET status = ?, done_at = now(), last_error = ?, attempts = ? WHERE id = ? AND lock_by = ?";
468 let query = sqlx::query(query);
469 let last_result =
470 C::encode(res.inner.as_ref().map_err(|e| e.to_string())).map_err(Box::new);
471 match (last_result, ctx.lock_by()) {
472 (Ok(val), Some(worker_id)) => {
473 let query = query
474 .bind(calculate_status(&ctx, &res).to_string())
475 .bind(val)
476 .bind(res.attempt.current() as i32)
477 .bind(res.task_id.to_string())
478 .bind(worker_id.to_string());
479 if let Err(e) = query.execute(&pool).await {
480 w.emit(Event::Error(Box::new(MysqlPollError::AckError(e))));
481 }
482 }
483 (Err(error), Some(_)) => {
484 w.emit(Event::Error(Box::new(MysqlPollError::CodecError(error))));
485 }
486 _ => {
487 unreachable!(
488 "Attempted to ACK without a worker attached. This is a bug, File it on the repo"
489 );
490 }
491 }
492 }
493
494 apalis_core::sleep(config.poll_interval).await;
495 }
496 };
497 let w = worker.clone();
498 let heartbeat = async move {
499 if let Err(e) = hb_storage
501 .reenqueue_orphaned((config.buffer_size * 10) as i32, Utc::now())
502 .await
503 {
504 w.emit(Event::Error(Box::new(
505 MysqlPollError::ReenqueueOrphanedError(e),
506 )));
507 }
508
509 loop {
510 let now = Utc::now();
511 if let Err(e) = hb_storage.keep_alive_at::<Self::Layer>(w.id(), now).await {
512 w.emit(Event::Error(Box::new(MysqlPollError::KeepAliveError(e))));
513 }
514 apalis_core::sleep(config.keep_alive).await;
515 }
516 };
517 let w = worker.clone();
518 let reenqueue_beat = async move {
519 loop {
520 let dead_since = Utc::now()
521 - chrono::Duration::from_std(config.reenqueue_orphaned_after)
522 .expect("Could not calculate dead since");
523 if let Err(e) = requeue_storage
524 .reenqueue_orphaned(
525 config
526 .buffer_size
527 .try_into()
528 .expect("Could not convert usize to i32"),
529 dead_since,
530 )
531 .await
532 {
533 w.emit(Event::Error(Box::new(
534 MysqlPollError::ReenqueueOrphanedError(e),
535 )));
536 }
537 apalis_core::sleep(config.poll_interval).await;
538 }
539 };
540 Poller::new_with_layer(
541 stream,
542 async {
543 futures::join!(heartbeat, ack_heartbeat, reenqueue_beat);
544 },
545 layer,
546 )
547 }
548}
549
550impl<T, Res, C> Ack<T, Res, C> for MysqlStorage<T, C>
551where
552 T: Sync + Send,
553 Res: Serialize + Send + 'static + Sync,
554 C: Codec<Compact = Value> + Send,
555 C::Error: Debug,
556{
557 type Context = SqlContext;
558 type AckError = sqlx::Error;
559 async fn ack(&mut self, ctx: &Self::Context, res: &Response<Res>) -> Result<(), sqlx::Error> {
560 self.ack_notify
561 .notify((
562 ctx.clone(),
563 res.map(|res| C::encode(res).expect("Could not encode result")),
564 ))
565 .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::BrokenPipe, e)))?;
566
567 Ok(())
568 }
569}
570
571impl<T, C: Codec> MysqlStorage<T, C> {
572 pub async fn kill(&mut self, worker_id: &WorkerId, job_id: &TaskId) -> Result<(), sqlx::Error> {
574 let pool = self.pool.clone();
575
576 let mut tx = pool.acquire().await?;
577 let query =
578 "UPDATE jobs SET status = 'Killed', done_at = NOW() WHERE id = ? AND lock_by = ?";
579 sqlx::query(query)
580 .bind(job_id.to_string())
581 .bind(worker_id.to_string())
582 .execute(&mut *tx)
583 .await?;
584 Ok(())
585 }
586
587 pub async fn retry(
589 &mut self,
590 worker_id: &WorkerId,
591 job_id: &TaskId,
592 ) -> Result<(), sqlx::Error> {
593 let pool = self.pool.clone();
594
595 let mut tx = pool.acquire().await?;
596 let query =
597 "UPDATE jobs SET status = 'Pending', done_at = NULL, lock_by = NULL WHERE id = ? AND lock_by = ?";
598 sqlx::query(query)
599 .bind(job_id.to_string())
600 .bind(worker_id.to_string())
601 .execute(&mut *tx)
602 .await?;
603 Ok(())
604 }
605
606 pub async fn reenqueue_orphaned(
608 &self,
609 count: i32,
610 dead_since: DateTime<Utc>,
611 ) -> Result<bool, sqlx::Error> {
612 let job_type = self.config.namespace.clone();
613 let mut tx = self.pool.acquire().await?;
614 let query = r#"Update jobs
615 INNER JOIN ( SELECT workers.id as worker_id, jobs.id as job_id from workers INNER JOIN jobs ON jobs.lock_by = workers.id WHERE jobs.attempts < jobs.max_attempts AND jobs.status = "Running" AND workers.last_seen < ? AND workers.worker_type = ?
616 ORDER BY lock_at ASC LIMIT ?) as workers ON jobs.lock_by = workers.worker_id AND jobs.id = workers.job_id
617 SET status = "Pending", done_at = NULL, lock_by = NULL, lock_at = NULL, last_error ="Job was abandoned", attempts = attempts + 1;"#;
618
619 sqlx::query(query)
620 .bind(dead_since)
621 .bind(job_type)
622 .bind(count)
623 .execute(&mut *tx)
624 .await?;
625 Ok(true)
626 }
627}
628
629impl<J: 'static + Serialize + DeserializeOwned + Unpin + Send + Sync> BackendExpose<J>
630 for MysqlStorage<J>
631{
632 type Request = Request<J, Parts<SqlContext>>;
633 type Error = SqlError;
634 async fn stats(&self) -> Result<Stat, Self::Error> {
635 let fetch_query = "SELECT
636 COUNT(CASE WHEN status = 'Pending' THEN 1 END) AS pending,
637 COUNT(CASE WHEN status = 'Running' THEN 1 END) AS running,
638 COUNT(CASE WHEN status = 'Done' THEN 1 END) AS done,
639 COUNT(CASE WHEN status = 'Retry' THEN 1 END) AS retry,
640 COUNT(CASE WHEN status = 'Failed' THEN 1 END) AS failed,
641 COUNT(CASE WHEN status = 'Killed' THEN 1 END) AS killed
642 FROM jobs WHERE job_type = ?";
643
644 let res: (i64, i64, i64, i64, i64, i64) = sqlx::query_as(fetch_query)
645 .bind(self.get_config().namespace())
646 .fetch_one(self.pool())
647 .await?;
648
649 Ok(Stat {
650 pending: res.0.try_into()?,
651 running: res.1.try_into()?,
652 dead: res.4.try_into()?,
653 failed: res.3.try_into()?,
654 success: res.2.try_into()?,
655 })
656 }
657
658 async fn list_jobs(
659 &self,
660 status: &State,
661 page: i32,
662 ) -> Result<Vec<Self::Request>, Self::Error> {
663 let status = status.to_string();
664 let fetch_query = "SELECT * FROM jobs WHERE status = ? AND job_type = ? ORDER BY done_at DESC, run_at DESC LIMIT 10 OFFSET ?";
665 let res: Vec<SqlRequest<serde_json::Value>> = sqlx::query_as(fetch_query)
666 .bind(status)
667 .bind(self.get_config().namespace())
668 .bind(((page - 1) * 10).to_string())
669 .fetch_all(self.pool())
670 .await?;
671 Ok(res
672 .into_iter()
673 .map(|j| {
674 let (req, ctx) = j.req.take_parts();
675 let req: J = MysqlCodec::decode(req).unwrap();
676 Request::new_with_ctx(req, ctx)
677 })
678 .collect())
679 }
680
681 async fn list_workers(&self) -> Result<Vec<Worker<WorkerState>>, Self::Error> {
682 let fetch_query =
683 "SELECT id, layers, last_seen FROM workers WHERE worker_type = ? ORDER BY last_seen DESC LIMIT 20 OFFSET ?";
684 let res: Vec<(String, String, i64)> = sqlx::query_as(fetch_query)
685 .bind(self.get_config().namespace())
686 .bind(0)
687 .fetch_all(self.pool())
688 .await?;
689 Ok(res
690 .into_iter()
691 .map(|w| Worker::new(WorkerId::new(w.0), WorkerState::new::<Self>(w.1)))
692 .collect())
693 }
694}
695
696#[cfg(test)]
697mod tests {
698
699 use crate::sql_storage_tests;
700
701 use super::*;
702
703 use apalis_core::test_utils::DummyService;
704 use email_service::Email;
705 use futures::StreamExt;
706
707 use apalis_core::generic_storage_test;
708 use apalis_core::test_utils::apalis_test_service_fn;
709 use apalis_core::test_utils::TestWrapper;
710
711 generic_storage_test!(setup);
712
713 sql_storage_tests!(setup::<Email>, MysqlStorage<Email>, Email);
714
715 async fn setup<T: Serialize + DeserializeOwned>() -> MysqlStorage<T> {
717 let db_url = &std::env::var("DATABASE_URL").expect("No DATABASE_URL is specified");
718 let pool = MySqlPool::connect(db_url).await.unwrap();
722 MysqlStorage::setup(&pool)
723 .await
724 .expect("failed to migrate DB");
725 let mut storage = MysqlStorage::new(pool);
726 cleanup(&mut storage, &WorkerId::new("test-worker")).await;
727 storage
728 }
729
730 async fn cleanup<T>(storage: &mut MysqlStorage<T>, worker_id: &WorkerId) {
737 sqlx::query("DELETE FROM jobs WHERE job_type = ?")
738 .bind(storage.config.namespace())
739 .execute(&storage.pool)
740 .await
741 .expect("failed to delete jobs");
742 sqlx::query("DELETE FROM workers WHERE id = ?")
743 .bind(worker_id.to_string())
744 .execute(&storage.pool)
745 .await
746 .expect("failed to delete worker");
747 }
748
749 async fn consume_one(
750 storage: &mut MysqlStorage<Email>,
751 worker: &Worker<Context>,
752 ) -> Request<Email, SqlContext> {
753 let mut stream = storage
754 .clone()
755 .stream_jobs(worker, std::time::Duration::from_secs(10), 1);
756 stream
757 .next()
758 .await
759 .expect("stream is empty")
760 .expect("failed to poll job")
761 .expect("no job is pending")
762 }
763
764 fn example_email() -> Email {
765 Email {
766 subject: "Test Subject".to_string(),
767 to: "example@mysql".to_string(),
768 text: "Some Text".to_string(),
769 }
770 }
771
772 async fn register_worker_at(
773 storage: &mut MysqlStorage<Email>,
774 last_seen: DateTime<Utc>,
775 ) -> Worker<Context> {
776 let worker_id = WorkerId::new("test-worker");
777 let wrk = Worker::new(worker_id, Context::default());
778 wrk.start();
779 storage
780 .keep_alive_at::<DummyService>(&wrk.id(), last_seen)
781 .await
782 .expect("failed to register worker");
783 wrk
784 }
785
786 async fn register_worker(storage: &mut MysqlStorage<Email>) -> Worker<Context> {
787 let now = Utc::now();
788
789 register_worker_at(storage, now).await
790 }
791
792 async fn push_email(storage: &mut MysqlStorage<Email>, email: Email) {
793 storage.push(email).await.expect("failed to push a job");
794 }
795
796 async fn get_job(
797 storage: &mut MysqlStorage<Email>,
798 job_id: &TaskId,
799 ) -> Request<Email, SqlContext> {
800 apalis_core::sleep(Duration::from_secs(1)).await;
802 storage
803 .fetch_by_id(job_id)
804 .await
805 .expect("failed to fetch job by id")
806 .expect("no job found by id")
807 }
808
809 #[tokio::test]
810 async fn test_consume_last_pushed_job() {
811 let mut storage = setup().await;
812 push_email(&mut storage, example_email()).await;
813
814 let worker = register_worker(&mut storage).await;
815
816 let job = consume_one(&mut storage, &worker).await;
817 let ctx = job.parts.context;
818 assert_eq!(*ctx.status(), State::Running);
820 assert_eq!(*ctx.lock_by(), Some(worker.id().clone()));
821 assert!(ctx.lock_at().is_some());
822 }
823
824 #[tokio::test]
825 async fn test_kill_job() {
826 let mut storage = setup().await;
827
828 push_email(&mut storage, example_email()).await;
829
830 let worker = register_worker(&mut storage).await;
831
832 let job = consume_one(&mut storage, &worker).await;
833
834 let job_id = &job.parts.task_id;
835
836 storage
837 .kill(worker.id(), job_id)
838 .await
839 .expect("failed to kill job");
840
841 let job = get_job(&mut storage, job_id).await;
842 let ctx = job.parts.context;
843 assert_eq!(*ctx.status(), State::Killed);
845 assert!(ctx.done_at().is_some());
846 }
847
848 #[tokio::test]
849 async fn test_storage_heartbeat_reenqueuorphaned_pulse_last_seen_6min() {
850 let mut storage = setup().await;
851
852 storage
854 .push(example_email())
855 .await
856 .expect("failed to push job");
857
858 let worker_id = WorkerId::new("test-worker");
860 let worker = Worker::new(worker_id, Context::default());
861 worker.start();
862 let five_minutes_ago = Utc::now() - Duration::from_secs(5 * 60);
863
864 let six_minutes_ago = Utc::now() - Duration::from_secs(60 * 6);
865
866 storage
867 .keep_alive_at::<Email>(worker.id(), six_minutes_ago)
868 .await
869 .unwrap();
870
871 let job = consume_one(&mut storage, &worker).await;
873 let ctx = job.parts.context;
874
875 assert_eq!(*ctx.status(), State::Running);
876
877 storage
878 .reenqueue_orphaned(1, five_minutes_ago)
879 .await
880 .unwrap();
881
882 let job = storage
884 .fetch_by_id(&job.parts.task_id)
885 .await
886 .unwrap()
887 .unwrap();
888 let ctx = job.parts.context;
889 assert_eq!(*ctx.status(), State::Pending);
890 assert!(ctx.done_at().is_none());
891 assert!(ctx.lock_by().is_none());
892 assert!(ctx.lock_at().is_none());
893 assert_eq!(*ctx.last_error(), Some("Job was abandoned".to_owned()));
894 assert_eq!(job.parts.attempt.current(), 1);
895 }
896
897 #[tokio::test]
898 async fn test_storage_heartbeat_reenqueuorphaned_pulse_last_seen_4min() {
899 let mut storage = setup().await;
900
901 let service = apalis_test_service_fn(|_: Request<Email, _>| async move {
902 apalis_core::sleep(Duration::from_millis(500)).await;
903 Ok::<_, io::Error>("success")
904 });
905 let (mut t, poller) = TestWrapper::new_with_service(storage.clone(), service);
906 let four_minutes_ago = Utc::now() - Duration::from_secs(4 * 60);
907 storage
908 .keep_alive_at::<Email>(&t.worker.id(), four_minutes_ago)
909 .await
910 .unwrap();
911
912 tokio::spawn(poller);
913
914 let parts = storage
916 .push(example_email())
917 .await
918 .expect("failed to push job");
919
920 let six_minutes_ago = Utc::now() - Duration::from_secs(6 * 60);
922 storage
924 .reenqueue_orphaned(1, six_minutes_ago)
925 .await
926 .unwrap();
927
928 let job = storage.fetch_by_id(&parts.task_id).await.unwrap().unwrap();
930 let ctx = job.parts.context;
931 assert_eq!(*ctx.status(), State::Pending);
932 assert_eq!(*ctx.lock_by(), None);
933 assert!(ctx.lock_at().is_none());
934 assert_eq!(*ctx.last_error(), None);
935 assert_eq!(job.parts.attempt.current(), 0);
936
937 let res = t.execute_next().await.unwrap();
938
939 apalis_core::sleep(Duration::from_millis(1000)).await;
940
941 let job = storage.fetch_by_id(&res.0).await.unwrap().unwrap();
942 let ctx = job.parts.context;
943 assert_eq!(*ctx.status(), State::Done);
944 assert_eq!(*ctx.lock_by(), Some(t.worker.id().clone()));
945 assert!(ctx.lock_at().is_some());
946 assert_eq!(*ctx.last_error(), Some("{\"Ok\":\"success\"}".to_owned()));
947 assert_eq!(job.parts.attempt.current(), 1);
948 }
949}