apalis_sql/
mysql.rs

1use apalis_core::backend::{BackendExpose, Stat, WorkerState};
2use apalis_core::codec::json::JsonCodec;
3use apalis_core::error::{BoxDynError, Error};
4use apalis_core::layers::{Ack, AckLayer};
5use apalis_core::notify::Notify;
6use apalis_core::poller::controller::Controller;
7use apalis_core::poller::stream::BackendStream;
8use apalis_core::poller::Poller;
9use apalis_core::request::{Parts, Request, RequestStream, State};
10use apalis_core::response::Response;
11use apalis_core::storage::Storage;
12use apalis_core::task::namespace::Namespace;
13use apalis_core::task::task_id::TaskId;
14use apalis_core::worker::{Context, Event, Worker, WorkerId};
15use apalis_core::{backend::Backend, codec::Codec};
16use async_stream::try_stream;
17use chrono::{DateTime, Utc};
18use futures::{Stream, StreamExt, TryStreamExt};
19use log::error;
20use serde::{de::DeserializeOwned, Serialize};
21use serde_json::Value;
22use sqlx::mysql::MySqlRow;
23use sqlx::{MySql, Pool, Row};
24use std::any::type_name;
25use std::convert::TryInto;
26use std::fmt::Debug;
27use std::sync::Arc;
28use std::{fmt, io};
29use std::{marker::PhantomData, ops::Add, time::Duration};
30
31use crate::context::SqlContext;
32use crate::from_row::SqlRequest;
33use crate::{calculate_status, Config, SqlError};
34
35pub use sqlx::mysql::MySqlPool;
36
37type MysqlCodec = JsonCodec<Value>;
38
39/// Represents a [Storage] that persists to MySQL
40
41pub struct MysqlStorage<T, C = JsonCodec<Value>>
42where
43    C: Codec,
44{
45    pool: Pool<MySql>,
46    job_type: PhantomData<T>,
47    controller: Controller,
48    config: Config,
49    codec: PhantomData<C>,
50    ack_notify: Notify<(SqlContext, Response<C::Compact>)>,
51}
52
53impl<T, C> fmt::Debug for MysqlStorage<T, C>
54where
55    C: Codec,
56    C::Compact: Debug,
57{
58    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
59        f.debug_struct("MysqlStorage")
60            .field("pool", &self.pool)
61            .field("job_type", &"PhantomData<T>")
62            .field("controller", &self.controller)
63            .field("config", &self.config)
64            .field("codec", &std::any::type_name::<C>())
65            .field("ack_notify", &self.ack_notify)
66            .finish()
67    }
68}
69
70impl<T, C> Clone for MysqlStorage<T, C>
71where
72    C: Codec,
73{
74    fn clone(&self) -> Self {
75        let pool = self.pool.clone();
76        MysqlStorage {
77            pool,
78            job_type: PhantomData,
79            controller: self.controller.clone(),
80            config: self.config.clone(),
81            codec: self.codec,
82            ack_notify: self.ack_notify.clone(),
83        }
84    }
85}
86
87impl MysqlStorage<(), JsonCodec<Value>> {
88    /// Get mysql migrations without running them
89    #[cfg(feature = "migrate")]
90    pub fn migrations() -> sqlx::migrate::Migrator {
91        sqlx::migrate!("migrations/mysql")
92    }
93
94    /// Do migrations for mysql
95    #[cfg(feature = "migrate")]
96    pub async fn setup(pool: &Pool<MySql>) -> Result<(), sqlx::Error> {
97        Self::migrations().run(pool).await?;
98        Ok(())
99    }
100}
101
102impl<T> MysqlStorage<T>
103where
104    T: Serialize + DeserializeOwned,
105{
106    /// Create a new instance from a pool
107    pub fn new(pool: MySqlPool) -> Self {
108        Self::new_with_config(pool, Config::new(type_name::<T>()))
109    }
110
111    /// Create a new instance from a pool and custom config
112    pub fn new_with_config(pool: MySqlPool, config: Config) -> Self {
113        Self {
114            pool,
115            job_type: PhantomData,
116            controller: Controller::new(),
117            config,
118            ack_notify: Notify::new(),
119            codec: PhantomData,
120        }
121    }
122
123    /// Expose the pool for other functionality, eg custom migrations
124    pub fn pool(&self) -> &Pool<MySql> {
125        &self.pool
126    }
127
128    /// Get the config used by the storage
129    pub fn get_config(&self) -> &Config {
130        &self.config
131    }
132}
133
134impl<T, C> MysqlStorage<T, C>
135where
136    T: DeserializeOwned + Send + Sync + 'static,
137    C: Codec<Compact = Value> + Send + 'static,
138{
139    fn stream_jobs(
140        self,
141        worker: &Worker<Context>,
142        interval: Duration,
143        buffer_size: usize,
144    ) -> impl Stream<Item = Result<Option<Request<T, SqlContext>>, sqlx::Error>> {
145        let pool = self.pool.clone();
146        let worker = worker.clone();
147        let worker_id = worker.id().to_string();
148
149        try_stream! {
150            let buffer_size = u32::try_from(buffer_size)
151                    .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidInput, e)))?;
152            loop {
153                apalis_core::sleep(interval).await;
154                if !worker.is_ready() {
155                    continue;
156                }
157                let pool = pool.clone();
158                let job_type = self.config.namespace.clone();
159                let mut tx = pool.begin().await?;
160                let fetch_query = "SELECT id FROM jobs
161                WHERE (status='Pending' OR (status = 'Failed' AND attempts < max_attempts)) AND run_at <= NOW() AND job_type = ? ORDER BY priority DESC, run_at ASC LIMIT ? FOR UPDATE SKIP LOCKED";
162                let task_ids: Vec<MySqlRow> = sqlx::query(fetch_query)
163                    .bind(job_type)
164                    .bind(buffer_size)
165                    .fetch_all(&mut *tx).await?;
166                if task_ids.is_empty() {
167                    tx.rollback().await?;
168                    yield None
169                } else {
170                    let task_ids: Vec<String> = task_ids.iter().map(|r| r.get_unchecked("id")).collect();
171                    let id_params = format!("?{}", ", ?".repeat(task_ids.len() - 1));
172                    let query = format!("UPDATE jobs SET status = 'Running', lock_by = ?, lock_at = NOW() WHERE id IN({}) AND status = 'Pending' AND lock_by IS NULL;", id_params);
173                    let mut query = sqlx::query(&query).bind(worker_id.clone());
174                    for i in &task_ids {
175                        query = query.bind(i);
176                    }
177                    query.execute(&mut *tx).await?;
178                    tx.commit().await?;
179
180                    let fetch_query = format!("SELECT * FROM jobs WHERE ID IN ({}) ORDER BY priority DESC, run_at ASC", id_params);
181                    let mut query = sqlx::query_as(&fetch_query);
182                    for i in task_ids {
183                        query = query.bind(i);
184                    }
185                    let jobs: Vec<SqlRequest<Value>> = query.fetch_all(&pool).await?;
186
187                    for job in jobs {
188                        yield {
189                            let (req, ctx) = job.req.take_parts();
190                            let req = C::decode(req)
191                                .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
192                            let mut req: Request<T, SqlContext> = Request::new_with_parts(req, ctx);
193                            req.parts.namespace = Some(Namespace(self.config.namespace.clone()));
194                            Some(req)
195                        }
196                    }
197                }
198            }
199        }.boxed()
200    }
201
202    async fn keep_alive_at<Service>(
203        &mut self,
204        worker_id: &WorkerId,
205        last_seen: DateTime<Utc>,
206    ) -> Result<(), sqlx::Error> {
207        let pool = self.pool.clone();
208
209        let mut tx = pool.acquire().await?;
210        let worker_type = self.config.namespace.clone();
211        let storage_name = std::any::type_name::<Self>();
212        let query =
213            "INSERT INTO workers (id, worker_type, storage_name, layers, last_seen) VALUES (?, ?, ?, ?, ?) ON DUPLICATE KEY UPDATE id = ?;";
214        sqlx::query(query)
215            .bind(worker_id.to_string())
216            .bind(worker_type)
217            .bind(storage_name)
218            .bind(std::any::type_name::<Service>())
219            .bind(last_seen)
220            .bind(worker_id.to_string())
221            .execute(&mut *tx)
222            .await?;
223        Ok(())
224    }
225}
226
227impl<T, C> Storage for MysqlStorage<T, C>
228where
229    T: Serialize + DeserializeOwned + Send + 'static + Unpin + Sync,
230    C: Codec<Compact = Value> + Send + Sync + 'static,
231    C::Error: std::error::Error + 'static + Send + Sync,
232{
233    type Job = T;
234
235    type Error = sqlx::Error;
236
237    type Context = SqlContext;
238
239    type Compact = Value;
240
241    async fn push_request(
242        &mut self,
243        job: Request<Self::Job, SqlContext>,
244    ) -> Result<Parts<SqlContext>, sqlx::Error> {
245        let (args, parts) = job.take_parts();
246        let query =
247            "INSERT INTO jobs VALUES (?, ?, ?, 'Pending', 0, ?, now(), NULL, NULL, NULL, NULL, ?)";
248        let pool = self.pool.clone();
249
250        let job = C::encode(args)
251            .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
252        let job_type = self.config.namespace.clone();
253        sqlx::query(query)
254            .bind(job)
255            .bind(parts.task_id.to_string())
256            .bind(job_type.to_string())
257            .bind(parts.context.max_attempts())
258            .bind(parts.context.priority())
259            .execute(&pool)
260            .await?;
261        Ok(parts)
262    }
263
264    async fn push_raw_request(
265        &mut self,
266        job: Request<Self::Compact, SqlContext>,
267    ) -> Result<Parts<SqlContext>, sqlx::Error> {
268        let (args, parts) = job.take_parts();
269        let query =
270            "INSERT INTO jobs VALUES (?, ?, ?, 'Pending', 0, ?, now(), NULL, NULL, NULL, NULL, ?)";
271        let pool = self.pool.clone();
272
273        let job = C::encode(args)
274            .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
275        let job_type = self.config.namespace.clone();
276        sqlx::query(query)
277            .bind(job)
278            .bind(parts.task_id.to_string())
279            .bind(job_type.to_string())
280            .bind(parts.context.max_attempts())
281            .bind(parts.context.priority())
282            .execute(&pool)
283            .await?;
284        Ok(parts)
285    }
286
287    async fn schedule_request(
288        &mut self,
289        req: Request<Self::Job, SqlContext>,
290        on: i64,
291    ) -> Result<Parts<Self::Context>, sqlx::Error> {
292        let query =
293            "INSERT INTO jobs VALUES (?, ?, ?, 'Pending', 0, ?, ?, NULL, NULL, NULL, NULL, ?)";
294        let pool = self.pool.clone();
295
296        let args = C::encode(&req.args)
297            .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
298
299        let job_type = self.config.namespace.clone();
300        sqlx::query(query)
301            .bind(args)
302            .bind(req.parts.task_id.to_string())
303            .bind(job_type)
304            .bind(req.parts.context.max_attempts())
305            .bind(on)
306            .bind(req.parts.context.priority())
307            .execute(&pool)
308            .await?;
309        Ok(req.parts)
310    }
311
312    async fn fetch_by_id(
313        &mut self,
314        job_id: &TaskId,
315    ) -> Result<Option<Request<Self::Job, SqlContext>>, sqlx::Error> {
316        let pool = self.pool.clone();
317
318        let fetch_query = "SELECT * FROM jobs WHERE id = ?";
319        let res: Option<SqlRequest<serde_json::Value>> = sqlx::query_as(fetch_query)
320            .bind(job_id.to_string())
321            .fetch_optional(&pool)
322            .await?;
323        match res {
324            None => Ok(None),
325            Some(job) => Ok(Some({
326                let (req, parts) = job.req.take_parts();
327                let req = C::decode(req)
328                    .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
329                let mut req = Request::new_with_parts(req, parts);
330                req.parts.namespace = Some(Namespace(self.config.namespace.clone()));
331                req
332            })),
333        }
334    }
335
336    async fn len(&mut self) -> Result<i64, sqlx::Error> {
337        let pool = self.pool.clone();
338
339        let query = "Select Count(*) as count from jobs where status='Pending'";
340        let record = sqlx::query(query).fetch_one(&pool).await?;
341        record.try_get("count")
342    }
343
344    async fn reschedule(
345        &mut self,
346        job: Request<T, SqlContext>,
347        wait: Duration,
348    ) -> Result<(), sqlx::Error> {
349        let pool = self.pool.clone();
350        let job_id = job.parts.task_id.clone();
351
352        let wait: i64 = wait
353            .as_secs()
354            .try_into()
355            .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidInput, e)))?;
356        let mut tx = pool.acquire().await?;
357        let query =
358                "UPDATE jobs SET status = 'Pending', done_at = NULL, lock_by = NULL, lock_at = NULL, run_at = ? WHERE id = ?";
359
360        sqlx::query(query)
361            .bind(Utc::now().timestamp().add(wait))
362            .bind(job_id.to_string())
363            .execute(&mut *tx)
364            .await?;
365        Ok(())
366    }
367
368    async fn update(&mut self, job: Request<Self::Job, SqlContext>) -> Result<(), sqlx::Error> {
369        let pool = self.pool.clone();
370        let ctx = job.parts.context;
371        let status = ctx.status().to_string();
372        let attempts = job.parts.attempt;
373        let done_at = *ctx.done_at();
374        let lock_by = ctx.lock_by().clone();
375        let lock_at = *ctx.lock_at();
376        let last_error = ctx.last_error().clone();
377        let priority = *ctx.priority();
378        let job_id = job.parts.task_id;
379        let mut tx = pool.acquire().await?;
380        let query =
381                "UPDATE jobs SET status = ?, attempts = ?, done_at = ?, lock_by = ?, lock_at = ?, last_error = ?, priority = ? WHERE id = ?";
382        sqlx::query(query)
383            .bind(status.to_owned())
384            .bind::<i64>(
385                attempts
386                    .current()
387                    .try_into()
388                    .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidInput, e)))?,
389            )
390            .bind(done_at)
391            .bind(lock_by.map(|w| w.name().to_string()))
392            .bind(lock_at)
393            .bind(last_error)
394            .bind(priority)
395            .bind(job_id.to_string())
396            .execute(&mut *tx)
397            .await?;
398        Ok(())
399    }
400
401    async fn is_empty(&mut self) -> Result<bool, Self::Error> {
402        Ok(self.len().await? == 0)
403    }
404
405    async fn vacuum(&mut self) -> Result<usize, sqlx::Error> {
406        let pool = self.pool.clone();
407        let query = "Delete from jobs where status='Done'";
408        let record = sqlx::query(query).execute(&pool).await?;
409        Ok(record.rows_affected().try_into().unwrap_or_default())
410    }
411}
412
413/// Errors that can occur while polling a MySQL database.
414#[derive(thiserror::Error, Debug)]
415pub enum MysqlPollError {
416    /// Error during task acknowledgment.
417    #[error("Encountered an error during ACK: `{0}`")]
418    AckError(sqlx::Error),
419
420    /// Error during result encoding.
421    #[error("Encountered an error during encoding the result: {0}")]
422    CodecError(BoxDynError),
423
424    /// Error during a keep-alive heartbeat.
425    #[error("Encountered an error during KeepAlive heartbeat: `{0}`")]
426    KeepAliveError(sqlx::Error),
427
428    /// Error during re-enqueuing orphaned tasks.
429    #[error("Encountered an error during ReenqueueOrphaned heartbeat: `{0}`")]
430    ReenqueueOrphanedError(sqlx::Error),
431}
432
433impl<Req, C> Backend<Request<Req, SqlContext>> for MysqlStorage<Req, C>
434where
435    Req: Serialize + DeserializeOwned + Sync + Send + 'static,
436    C: Codec<Compact = Value> + Send + 'static + Sync,
437    C::Error: std::error::Error + 'static + Send + Sync,
438{
439    type Stream = BackendStream<RequestStream<Request<Req, SqlContext>>>;
440
441    type Layer = AckLayer<MysqlStorage<Req, C>, Req, SqlContext, C>;
442
443    type Codec = C;
444
445    fn poll(self, worker: &Worker<Context>) -> Poller<Self::Stream, Self::Layer> {
446        let layer = AckLayer::new(self.clone());
447        let config = self.config.clone();
448        let controller = self.controller.clone();
449        let pool = self.pool.clone();
450        let ack_notify = self.ack_notify.clone();
451        let mut hb_storage = self.clone();
452        let requeue_storage = self.clone();
453        let stream = self
454            .stream_jobs(worker, config.poll_interval, config.buffer_size)
455            .map_err(|e| Error::SourceError(Arc::new(Box::new(e))));
456        let stream = BackendStream::new(stream.boxed(), controller);
457        let w = worker.clone();
458
459        let ack_heartbeat = async move {
460            while let Some(ids) = ack_notify
461                .clone()
462                .ready_chunks(config.buffer_size)
463                .next()
464                .await
465            {
466                for (ctx, res) in ids {
467                    let query = "UPDATE jobs SET status = ?, done_at = now(), last_error = ?, attempts = ? WHERE id = ? AND lock_by = ?";
468                    let query = sqlx::query(query);
469                    let last_result =
470                        C::encode(res.inner.as_ref().map_err(|e| e.to_string())).map_err(Box::new);
471                    match (last_result, ctx.lock_by()) {
472                        (Ok(val), Some(worker_id)) => {
473                            let query = query
474                                .bind(calculate_status(&ctx, &res).to_string())
475                                .bind(val)
476                                .bind(res.attempt.current() as i32)
477                                .bind(res.task_id.to_string())
478                                .bind(worker_id.to_string());
479                            if let Err(e) = query.execute(&pool).await {
480                                w.emit(Event::Error(Box::new(MysqlPollError::AckError(e))));
481                            }
482                        }
483                        (Err(error), Some(_)) => {
484                            w.emit(Event::Error(Box::new(MysqlPollError::CodecError(error))));
485                        }
486                        _ => {
487                            unreachable!(
488                                "Attempted to ACK without a worker attached. This is a bug, File it on the repo"
489                            );
490                        }
491                    }
492                }
493
494                apalis_core::sleep(config.poll_interval).await;
495            }
496        };
497        let w = worker.clone();
498        let heartbeat = async move {
499            // Lets reenqueue any jobs that belonged to this worker in case of a death
500            if let Err(e) = hb_storage
501                .reenqueue_orphaned((config.buffer_size * 10) as i32, Utc::now())
502                .await
503            {
504                w.emit(Event::Error(Box::new(
505                    MysqlPollError::ReenqueueOrphanedError(e),
506                )));
507            }
508
509            loop {
510                let now = Utc::now();
511                if let Err(e) = hb_storage.keep_alive_at::<Self::Layer>(w.id(), now).await {
512                    w.emit(Event::Error(Box::new(MysqlPollError::KeepAliveError(e))));
513                }
514                apalis_core::sleep(config.keep_alive).await;
515            }
516        };
517        let w = worker.clone();
518        let reenqueue_beat = async move {
519            loop {
520                let dead_since = Utc::now()
521                    - chrono::Duration::from_std(config.reenqueue_orphaned_after)
522                        .expect("Could not calculate dead since");
523                if let Err(e) = requeue_storage
524                    .reenqueue_orphaned(
525                        config
526                            .buffer_size
527                            .try_into()
528                            .expect("Could not convert usize to i32"),
529                        dead_since,
530                    )
531                    .await
532                {
533                    w.emit(Event::Error(Box::new(
534                        MysqlPollError::ReenqueueOrphanedError(e),
535                    )));
536                }
537                apalis_core::sleep(config.poll_interval).await;
538            }
539        };
540        Poller::new_with_layer(
541            stream,
542            async {
543                futures::join!(heartbeat, ack_heartbeat, reenqueue_beat);
544            },
545            layer,
546        )
547    }
548}
549
550impl<T, Res, C> Ack<T, Res, C> for MysqlStorage<T, C>
551where
552    T: Sync + Send,
553    Res: Serialize + Send + 'static + Sync,
554    C: Codec<Compact = Value> + Send,
555    C::Error: Debug,
556{
557    type Context = SqlContext;
558    type AckError = sqlx::Error;
559    async fn ack(&mut self, ctx: &Self::Context, res: &Response<Res>) -> Result<(), sqlx::Error> {
560        self.ack_notify
561            .notify((
562                ctx.clone(),
563                res.map(|res| C::encode(res).expect("Could not encode result")),
564            ))
565            .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::BrokenPipe, e)))?;
566
567        Ok(())
568    }
569}
570
571impl<T, C: Codec> MysqlStorage<T, C> {
572    /// Kill a job
573    pub async fn kill(&mut self, worker_id: &WorkerId, job_id: &TaskId) -> Result<(), sqlx::Error> {
574        let pool = self.pool.clone();
575
576        let mut tx = pool.acquire().await?;
577        let query =
578            "UPDATE jobs SET status = 'Killed', done_at = NOW() WHERE id = ? AND lock_by = ?";
579        sqlx::query(query)
580            .bind(job_id.to_string())
581            .bind(worker_id.to_string())
582            .execute(&mut *tx)
583            .await?;
584        Ok(())
585    }
586
587    /// Puts the job instantly back into the queue
588    pub async fn retry(
589        &mut self,
590        worker_id: &WorkerId,
591        job_id: &TaskId,
592    ) -> Result<(), sqlx::Error> {
593        let pool = self.pool.clone();
594
595        let mut tx = pool.acquire().await?;
596        let query =
597                "UPDATE jobs SET status = 'Pending', done_at = NULL, lock_by = NULL WHERE id = ? AND lock_by = ?";
598        sqlx::query(query)
599            .bind(job_id.to_string())
600            .bind(worker_id.to_string())
601            .execute(&mut *tx)
602            .await?;
603        Ok(())
604    }
605
606    /// Readd jobs that are abandoned to the queue
607    pub async fn reenqueue_orphaned(
608        &self,
609        count: i32,
610        dead_since: DateTime<Utc>,
611    ) -> Result<bool, sqlx::Error> {
612        let job_type = self.config.namespace.clone();
613        let mut tx = self.pool.acquire().await?;
614        let query = r#"Update jobs
615                        INNER JOIN ( SELECT workers.id as worker_id, jobs.id as job_id from workers INNER JOIN jobs ON jobs.lock_by = workers.id WHERE jobs.attempts < jobs.max_attempts AND jobs.status = "Running" AND workers.last_seen < ? AND workers.worker_type = ?
616                            ORDER BY lock_at ASC LIMIT ?) as workers ON jobs.lock_by = workers.worker_id AND jobs.id = workers.job_id
617                        SET status = "Pending", done_at = NULL, lock_by = NULL, lock_at = NULL, last_error ="Job was abandoned", attempts = attempts + 1;"#;
618
619        sqlx::query(query)
620            .bind(dead_since)
621            .bind(job_type)
622            .bind(count)
623            .execute(&mut *tx)
624            .await?;
625        Ok(true)
626    }
627}
628
629impl<J: 'static + Serialize + DeserializeOwned + Unpin + Send + Sync> BackendExpose<J>
630    for MysqlStorage<J>
631{
632    type Request = Request<J, Parts<SqlContext>>;
633    type Error = SqlError;
634    async fn stats(&self) -> Result<Stat, Self::Error> {
635        let fetch_query = "SELECT
636            COUNT(CASE WHEN status = 'Pending' THEN 1 END) AS pending,
637            COUNT(CASE WHEN status = 'Running' THEN 1 END) AS running,
638            COUNT(CASE WHEN status = 'Done' THEN 1 END) AS done,
639            COUNT(CASE WHEN status = 'Retry' THEN 1 END) AS retry,
640            COUNT(CASE WHEN status = 'Failed' THEN 1 END) AS failed,
641            COUNT(CASE WHEN status = 'Killed' THEN 1 END) AS killed
642        FROM jobs WHERE job_type = ?";
643
644        let res: (i64, i64, i64, i64, i64, i64) = sqlx::query_as(fetch_query)
645            .bind(self.get_config().namespace())
646            .fetch_one(self.pool())
647            .await?;
648
649        Ok(Stat {
650            pending: res.0.try_into()?,
651            running: res.1.try_into()?,
652            dead: res.4.try_into()?,
653            failed: res.3.try_into()?,
654            success: res.2.try_into()?,
655        })
656    }
657
658    async fn list_jobs(
659        &self,
660        status: &State,
661        page: i32,
662    ) -> Result<Vec<Self::Request>, Self::Error> {
663        let status = status.to_string();
664        let fetch_query = "SELECT * FROM jobs WHERE status = ? AND job_type = ? ORDER BY done_at DESC, run_at DESC LIMIT 10 OFFSET ?";
665        let res: Vec<SqlRequest<serde_json::Value>> = sqlx::query_as(fetch_query)
666            .bind(status)
667            .bind(self.get_config().namespace())
668            .bind(((page - 1) * 10).to_string())
669            .fetch_all(self.pool())
670            .await?;
671        Ok(res
672            .into_iter()
673            .map(|j| {
674                let (req, ctx) = j.req.take_parts();
675                let req: J = MysqlCodec::decode(req).unwrap();
676                Request::new_with_ctx(req, ctx)
677            })
678            .collect())
679    }
680
681    async fn list_workers(&self) -> Result<Vec<Worker<WorkerState>>, Self::Error> {
682        let fetch_query =
683            "SELECT id, layers, last_seen FROM workers WHERE worker_type = ? ORDER BY last_seen DESC LIMIT 20 OFFSET ?";
684        let res: Vec<(String, String, i64)> = sqlx::query_as(fetch_query)
685            .bind(self.get_config().namespace())
686            .bind(0)
687            .fetch_all(self.pool())
688            .await?;
689        Ok(res
690            .into_iter()
691            .map(|w| Worker::new(WorkerId::new(w.0), WorkerState::new::<Self>(w.1)))
692            .collect())
693    }
694}
695
696#[cfg(test)]
697mod tests {
698
699    use crate::sql_storage_tests;
700
701    use super::*;
702
703    use apalis_core::test_utils::DummyService;
704    use email_service::Email;
705    use futures::StreamExt;
706
707    use apalis_core::generic_storage_test;
708    use apalis_core::test_utils::apalis_test_service_fn;
709    use apalis_core::test_utils::TestWrapper;
710
711    generic_storage_test!(setup);
712
713    sql_storage_tests!(setup::<Email>, MysqlStorage<Email>, Email);
714
715    /// migrate DB and return a storage instance.
716    async fn setup<T: Serialize + DeserializeOwned>() -> MysqlStorage<T> {
717        let db_url = &std::env::var("DATABASE_URL").expect("No DATABASE_URL is specified");
718        // Because connections cannot be shared across async runtime
719        // (different runtimes are created for each test),
720        // we don't share the storage and tests must be run sequentially.
721        let pool = MySqlPool::connect(db_url).await.unwrap();
722        MysqlStorage::setup(&pool)
723            .await
724            .expect("failed to migrate DB");
725        let mut storage = MysqlStorage::new(pool);
726        cleanup(&mut storage, &WorkerId::new("test-worker")).await;
727        storage
728    }
729
730    /// rollback DB changes made by tests.
731    /// Delete the following rows:
732    ///  - jobs whose state is `Pending` or locked by `worker_id`
733    ///  - worker identified by `worker_id`
734    ///
735    /// You should execute this function in the end of a test
736    async fn cleanup<T>(storage: &mut MysqlStorage<T>, worker_id: &WorkerId) {
737        sqlx::query("DELETE FROM jobs WHERE job_type = ?")
738            .bind(storage.config.namespace())
739            .execute(&storage.pool)
740            .await
741            .expect("failed to delete jobs");
742        sqlx::query("DELETE FROM workers WHERE id = ?")
743            .bind(worker_id.to_string())
744            .execute(&storage.pool)
745            .await
746            .expect("failed to delete worker");
747    }
748
749    async fn consume_one(
750        storage: &mut MysqlStorage<Email>,
751        worker: &Worker<Context>,
752    ) -> Request<Email, SqlContext> {
753        let mut stream = storage
754            .clone()
755            .stream_jobs(worker, std::time::Duration::from_secs(10), 1);
756        stream
757            .next()
758            .await
759            .expect("stream is empty")
760            .expect("failed to poll job")
761            .expect("no job is pending")
762    }
763
764    fn example_email() -> Email {
765        Email {
766            subject: "Test Subject".to_string(),
767            to: "example@mysql".to_string(),
768            text: "Some Text".to_string(),
769        }
770    }
771
772    async fn register_worker_at(
773        storage: &mut MysqlStorage<Email>,
774        last_seen: DateTime<Utc>,
775    ) -> Worker<Context> {
776        let worker_id = WorkerId::new("test-worker");
777        let wrk = Worker::new(worker_id, Context::default());
778        wrk.start();
779        storage
780            .keep_alive_at::<DummyService>(&wrk.id(), last_seen)
781            .await
782            .expect("failed to register worker");
783        wrk
784    }
785
786    async fn register_worker(storage: &mut MysqlStorage<Email>) -> Worker<Context> {
787        let now = Utc::now();
788
789        register_worker_at(storage, now).await
790    }
791
792    async fn push_email(storage: &mut MysqlStorage<Email>, email: Email) {
793        storage.push(email).await.expect("failed to push a job");
794    }
795
796    async fn get_job(
797        storage: &mut MysqlStorage<Email>,
798        job_id: &TaskId,
799    ) -> Request<Email, SqlContext> {
800        // add a slight delay to allow background actions like ack to complete
801        apalis_core::sleep(Duration::from_secs(1)).await;
802        storage
803            .fetch_by_id(job_id)
804            .await
805            .expect("failed to fetch job by id")
806            .expect("no job found by id")
807    }
808
809    #[tokio::test]
810    async fn test_consume_last_pushed_job() {
811        let mut storage = setup().await;
812        push_email(&mut storage, example_email()).await;
813
814        let worker = register_worker(&mut storage).await;
815
816        let job = consume_one(&mut storage, &worker).await;
817        let ctx = job.parts.context;
818        // TODO: Fix assertions
819        assert_eq!(*ctx.status(), State::Running);
820        assert_eq!(*ctx.lock_by(), Some(worker.id().clone()));
821        assert!(ctx.lock_at().is_some());
822    }
823
824    #[tokio::test]
825    async fn test_kill_job() {
826        let mut storage = setup().await;
827
828        push_email(&mut storage, example_email()).await;
829
830        let worker = register_worker(&mut storage).await;
831
832        let job = consume_one(&mut storage, &worker).await;
833
834        let job_id = &job.parts.task_id;
835
836        storage
837            .kill(worker.id(), job_id)
838            .await
839            .expect("failed to kill job");
840
841        let job = get_job(&mut storage, job_id).await;
842        let ctx = job.parts.context;
843        // TODO: Fix assertions
844        assert_eq!(*ctx.status(), State::Killed);
845        assert!(ctx.done_at().is_some());
846    }
847
848    #[tokio::test]
849    async fn test_storage_heartbeat_reenqueuorphaned_pulse_last_seen_6min() {
850        let mut storage = setup().await;
851
852        // push an Email job
853        storage
854            .push(example_email())
855            .await
856            .expect("failed to push job");
857
858        // register a worker not responding since 6 minutes ago
859        let worker_id = WorkerId::new("test-worker");
860        let worker = Worker::new(worker_id, Context::default());
861        worker.start();
862        let five_minutes_ago = Utc::now() - Duration::from_secs(5 * 60);
863
864        let six_minutes_ago = Utc::now() - Duration::from_secs(60 * 6);
865
866        storage
867            .keep_alive_at::<Email>(worker.id(), six_minutes_ago)
868            .await
869            .unwrap();
870
871        // fetch job
872        let job = consume_one(&mut storage, &worker).await;
873        let ctx = job.parts.context;
874
875        assert_eq!(*ctx.status(), State::Running);
876
877        storage
878            .reenqueue_orphaned(1, five_minutes_ago)
879            .await
880            .unwrap();
881
882        // then, the job status has changed to Pending
883        let job = storage
884            .fetch_by_id(&job.parts.task_id)
885            .await
886            .unwrap()
887            .unwrap();
888        let ctx = job.parts.context;
889        assert_eq!(*ctx.status(), State::Pending);
890        assert!(ctx.done_at().is_none());
891        assert!(ctx.lock_by().is_none());
892        assert!(ctx.lock_at().is_none());
893        assert_eq!(*ctx.last_error(), Some("Job was abandoned".to_owned()));
894        assert_eq!(job.parts.attempt.current(), 1);
895    }
896
897    #[tokio::test]
898    async fn test_storage_heartbeat_reenqueuorphaned_pulse_last_seen_4min() {
899        let mut storage = setup().await;
900
901        let service = apalis_test_service_fn(|_: Request<Email, _>| async move {
902            apalis_core::sleep(Duration::from_millis(500)).await;
903            Ok::<_, io::Error>("success")
904        });
905        let (mut t, poller) = TestWrapper::new_with_service(storage.clone(), service);
906        let four_minutes_ago = Utc::now() - Duration::from_secs(4 * 60);
907        storage
908            .keep_alive_at::<Email>(&t.worker.id(), four_minutes_ago)
909            .await
910            .unwrap();
911
912        tokio::spawn(poller);
913
914        // push an Email job
915        let parts = storage
916            .push(example_email())
917            .await
918            .expect("failed to push job");
919
920        // register a worker responding at 4 minutes ago
921        let six_minutes_ago = Utc::now() - Duration::from_secs(6 * 60);
922        // heartbeat with ReenqueueOrpharned pulse
923        storage
924            .reenqueue_orphaned(1, six_minutes_ago)
925            .await
926            .unwrap();
927
928        // then, the job status is not changed
929        let job = storage.fetch_by_id(&parts.task_id).await.unwrap().unwrap();
930        let ctx = job.parts.context;
931        assert_eq!(*ctx.status(), State::Pending);
932        assert_eq!(*ctx.lock_by(), None);
933        assert!(ctx.lock_at().is_none());
934        assert_eq!(*ctx.last_error(), None);
935        assert_eq!(job.parts.attempt.current(), 0);
936
937        let res = t.execute_next().await.unwrap();
938
939        apalis_core::sleep(Duration::from_millis(1000)).await;
940
941        let job = storage.fetch_by_id(&res.0).await.unwrap().unwrap();
942        let ctx = job.parts.context;
943        assert_eq!(*ctx.status(), State::Done);
944        assert_eq!(*ctx.lock_by(), Some(t.worker.id().clone()));
945        assert!(ctx.lock_at().is_some());
946        assert_eq!(*ctx.last_error(), Some("{\"Ok\":\"success\"}".to_owned()));
947        assert_eq!(job.parts.attempt.current(), 1);
948    }
949}