apalis_sql/
mysql.rs

1use apalis_core::backend::{BackendExpose, Stat, WorkerState};
2use apalis_core::codec::json::JsonCodec;
3use apalis_core::error::{BoxDynError, Error};
4use apalis_core::layers::{Ack, AckLayer};
5use apalis_core::notify::Notify;
6use apalis_core::poller::controller::Controller;
7use apalis_core::poller::stream::BackendStream;
8use apalis_core::poller::Poller;
9use apalis_core::request::{Parts, Request, RequestStream, State};
10use apalis_core::response::Response;
11use apalis_core::storage::Storage;
12use apalis_core::task::namespace::Namespace;
13use apalis_core::task::task_id::TaskId;
14use apalis_core::worker::{Context, Event, Worker, WorkerId};
15use apalis_core::{backend::Backend, codec::Codec};
16use async_stream::try_stream;
17use chrono::{DateTime, Utc};
18use futures::{Stream, StreamExt, TryStreamExt};
19use log::error;
20use serde::{de::DeserializeOwned, Serialize};
21use serde_json::Value;
22use sqlx::mysql::MySqlRow;
23use sqlx::{MySql, Pool, Row};
24use std::any::type_name;
25use std::convert::TryInto;
26use std::fmt::Debug;
27use std::sync::Arc;
28use std::{fmt, io};
29use std::{marker::PhantomData, ops::Add, time::Duration};
30
31use crate::context::SqlContext;
32use crate::from_row::SqlRequest;
33use crate::{calculate_status, Config, SqlError};
34
35pub use sqlx::mysql::MySqlPool;
36
37type MysqlCodec = JsonCodec<Value>;
38
39/// Represents a [Storage] that persists to MySQL
40
41pub struct MysqlStorage<T, C = JsonCodec<Value>>
42where
43    C: Codec,
44{
45    pool: Pool<MySql>,
46    job_type: PhantomData<T>,
47    controller: Controller,
48    config: Config,
49    codec: PhantomData<C>,
50    ack_notify: Notify<(SqlContext, Response<C::Compact>)>,
51}
52
53impl<T, C> fmt::Debug for MysqlStorage<T, C>
54where
55    C: Codec,
56    C::Compact: Debug,
57{
58    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
59        f.debug_struct("MysqlStorage")
60            .field("pool", &self.pool)
61            .field("job_type", &"PhantomData<T>")
62            .field("controller", &self.controller)
63            .field("config", &self.config)
64            .field("codec", &std::any::type_name::<C>())
65            .field("ack_notify", &self.ack_notify)
66            .finish()
67    }
68}
69
70impl<T, C> Clone for MysqlStorage<T, C>
71where
72    C: Codec,
73{
74    fn clone(&self) -> Self {
75        let pool = self.pool.clone();
76        MysqlStorage {
77            pool,
78            job_type: PhantomData,
79            controller: self.controller.clone(),
80            config: self.config.clone(),
81            codec: self.codec,
82            ack_notify: self.ack_notify.clone(),
83        }
84    }
85}
86
87impl MysqlStorage<(), JsonCodec<Value>> {
88    /// Get mysql migrations without running them
89    #[cfg(feature = "migrate")]
90    pub fn migrations() -> sqlx::migrate::Migrator {
91        sqlx::migrate!("migrations/mysql")
92    }
93
94    /// Do migrations for mysql
95    #[cfg(feature = "migrate")]
96    pub async fn setup(pool: &Pool<MySql>) -> Result<(), sqlx::Error> {
97        Self::migrations().run(pool).await?;
98        Ok(())
99    }
100}
101
102impl<T> MysqlStorage<T>
103where
104    T: Serialize + DeserializeOwned,
105{
106    /// Create a new instance from a pool
107    pub fn new(pool: MySqlPool) -> Self {
108        Self::new_with_config(pool, Config::new(type_name::<T>()))
109    }
110
111    /// Create a new instance from a pool and custom config
112    pub fn new_with_config(pool: MySqlPool, config: Config) -> Self {
113        Self {
114            pool,
115            job_type: PhantomData,
116            controller: Controller::new(),
117            config,
118            ack_notify: Notify::new(),
119            codec: PhantomData,
120        }
121    }
122
123    /// Expose the pool for other functionality, eg custom migrations
124    pub fn pool(&self) -> &Pool<MySql> {
125        &self.pool
126    }
127
128    /// Get the config used by the storage
129    pub fn get_config(&self) -> &Config {
130        &self.config
131    }
132}
133
134impl<T, C> MysqlStorage<T, C>
135where
136    T: DeserializeOwned + Send + Sync + 'static,
137    C: Codec<Compact = Value> + Send + 'static,
138{
139    fn stream_jobs(
140        self,
141        worker: &Worker<Context>,
142        interval: Duration,
143        buffer_size: usize,
144    ) -> impl Stream<Item = Result<Option<Request<T, SqlContext>>, sqlx::Error>> {
145        let pool = self.pool.clone();
146        let worker = worker.clone();
147        let worker_id = worker.id().to_string();
148
149        try_stream! {
150            let buffer_size = u32::try_from(buffer_size)
151                    .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidInput, e)))?;
152            loop {
153                apalis_core::sleep(interval).await;
154                if !worker.is_ready() {
155                    continue;
156                }
157                let pool = pool.clone();
158                let job_type = self.config.namespace.clone();
159                let mut tx = pool.begin().await?;
160                let fetch_query = "SELECT id FROM jobs
161                WHERE (status='Pending' OR (status = 'Failed' AND attempts < max_attempts)) AND run_at <= NOW() AND job_type = ? ORDER BY priority DESC, run_at ASC LIMIT ? FOR UPDATE SKIP LOCKED";
162                let task_ids: Vec<MySqlRow> = sqlx::query(fetch_query)
163                    .bind(job_type)
164                    .bind(buffer_size)
165                    .fetch_all(&mut *tx).await?;
166                if task_ids.is_empty() {
167                    tx.rollback().await?;
168                    yield None
169                } else {
170                    let task_ids: Vec<String> = task_ids.iter().map(|r| r.get_unchecked("id")).collect();
171                    let id_params = format!("?{}", ", ?".repeat(task_ids.len() - 1));
172                    let query = format!("UPDATE jobs SET status = 'Running', lock_by = ?, lock_at = NOW() WHERE id IN({}) AND status = 'Pending' AND lock_by IS NULL;", id_params);
173                    let mut query = sqlx::query(&query).bind(worker_id.clone());
174                    for i in &task_ids {
175                        query = query.bind(i);
176                    }
177                    query.execute(&mut *tx).await?;
178                    tx.commit().await?;
179
180                    let fetch_query = format!("SELECT * FROM jobs WHERE ID IN ({}) ORDER BY priority DESC, run_at ASC", id_params);
181                    let mut query = sqlx::query_as(&fetch_query);
182                    for i in task_ids {
183                        query = query.bind(i);
184                    }
185                    let jobs: Vec<SqlRequest<Value>> = query.fetch_all(&pool).await?;
186
187                    for job in jobs {
188                        yield {
189                            let (req, ctx) = job.req.take_parts();
190                            let req = C::decode(req)
191                                .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
192                            let mut req: Request<T, SqlContext> = Request::new_with_parts(req, ctx);
193                            req.parts.namespace = Some(Namespace(self.config.namespace.clone()));
194                            Some(req)
195                        }
196                    }
197                }
198            }
199        }.boxed()
200    }
201
202    async fn keep_alive_at<Service>(
203        &mut self,
204        worker_id: &WorkerId,
205        last_seen: DateTime<Utc>,
206    ) -> Result<(), sqlx::Error> {
207        let pool = self.pool.clone();
208
209        let mut tx = pool.acquire().await?;
210        let worker_type = self.config.namespace.clone();
211        let storage_name = std::any::type_name::<Self>();
212        let query =
213            "INSERT INTO workers (id, worker_type, storage_name, layers, last_seen) VALUES (?, ?, ?, ?, ?) ON DUPLICATE KEY UPDATE id = ?;";
214        sqlx::query(query)
215            .bind(worker_id.to_string())
216            .bind(worker_type)
217            .bind(storage_name)
218            .bind(std::any::type_name::<Service>())
219            .bind(last_seen)
220            .bind(worker_id.to_string())
221            .execute(&mut *tx)
222            .await?;
223        Ok(())
224    }
225}
226
227impl<T, C> Storage for MysqlStorage<T, C>
228where
229    T: Serialize + DeserializeOwned + Send + 'static + Unpin + Sync,
230    C: Codec<Compact = Value> + Send + Sync + 'static,
231    C::Error: std::error::Error + 'static + Send + Sync,
232{
233    type Job = T;
234
235    type Error = sqlx::Error;
236
237    type Context = SqlContext;
238
239    type Compact = Value;
240
241    async fn push_request(
242        &mut self,
243        job: Request<Self::Job, SqlContext>,
244    ) -> Result<Parts<SqlContext>, sqlx::Error> {
245        let (args, parts) = job.take_parts();
246        let query =
247            "INSERT INTO jobs VALUES (?, ?, ?, 'Pending', 0, ?, now(), NULL, NULL, NULL, NULL, ?)";
248        let pool = self.pool.clone();
249
250        let job = C::encode(args)
251            .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
252        let job_type = self.config.namespace.clone();
253        sqlx::query(query)
254            .bind(job)
255            .bind(parts.task_id.to_string())
256            .bind(job_type.to_string())
257            .bind(parts.context.max_attempts())
258            .bind(parts.context.priority())
259            .execute(&pool)
260            .await?;
261        Ok(parts)
262    }
263
264    async fn push_raw_request(
265        &mut self,
266        job: Request<Self::Compact, SqlContext>,
267    ) -> Result<Parts<SqlContext>, sqlx::Error> {
268        let (args, parts) = job.take_parts();
269        let query =
270            "INSERT INTO jobs VALUES (?, ?, ?, 'Pending', 0, ?, now(), NULL, NULL, NULL, NULL, ?)";
271        let pool = self.pool.clone();
272
273        let job = C::encode(args)
274            .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
275        let job_type = self.config.namespace.clone();
276        sqlx::query(query)
277            .bind(job)
278            .bind(parts.task_id.to_string())
279            .bind(job_type.to_string())
280            .bind(parts.context.max_attempts())
281            .bind(parts.context.priority())
282            .execute(&pool)
283            .await?;
284        Ok(parts)
285    }
286
287    async fn schedule_request(
288        &mut self,
289        req: Request<Self::Job, SqlContext>,
290        on: i64,
291    ) -> Result<Parts<Self::Context>, sqlx::Error> {
292        let query =
293            "INSERT INTO jobs VALUES (?, ?, ?, 'Pending', 0, ?, ?, NULL, NULL, NULL, NULL, ?)";
294        let pool = self.pool.clone();
295
296        let args = C::encode(&req.args)
297            .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
298
299        let on = DateTime::from_timestamp(on, 0);
300
301        let job_type = self.config.namespace.clone();
302        sqlx::query(query)
303            .bind(args)
304            .bind(req.parts.task_id.to_string())
305            .bind(job_type)
306            .bind(req.parts.context.max_attempts())
307            .bind(on)
308            .bind(req.parts.context.priority())
309            .execute(&pool)
310            .await?;
311        Ok(req.parts)
312    }
313
314    async fn fetch_by_id(
315        &mut self,
316        job_id: &TaskId,
317    ) -> Result<Option<Request<Self::Job, SqlContext>>, sqlx::Error> {
318        let pool = self.pool.clone();
319
320        let fetch_query = "SELECT * FROM jobs WHERE id = ?";
321        let res: Option<SqlRequest<serde_json::Value>> = sqlx::query_as(fetch_query)
322            .bind(job_id.to_string())
323            .fetch_optional(&pool)
324            .await?;
325        match res {
326            None => Ok(None),
327            Some(job) => Ok(Some({
328                let (req, parts) = job.req.take_parts();
329                let req = C::decode(req)
330                    .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
331                let mut req = Request::new_with_parts(req, parts);
332                req.parts.namespace = Some(Namespace(self.config.namespace.clone()));
333                req
334            })),
335        }
336    }
337
338    async fn len(&mut self) -> Result<i64, sqlx::Error> {
339        let pool = self.pool.clone();
340
341        let query = "Select Count(*) as count from jobs where status='Pending'";
342        let record = sqlx::query(query).fetch_one(&pool).await?;
343        record.try_get("count")
344    }
345
346    async fn reschedule(
347        &mut self,
348        job: Request<T, SqlContext>,
349        wait: Duration,
350    ) -> Result<(), sqlx::Error> {
351        let pool = self.pool.clone();
352        let job_id = job.parts.task_id.clone();
353
354        let wait: i64 = wait
355            .as_secs()
356            .try_into()
357            .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidInput, e)))?;
358        let mut tx = pool.acquire().await?;
359        let query =
360                "UPDATE jobs SET status = 'Pending', done_at = NULL, lock_by = NULL, lock_at = NULL, run_at = ? WHERE id = ?";
361
362        sqlx::query(query)
363            .bind(Utc::now().timestamp().add(wait))
364            .bind(job_id.to_string())
365            .execute(&mut *tx)
366            .await?;
367        Ok(())
368    }
369
370    async fn update(&mut self, job: Request<Self::Job, SqlContext>) -> Result<(), sqlx::Error> {
371        let pool = self.pool.clone();
372        let ctx = job.parts.context;
373        let status = ctx.status().to_string();
374        let attempts = job.parts.attempt;
375        let done_at = *ctx.done_at();
376        let lock_by = ctx.lock_by().clone();
377        let lock_at = *ctx.lock_at();
378        let last_error = ctx.last_error().clone();
379        let priority = *ctx.priority();
380        let job_id = job.parts.task_id;
381        let mut tx = pool.acquire().await?;
382        let query =
383                "UPDATE jobs SET status = ?, attempts = ?, done_at = ?, lock_by = ?, lock_at = ?, last_error = ?, priority = ? WHERE id = ?";
384        sqlx::query(query)
385            .bind(status.to_owned())
386            .bind::<i64>(
387                attempts
388                    .current()
389                    .try_into()
390                    .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidInput, e)))?,
391            )
392            .bind(done_at)
393            .bind(lock_by.map(|w| w.name().to_string()))
394            .bind(lock_at)
395            .bind(last_error)
396            .bind(priority)
397            .bind(job_id.to_string())
398            .execute(&mut *tx)
399            .await?;
400        Ok(())
401    }
402
403    async fn is_empty(&mut self) -> Result<bool, Self::Error> {
404        Ok(self.len().await? == 0)
405    }
406
407    async fn vacuum(&mut self) -> Result<usize, sqlx::Error> {
408        let pool = self.pool.clone();
409        let query = "Delete from jobs where status='Done'";
410        let record = sqlx::query(query).execute(&pool).await?;
411        Ok(record.rows_affected().try_into().unwrap_or_default())
412    }
413}
414
415/// Errors that can occur while polling a MySQL database.
416#[derive(thiserror::Error, Debug)]
417pub enum MysqlPollError {
418    /// Error during task acknowledgment.
419    #[error("Encountered an error during ACK: `{0}`")]
420    AckError(sqlx::Error),
421
422    /// Error during result encoding.
423    #[error("Encountered an error during encoding the result: {0}")]
424    CodecError(BoxDynError),
425
426    /// Error during a keep-alive heartbeat.
427    #[error("Encountered an error during KeepAlive heartbeat: `{0}`")]
428    KeepAliveError(sqlx::Error),
429
430    /// Error during re-enqueuing orphaned tasks.
431    #[error("Encountered an error during ReenqueueOrphaned heartbeat: `{0}`")]
432    ReenqueueOrphanedError(sqlx::Error),
433}
434
435impl<Req, C> Backend<Request<Req, SqlContext>> for MysqlStorage<Req, C>
436where
437    Req: Serialize + DeserializeOwned + Sync + Send + 'static,
438    C: Codec<Compact = Value> + Send + 'static + Sync,
439    C::Error: std::error::Error + 'static + Send + Sync,
440{
441    type Stream = BackendStream<RequestStream<Request<Req, SqlContext>>>;
442
443    type Layer = AckLayer<MysqlStorage<Req, C>, Req, SqlContext, C>;
444
445    type Codec = C;
446
447    fn poll(self, worker: &Worker<Context>) -> Poller<Self::Stream, Self::Layer> {
448        let layer = AckLayer::new(self.clone());
449        let config = self.config.clone();
450        let controller = self.controller.clone();
451        let pool = self.pool.clone();
452        let ack_notify = self.ack_notify.clone();
453        let mut hb_storage = self.clone();
454        let requeue_storage = self.clone();
455        let stream = self
456            .stream_jobs(worker, config.poll_interval, config.buffer_size)
457            .map_err(|e| Error::SourceError(Arc::new(Box::new(e))));
458        let stream = BackendStream::new(stream.boxed(), controller);
459        let w = worker.clone();
460
461        let ack_heartbeat = async move {
462            while let Some(ids) = ack_notify
463                .clone()
464                .ready_chunks(config.buffer_size)
465                .next()
466                .await
467            {
468                for (ctx, res) in ids {
469                    let query = "UPDATE jobs SET status = ?, done_at = now(), last_error = ?, attempts = ? WHERE id = ? AND lock_by = ?";
470                    let query = sqlx::query(query);
471                    let last_result =
472                        C::encode(res.inner.as_ref().map_err(|e| e.to_string())).map_err(Box::new);
473                    match (last_result, ctx.lock_by()) {
474                        (Ok(val), Some(worker_id)) => {
475                            let query = query
476                                .bind(calculate_status(&ctx, &res).to_string())
477                                .bind(val)
478                                .bind(res.attempt.current() as i32)
479                                .bind(res.task_id.to_string())
480                                .bind(worker_id.to_string());
481                            if let Err(e) = query.execute(&pool).await {
482                                w.emit(Event::Error(Box::new(MysqlPollError::AckError(e))));
483                            }
484                        }
485                        (Err(error), Some(_)) => {
486                            w.emit(Event::Error(Box::new(MysqlPollError::CodecError(error))));
487                        }
488                        _ => {
489                            unreachable!(
490                                "Attempted to ACK without a worker attached. This is a bug, File it on the repo"
491                            );
492                        }
493                    }
494                }
495
496                apalis_core::sleep(config.poll_interval).await;
497            }
498        };
499        let w = worker.clone();
500        let heartbeat = async move {
501            // Lets reenqueue any jobs that belonged to this worker in case of a death
502            if let Err(e) = hb_storage
503                .reenqueue_orphaned((config.buffer_size * 10) as i32, Utc::now())
504                .await
505            {
506                w.emit(Event::Error(Box::new(
507                    MysqlPollError::ReenqueueOrphanedError(e),
508                )));
509            }
510
511            loop {
512                let now = Utc::now();
513                if let Err(e) = hb_storage.keep_alive_at::<Self::Layer>(w.id(), now).await {
514                    w.emit(Event::Error(Box::new(MysqlPollError::KeepAliveError(e))));
515                }
516                apalis_core::sleep(config.keep_alive).await;
517            }
518        };
519        let w = worker.clone();
520        let reenqueue_beat = async move {
521            loop {
522                let dead_since = Utc::now()
523                    - chrono::Duration::from_std(config.reenqueue_orphaned_after)
524                        .expect("Could not calculate dead since");
525                if let Err(e) = requeue_storage
526                    .reenqueue_orphaned(
527                        config
528                            .buffer_size
529                            .try_into()
530                            .expect("Could not convert usize to i32"),
531                        dead_since,
532                    )
533                    .await
534                {
535                    w.emit(Event::Error(Box::new(
536                        MysqlPollError::ReenqueueOrphanedError(e),
537                    )));
538                }
539                apalis_core::sleep(config.poll_interval).await;
540            }
541        };
542        Poller::new_with_layer(
543            stream,
544            async {
545                futures::join!(heartbeat, ack_heartbeat, reenqueue_beat);
546            },
547            layer,
548        )
549    }
550}
551
552impl<T, Res, C> Ack<T, Res, C> for MysqlStorage<T, C>
553where
554    T: Sync + Send,
555    Res: Serialize + Send + 'static + Sync,
556    C: Codec<Compact = Value> + Send,
557    C::Error: Debug,
558{
559    type Context = SqlContext;
560    type AckError = sqlx::Error;
561    async fn ack(&mut self, ctx: &Self::Context, res: &Response<Res>) -> Result<(), sqlx::Error> {
562        self.ack_notify
563            .notify((
564                ctx.clone(),
565                res.map(|res| C::encode(res).expect("Could not encode result")),
566            ))
567            .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::BrokenPipe, e)))?;
568
569        Ok(())
570    }
571}
572
573impl<T, C: Codec> MysqlStorage<T, C> {
574    /// Kill a job
575    pub async fn kill(&mut self, worker_id: &WorkerId, job_id: &TaskId) -> Result<(), sqlx::Error> {
576        let pool = self.pool.clone();
577
578        let mut tx = pool.acquire().await?;
579        let query =
580            "UPDATE jobs SET status = 'Killed', done_at = NOW() WHERE id = ? AND lock_by = ?";
581        sqlx::query(query)
582            .bind(job_id.to_string())
583            .bind(worker_id.to_string())
584            .execute(&mut *tx)
585            .await?;
586        Ok(())
587    }
588
589    /// Puts the job instantly back into the queue
590    pub async fn retry(
591        &mut self,
592        worker_id: &WorkerId,
593        job_id: &TaskId,
594    ) -> Result<(), sqlx::Error> {
595        let pool = self.pool.clone();
596
597        let mut tx = pool.acquire().await?;
598        let query =
599                "UPDATE jobs SET status = 'Pending', done_at = NULL, lock_by = NULL WHERE id = ? AND lock_by = ?";
600        sqlx::query(query)
601            .bind(job_id.to_string())
602            .bind(worker_id.to_string())
603            .execute(&mut *tx)
604            .await?;
605        Ok(())
606    }
607
608    /// Readd jobs that are abandoned to the queue
609    pub async fn reenqueue_orphaned(
610        &self,
611        count: i32,
612        dead_since: DateTime<Utc>,
613    ) -> Result<bool, sqlx::Error> {
614        let job_type = self.config.namespace.clone();
615        let mut tx = self.pool.acquire().await?;
616        let query = r#"Update jobs
617                        INNER JOIN ( SELECT workers.id as worker_id, jobs.id as job_id from workers INNER JOIN jobs ON jobs.lock_by = workers.id WHERE jobs.attempts < jobs.max_attempts AND jobs.status = "Running" AND workers.last_seen < ? AND workers.worker_type = ?
618                            ORDER BY lock_at ASC LIMIT ?) as workers ON jobs.lock_by = workers.worker_id AND jobs.id = workers.job_id
619                        SET status = "Pending", done_at = NULL, lock_by = NULL, lock_at = NULL, last_error ="Job was abandoned", attempts = attempts + 1;"#;
620
621        sqlx::query(query)
622            .bind(dead_since)
623            .bind(job_type)
624            .bind(count)
625            .execute(&mut *tx)
626            .await?;
627        Ok(true)
628    }
629}
630
631impl<J: 'static + Serialize + DeserializeOwned + Unpin + Send + Sync> BackendExpose<J>
632    for MysqlStorage<J>
633{
634    type Request = Request<J, Parts<SqlContext>>;
635    type Error = SqlError;
636    async fn stats(&self) -> Result<Stat, Self::Error> {
637        let fetch_query = "SELECT
638            COUNT(CASE WHEN status = 'Pending' THEN 1 END) AS pending,
639            COUNT(CASE WHEN status = 'Running' THEN 1 END) AS running,
640            COUNT(CASE WHEN status = 'Done' THEN 1 END) AS done,
641            COUNT(CASE WHEN status = 'Retry' THEN 1 END) AS retry,
642            COUNT(CASE WHEN status = 'Failed' THEN 1 END) AS failed,
643            COUNT(CASE WHEN status = 'Killed' THEN 1 END) AS killed
644        FROM jobs WHERE job_type = ?";
645
646        let res: (i64, i64, i64, i64, i64, i64) = sqlx::query_as(fetch_query)
647            .bind(self.get_config().namespace())
648            .fetch_one(self.pool())
649            .await?;
650
651        Ok(Stat {
652            pending: res.0.try_into()?,
653            running: res.1.try_into()?,
654            dead: res.4.try_into()?,
655            failed: res.3.try_into()?,
656            success: res.2.try_into()?,
657        })
658    }
659
660    async fn list_jobs(
661        &self,
662        status: &State,
663        page: i32,
664    ) -> Result<Vec<Self::Request>, Self::Error> {
665        let status = status.to_string();
666        let fetch_query = "SELECT * FROM jobs WHERE status = ? AND job_type = ? ORDER BY done_at DESC, run_at DESC LIMIT 10 OFFSET ?";
667        let res: Vec<SqlRequest<serde_json::Value>> = sqlx::query_as(fetch_query)
668            .bind(status)
669            .bind(self.get_config().namespace())
670            .bind(((page - 1) * 10).to_string())
671            .fetch_all(self.pool())
672            .await?;
673        Ok(res
674            .into_iter()
675            .map(|j| {
676                let (req, ctx) = j.req.take_parts();
677                let req: J = MysqlCodec::decode(req).unwrap();
678                Request::new_with_ctx(req, ctx)
679            })
680            .collect())
681    }
682
683    async fn list_workers(&self) -> Result<Vec<Worker<WorkerState>>, Self::Error> {
684        let fetch_query =
685            "SELECT id, layers, last_seen FROM workers WHERE worker_type = ? ORDER BY last_seen DESC LIMIT 20 OFFSET ?";
686        let res: Vec<(String, String, i64)> = sqlx::query_as(fetch_query)
687            .bind(self.get_config().namespace())
688            .bind(0)
689            .fetch_all(self.pool())
690            .await?;
691        Ok(res
692            .into_iter()
693            .map(|w| Worker::new(WorkerId::new(w.0), WorkerState::new::<Self>(w.1)))
694            .collect())
695    }
696}
697
698#[cfg(test)]
699mod tests {
700
701    use crate::sql_storage_tests;
702
703    use super::*;
704
705    use apalis_core::test_utils::DummyService;
706    use email_service::Email;
707    use futures::StreamExt;
708
709    use apalis_core::generic_storage_test;
710    use apalis_core::test_utils::apalis_test_service_fn;
711    use apalis_core::test_utils::TestWrapper;
712
713    generic_storage_test!(setup);
714
715    sql_storage_tests!(setup::<Email>, MysqlStorage<Email>, Email);
716
717    /// migrate DB and return a storage instance.
718    async fn setup<T: Serialize + DeserializeOwned>() -> MysqlStorage<T> {
719        let db_url = &std::env::var("DATABASE_URL").expect("No DATABASE_URL is specified");
720        // Because connections cannot be shared across async runtime
721        // (different runtimes are created for each test),
722        // we don't share the storage and tests must be run sequentially.
723        let pool = MySqlPool::connect(db_url).await.unwrap();
724        MysqlStorage::setup(&pool)
725            .await
726            .expect("failed to migrate DB");
727        let mut storage = MysqlStorage::new(pool);
728        cleanup(&mut storage, &WorkerId::new("test-worker")).await;
729        storage
730    }
731
732    /// rollback DB changes made by tests.
733    /// Delete the following rows:
734    ///  - jobs whose state is `Pending` or locked by `worker_id`
735    ///  - worker identified by `worker_id`
736    ///
737    /// You should execute this function in the end of a test
738    async fn cleanup<T>(storage: &mut MysqlStorage<T>, worker_id: &WorkerId) {
739        sqlx::query("DELETE FROM jobs WHERE job_type = ?")
740            .bind(storage.config.namespace())
741            .execute(&storage.pool)
742            .await
743            .expect("failed to delete jobs");
744        sqlx::query("DELETE FROM workers WHERE id = ?")
745            .bind(worker_id.to_string())
746            .execute(&storage.pool)
747            .await
748            .expect("failed to delete worker");
749    }
750
751    async fn consume_one(
752        storage: &mut MysqlStorage<Email>,
753        worker: &Worker<Context>,
754    ) -> Request<Email, SqlContext> {
755        let mut stream = storage
756            .clone()
757            .stream_jobs(worker, std::time::Duration::from_secs(10), 1);
758        stream
759            .next()
760            .await
761            .expect("stream is empty")
762            .expect("failed to poll job")
763            .expect("no job is pending")
764    }
765
766    fn example_email() -> Email {
767        Email {
768            subject: "Test Subject".to_string(),
769            to: "example@mysql".to_string(),
770            text: "Some Text".to_string(),
771        }
772    }
773
774    async fn register_worker_at(
775        storage: &mut MysqlStorage<Email>,
776        last_seen: DateTime<Utc>,
777    ) -> Worker<Context> {
778        let worker_id = WorkerId::new("test-worker");
779        let wrk = Worker::new(worker_id, Context::default());
780        wrk.start();
781        storage
782            .keep_alive_at::<DummyService>(&wrk.id(), last_seen)
783            .await
784            .expect("failed to register worker");
785        wrk
786    }
787
788    async fn register_worker(storage: &mut MysqlStorage<Email>) -> Worker<Context> {
789        let now = Utc::now();
790
791        register_worker_at(storage, now).await
792    }
793
794    async fn push_email(storage: &mut MysqlStorage<Email>, email: Email) {
795        storage.push(email).await.expect("failed to push a job");
796    }
797
798    async fn get_job(
799        storage: &mut MysqlStorage<Email>,
800        job_id: &TaskId,
801    ) -> Request<Email, SqlContext> {
802        // add a slight delay to allow background actions like ack to complete
803        apalis_core::sleep(Duration::from_secs(1)).await;
804        storage
805            .fetch_by_id(job_id)
806            .await
807            .expect("failed to fetch job by id")
808            .expect("no job found by id")
809    }
810
811    #[tokio::test]
812    async fn test_consume_last_pushed_job() {
813        let mut storage = setup().await;
814        push_email(&mut storage, example_email()).await;
815
816        let worker = register_worker(&mut storage).await;
817
818        let job = consume_one(&mut storage, &worker).await;
819        let ctx = job.parts.context;
820        // TODO: Fix assertions
821        assert_eq!(*ctx.status(), State::Running);
822        assert_eq!(*ctx.lock_by(), Some(worker.id().clone()));
823        assert!(ctx.lock_at().is_some());
824    }
825
826    #[tokio::test]
827    async fn test_kill_job() {
828        let mut storage = setup().await;
829
830        push_email(&mut storage, example_email()).await;
831
832        let worker = register_worker(&mut storage).await;
833
834        let job = consume_one(&mut storage, &worker).await;
835
836        let job_id = &job.parts.task_id;
837
838        storage
839            .kill(worker.id(), job_id)
840            .await
841            .expect("failed to kill job");
842
843        let job = get_job(&mut storage, job_id).await;
844        let ctx = job.parts.context;
845        // TODO: Fix assertions
846        assert_eq!(*ctx.status(), State::Killed);
847        assert!(ctx.done_at().is_some());
848    }
849
850    #[tokio::test]
851    async fn test_storage_heartbeat_reenqueuorphaned_pulse_last_seen_6min() {
852        let mut storage = setup().await;
853
854        // push an Email job
855        storage
856            .push(example_email())
857            .await
858            .expect("failed to push job");
859
860        // register a worker not responding since 6 minutes ago
861        let worker_id = WorkerId::new("test-worker");
862        let worker = Worker::new(worker_id, Context::default());
863        worker.start();
864        let five_minutes_ago = Utc::now() - Duration::from_secs(5 * 60);
865
866        let six_minutes_ago = Utc::now() - Duration::from_secs(60 * 6);
867
868        storage
869            .keep_alive_at::<Email>(worker.id(), six_minutes_ago)
870            .await
871            .unwrap();
872
873        // fetch job
874        let job = consume_one(&mut storage, &worker).await;
875        let ctx = job.parts.context;
876
877        assert_eq!(*ctx.status(), State::Running);
878
879        storage
880            .reenqueue_orphaned(1, five_minutes_ago)
881            .await
882            .unwrap();
883
884        // then, the job status has changed to Pending
885        let job = storage
886            .fetch_by_id(&job.parts.task_id)
887            .await
888            .unwrap()
889            .unwrap();
890        let ctx = job.parts.context;
891        assert_eq!(*ctx.status(), State::Pending);
892        assert!(ctx.done_at().is_none());
893        assert!(ctx.lock_by().is_none());
894        assert!(ctx.lock_at().is_none());
895        assert_eq!(*ctx.last_error(), Some("Job was abandoned".to_owned()));
896        assert_eq!(job.parts.attempt.current(), 1);
897    }
898
899    #[tokio::test]
900    async fn test_storage_heartbeat_reenqueuorphaned_pulse_last_seen_4min() {
901        let mut storage = setup().await;
902
903        let service = apalis_test_service_fn(|_: Request<Email, _>| async move {
904            apalis_core::sleep(Duration::from_millis(500)).await;
905            Ok::<_, io::Error>("success")
906        });
907        let (mut t, poller) = TestWrapper::new_with_service(storage.clone(), service);
908        let four_minutes_ago = Utc::now() - Duration::from_secs(4 * 60);
909        storage
910            .keep_alive_at::<Email>(&t.worker.id(), four_minutes_ago)
911            .await
912            .unwrap();
913
914        tokio::spawn(poller);
915
916        // push an Email job
917        let parts = storage
918            .push(example_email())
919            .await
920            .expect("failed to push job");
921
922        // register a worker responding at 4 minutes ago
923        let six_minutes_ago = Utc::now() - Duration::from_secs(6 * 60);
924        // heartbeat with ReenqueueOrpharned pulse
925        storage
926            .reenqueue_orphaned(1, six_minutes_ago)
927            .await
928            .unwrap();
929
930        // then, the job status is not changed
931        let job = storage.fetch_by_id(&parts.task_id).await.unwrap().unwrap();
932        let ctx = job.parts.context;
933        assert_eq!(*ctx.status(), State::Pending);
934        assert_eq!(*ctx.lock_by(), None);
935        assert!(ctx.lock_at().is_none());
936        assert_eq!(*ctx.last_error(), None);
937        assert_eq!(job.parts.attempt.current(), 0);
938
939        let res = t.execute_next().await.unwrap();
940
941        apalis_core::sleep(Duration::from_millis(1000)).await;
942
943        let job = storage.fetch_by_id(&res.0).await.unwrap().unwrap();
944        let ctx = job.parts.context;
945        assert_eq!(*ctx.status(), State::Done);
946        assert_eq!(*ctx.lock_by(), Some(t.worker.id().clone()));
947        assert!(ctx.lock_at().is_some());
948        assert_eq!(*ctx.last_error(), Some("{\"Ok\":\"success\"}".to_owned()));
949        assert_eq!(job.parts.attempt.current(), 1);
950    }
951}