apalis_sql/
mysql.rs

1use apalis_core::backend::{BackendExpose, Stat, WorkerState};
2use apalis_core::codec::json::JsonCodec;
3use apalis_core::error::{BoxDynError, Error};
4use apalis_core::layers::{Ack, AckLayer};
5use apalis_core::notify::Notify;
6use apalis_core::poller::controller::Controller;
7use apalis_core::poller::stream::BackendStream;
8use apalis_core::poller::Poller;
9use apalis_core::request::{Parts, Request, RequestStream, State};
10use apalis_core::response::Response;
11use apalis_core::storage::Storage;
12use apalis_core::task::namespace::Namespace;
13use apalis_core::task::task_id::TaskId;
14use apalis_core::worker::{Context, Event, Worker, WorkerId};
15use apalis_core::{backend::Backend, codec::Codec};
16use async_stream::try_stream;
17use chrono::{DateTime, Utc};
18use futures::{Stream, StreamExt, TryStreamExt};
19use log::error;
20use serde::{de::DeserializeOwned, Serialize};
21use serde_json::Value;
22use sqlx::mysql::MySqlRow;
23use sqlx::{MySql, Pool, Row};
24use std::any::type_name;
25use std::convert::TryInto;
26use std::fmt::Debug;
27use std::sync::Arc;
28use std::{fmt, io};
29use std::{marker::PhantomData, ops::Add, time::Duration};
30
31use crate::context::SqlContext;
32use crate::from_row::SqlRequest;
33use crate::{calculate_status, Config, SqlError};
34
35pub use sqlx::mysql::MySqlPool;
36
37type MysqlCodec = JsonCodec<Value>;
38
39/// Represents a [Storage] that persists to MySQL
40
41pub struct MysqlStorage<T, C = JsonCodec<Value>>
42where
43    C: Codec,
44{
45    pool: Pool<MySql>,
46    job_type: PhantomData<T>,
47    controller: Controller,
48    config: Config,
49    codec: PhantomData<C>,
50    ack_notify: Notify<(SqlContext, Response<C::Compact>)>,
51}
52
53impl<T, C> fmt::Debug for MysqlStorage<T, C>
54where
55    C: Debug + Codec,
56    C::Compact: Debug,
57{
58    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
59        f.debug_struct("MysqlStorage")
60            .field("pool", &self.pool)
61            .field("job_type", &"PhantomData<T>")
62            .field("controller", &self.controller)
63            .field("config", &self.config)
64            .field("codec", &self.codec)
65            .field("ack_notify", &self.ack_notify)
66            .finish()
67    }
68}
69
70impl<T, C> Clone for MysqlStorage<T, C>
71where
72    C: Debug + Codec,
73{
74    fn clone(&self) -> Self {
75        let pool = self.pool.clone();
76        MysqlStorage {
77            pool,
78            job_type: PhantomData,
79            controller: self.controller.clone(),
80            config: self.config.clone(),
81            codec: self.codec,
82            ack_notify: self.ack_notify.clone(),
83        }
84    }
85}
86
87impl MysqlStorage<(), JsonCodec<Value>> {
88    /// Get mysql migrations without running them
89    #[cfg(feature = "migrate")]
90    pub fn migrations() -> sqlx::migrate::Migrator {
91        sqlx::migrate!("migrations/mysql")
92    }
93
94    /// Do migrations for mysql
95    #[cfg(feature = "migrate")]
96    pub async fn setup(pool: &Pool<MySql>) -> Result<(), sqlx::Error> {
97        Self::migrations().run(pool).await?;
98        Ok(())
99    }
100}
101
102impl<T> MysqlStorage<T>
103where
104    T: Serialize + DeserializeOwned,
105{
106    /// Create a new instance from a pool
107    pub fn new(pool: MySqlPool) -> Self {
108        Self::new_with_config(pool, Config::new(type_name::<T>()))
109    }
110
111    /// Create a new instance from a pool and custom config
112    pub fn new_with_config(pool: MySqlPool, config: Config) -> Self {
113        Self {
114            pool,
115            job_type: PhantomData,
116            controller: Controller::new(),
117            config,
118            ack_notify: Notify::new(),
119            codec: PhantomData,
120        }
121    }
122
123    /// Expose the pool for other functionality, eg custom migrations
124    pub fn pool(&self) -> &Pool<MySql> {
125        &self.pool
126    }
127
128    /// Get the config used by the storage
129    pub fn get_config(&self) -> &Config {
130        &self.config
131    }
132}
133
134impl<T, C> MysqlStorage<T, C>
135where
136    T: DeserializeOwned + Send + Unpin + Sync + 'static,
137    C: Codec<Compact = Value> + Send + 'static,
138{
139    fn stream_jobs(
140        self,
141        worker: &Worker<Context>,
142        interval: Duration,
143        buffer_size: usize,
144    ) -> impl Stream<Item = Result<Option<Request<T, SqlContext>>, sqlx::Error>> {
145        let pool = self.pool.clone();
146        let worker = worker.clone();
147        let worker_id = worker.id().to_string();
148
149        try_stream! {
150            let buffer_size = u32::try_from(buffer_size)
151                    .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidInput, e)))?;
152            loop {
153                apalis_core::sleep(interval).await;
154                if !worker.is_ready() {
155                    continue;
156                }
157                let pool = pool.clone();
158                let job_type = self.config.namespace.clone();
159                let mut tx = pool.begin().await?;
160                let fetch_query = "SELECT id FROM jobs
161                WHERE status = 'Pending' AND run_at <= NOW() AND job_type = ? ORDER BY run_at ASC LIMIT ? FOR UPDATE SKIP LOCKED";
162                let task_ids: Vec<MySqlRow> = sqlx::query(fetch_query)
163                    .bind(job_type)
164                    .bind(buffer_size)
165                    .fetch_all(&mut *tx).await?;
166                if task_ids.is_empty() {
167                    tx.rollback().await?;
168                    yield None
169                } else {
170                    let task_ids: Vec<String> = task_ids.iter().map(|r| r.get_unchecked("id")).collect();
171                    let id_params = format!("?{}", ", ?".repeat(task_ids.len() - 1));
172                    let query = format!("UPDATE jobs SET status = 'Running', lock_by = ?, lock_at = NOW(), attempts = attempts + 1 WHERE id IN({}) AND status = 'Pending' AND lock_by IS NULL;", id_params);
173                    let mut query = sqlx::query(&query).bind(worker_id.clone());
174                    for i in &task_ids {
175                        query = query.bind(i);
176                    }
177                    query.execute(&mut *tx).await?;
178                    tx.commit().await?;
179
180                    let fetch_query = format!("SELECT * FROM jobs WHERE ID IN ({})", id_params);
181                    let mut query = sqlx::query_as(&fetch_query);
182                    for i in task_ids {
183                        query = query.bind(i);
184                    }
185                    let jobs: Vec<SqlRequest<Value>> = query.fetch_all(&pool).await?;
186
187                    for job in jobs {
188                        yield {
189                            let (req, ctx) = job.req.take_parts();
190                            let req = C::decode(req)
191                                .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
192                            let mut req: Request<T, SqlContext> = Request::new_with_parts(req, ctx);
193                            req.parts.namespace = Some(Namespace(self.config.namespace.clone()));
194                            Some(req)
195                        }
196                    }
197                }
198            }
199        }.boxed()
200    }
201
202    async fn keep_alive_at<Service>(
203        &mut self,
204        worker_id: &WorkerId,
205        last_seen: DateTime<Utc>,
206    ) -> Result<(), sqlx::Error> {
207        let pool = self.pool.clone();
208
209        let mut tx = pool.acquire().await?;
210        let worker_type = self.config.namespace.clone();
211        let storage_name = std::any::type_name::<Self>();
212        let query =
213            "INSERT INTO workers (id, worker_type, storage_name, layers, last_seen) VALUES (?, ?, ?, ?, ?) ON DUPLICATE KEY UPDATE id = ?;";
214        sqlx::query(query)
215            .bind(worker_id.to_string())
216            .bind(worker_type)
217            .bind(storage_name)
218            .bind(std::any::type_name::<Service>())
219            .bind(last_seen)
220            .bind(worker_id.to_string())
221            .execute(&mut *tx)
222            .await?;
223        Ok(())
224    }
225}
226
227impl<T, C> Storage for MysqlStorage<T, C>
228where
229    T: Serialize + DeserializeOwned + Send + 'static + Unpin + Sync,
230    C: Codec<Compact = Value> + Send,
231{
232    type Job = T;
233
234    type Error = sqlx::Error;
235
236    type Context = SqlContext;
237
238    async fn push_request(
239        &mut self,
240        job: Request<Self::Job, SqlContext>,
241    ) -> Result<Parts<SqlContext>, sqlx::Error> {
242        let (args, parts) = job.take_parts();
243        let query =
244            "INSERT INTO jobs VALUES (?, ?, ?, 'Pending', 0, ?, now(), NULL, NULL, NULL, NULL)";
245        let pool = self.pool.clone();
246
247        let job = C::encode(args)
248            .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
249        let job_type = self.config.namespace.clone();
250        sqlx::query(query)
251            .bind(job)
252            .bind(parts.task_id.to_string())
253            .bind(job_type.to_string())
254            .bind(parts.context.max_attempts())
255            .execute(&pool)
256            .await?;
257        Ok(parts)
258    }
259
260    async fn schedule_request(
261        &mut self,
262        req: Request<Self::Job, SqlContext>,
263        on: i64,
264    ) -> Result<Parts<Self::Context>, sqlx::Error> {
265        let query = "INSERT INTO jobs VALUES (?, ?, ?, 'Pending', 0, ?, ?, NULL, NULL, NULL, NULL)";
266        let pool = self.pool.clone();
267
268        let args = C::encode(&req.args)
269            .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
270
271        let job_type = self.config.namespace.clone();
272        sqlx::query(query)
273            .bind(args)
274            .bind(req.parts.task_id.to_string())
275            .bind(job_type)
276            .bind(req.parts.context.max_attempts())
277            .bind(on)
278            .execute(&pool)
279            .await?;
280        Ok(req.parts)
281    }
282
283    async fn fetch_by_id(
284        &mut self,
285        job_id: &TaskId,
286    ) -> Result<Option<Request<Self::Job, SqlContext>>, sqlx::Error> {
287        let pool = self.pool.clone();
288
289        let fetch_query = "SELECT * FROM jobs WHERE id = ?";
290        let res: Option<SqlRequest<serde_json::Value>> = sqlx::query_as(fetch_query)
291            .bind(job_id.to_string())
292            .fetch_optional(&pool)
293            .await?;
294        match res {
295            None => Ok(None),
296            Some(job) => Ok(Some({
297                let (req, parts) = job.req.take_parts();
298                let req = C::decode(req)
299                    .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
300                let mut req = Request::new_with_parts(req, parts);
301                req.parts.namespace = Some(Namespace(self.config.namespace.clone()));
302                req
303            })),
304        }
305    }
306
307    async fn len(&mut self) -> Result<i64, sqlx::Error> {
308        let pool = self.pool.clone();
309
310        let query = "Select Count(*) as count from jobs where status='Pending'";
311        let record = sqlx::query(query).fetch_one(&pool).await?;
312        record.try_get("count")
313    }
314
315    async fn reschedule(
316        &mut self,
317        job: Request<T, SqlContext>,
318        wait: Duration,
319    ) -> Result<(), sqlx::Error> {
320        let pool = self.pool.clone();
321        let job_id = job.parts.task_id.clone();
322
323        let wait: i64 = wait
324            .as_secs()
325            .try_into()
326            .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidInput, e)))?;
327        let mut tx = pool.acquire().await?;
328        let query =
329                "UPDATE jobs SET status = 'Pending', done_at = NULL, lock_by = NULL, lock_at = NULL, run_at = ? WHERE id = ?";
330
331        sqlx::query(query)
332            .bind(Utc::now().timestamp().add(wait))
333            .bind(job_id.to_string())
334            .execute(&mut *tx)
335            .await?;
336        Ok(())
337    }
338
339    async fn update(&mut self, job: Request<Self::Job, SqlContext>) -> Result<(), sqlx::Error> {
340        let pool = self.pool.clone();
341        let ctx = job.parts.context;
342        let status = ctx.status().to_string();
343        let attempts = job.parts.attempt;
344        let done_at = *ctx.done_at();
345        let lock_by = ctx.lock_by().clone();
346        let lock_at = *ctx.lock_at();
347        let last_error = ctx.last_error().clone();
348        let job_id = job.parts.task_id;
349        let mut tx = pool.acquire().await?;
350        let query =
351                "UPDATE jobs SET status = ?, attempts = ?, done_at = ?, lock_by = ?, lock_at = ?, last_error = ? WHERE id = ?";
352        sqlx::query(query)
353            .bind(status.to_owned())
354            .bind::<i64>(
355                attempts
356                    .current()
357                    .try_into()
358                    .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidInput, e)))?,
359            )
360            .bind(done_at)
361            .bind(lock_by.map(|w| w.name().to_string()))
362            .bind(lock_at)
363            .bind(last_error)
364            .bind(job_id.to_string())
365            .execute(&mut *tx)
366            .await?;
367        Ok(())
368    }
369
370    async fn is_empty(&mut self) -> Result<bool, Self::Error> {
371        Ok(self.len().await? == 0)
372    }
373
374    async fn vacuum(&mut self) -> Result<usize, sqlx::Error> {
375        let pool = self.pool.clone();
376        let query = "Delete from jobs where status='Done'";
377        let record = sqlx::query(query).execute(&pool).await?;
378        Ok(record.rows_affected().try_into().unwrap_or_default())
379    }
380}
381
382/// Errors that can occur while polling a MySQL database.
383#[derive(thiserror::Error, Debug)]
384pub enum MysqlPollError {
385    /// Error during task acknowledgment.
386    #[error("Encountered an error during ACK: `{0}`")]
387    AckError(sqlx::Error),
388
389    /// Error during result encoding.
390    #[error("Encountered an error during encoding the result: {0}")]
391    CodecError(BoxDynError),
392
393    /// Error during a keep-alive heartbeat.
394    #[error("Encountered an error during KeepAlive heartbeat: `{0}`")]
395    KeepAliveError(sqlx::Error),
396
397    /// Error during re-enqueuing orphaned tasks.
398    #[error("Encountered an error during ReenqueueOrphaned heartbeat: `{0}`")]
399    ReenqueueOrphanedError(sqlx::Error),
400}
401
402impl<Req, Res, C> Backend<Request<Req, SqlContext>, Res> for MysqlStorage<Req, C>
403where
404    Req: Serialize + DeserializeOwned + Sync + Send + Unpin + 'static,
405    C: Debug + Codec<Compact = Value> + Clone + Send + 'static + Sync,
406    C::Error: std::error::Error + 'static + Send + Sync,
407{
408    type Stream = BackendStream<RequestStream<Request<Req, SqlContext>>>;
409
410    type Layer = AckLayer<MysqlStorage<Req, C>, Req, SqlContext, Res>;
411
412    fn poll<Svc>(self, worker: &Worker<Context>) -> Poller<Self::Stream, Self::Layer> {
413        let layer = AckLayer::new(self.clone());
414        let config = self.config.clone();
415        let controller = self.controller.clone();
416        let pool = self.pool.clone();
417        let ack_notify = self.ack_notify.clone();
418        let mut hb_storage = self.clone();
419        let requeue_storage = self.clone();
420        let stream = self
421            .stream_jobs(worker, config.poll_interval, config.buffer_size)
422            .map_err(|e| Error::SourceError(Arc::new(Box::new(e))));
423        let stream = BackendStream::new(stream.boxed(), controller);
424        let w = worker.clone();
425
426        let ack_heartbeat = async move {
427            while let Some(ids) = ack_notify
428                .clone()
429                .ready_chunks(config.buffer_size)
430                .next()
431                .await
432            {
433                for (ctx, res) in ids {
434                    let query = "UPDATE jobs SET status = ?, done_at = now(), last_error = ? WHERE id = ? AND lock_by = ?";
435                    let query = sqlx::query(query);
436                    let last_result =
437                        C::encode(res.inner.as_ref().map_err(|e| e.to_string())).map_err(Box::new);
438                    match (last_result, ctx.lock_by()) {
439                        (Ok(val), Some(worker_id)) => {
440                            let query = query
441                                .bind(calculate_status(&res.inner).to_string())
442                                .bind(val)
443                                .bind(res.task_id.to_string())
444                                .bind(worker_id.to_string());
445                            if let Err(e) = query.execute(&pool).await {
446                                w.emit(Event::Error(Box::new(MysqlPollError::AckError(e))));
447                            }
448                        }
449                        (Err(error), Some(_)) => {
450                            w.emit(Event::Error(Box::new(MysqlPollError::CodecError(error))));
451                        }
452                        _ => {
453                            unreachable!(
454                                "Attempted to ACK without a worker attached. This is a bug, File it on the repo"
455                            );
456                        }
457                    }
458                }
459
460                apalis_core::sleep(config.poll_interval).await;
461            }
462        };
463        let w = worker.clone();
464        let heartbeat = async move {
465            loop {
466                let now = Utc::now();
467                if let Err(e) = hb_storage.keep_alive_at::<Self::Layer>(w.id(), now).await {
468                    w.emit(Event::Error(Box::new(MysqlPollError::KeepAliveError(e))));
469                }
470                apalis_core::sleep(config.keep_alive).await;
471            }
472        };
473        let w = worker.clone();
474        let reenqueue_beat = async move {
475            loop {
476                let dead_since = Utc::now()
477                    - chrono::Duration::from_std(config.reenqueue_orphaned_after)
478                        .expect("Could not calculate dead since");
479                if let Err(e) = requeue_storage
480                    .reenqueue_orphaned(
481                        config
482                            .buffer_size
483                            .try_into()
484                            .expect("Could not convert usize to i32"),
485                        dead_since,
486                    )
487                    .await
488                {
489                    w.emit(Event::Error(Box::new(
490                        MysqlPollError::ReenqueueOrphanedError(e),
491                    )));
492                }
493                apalis_core::sleep(config.poll_interval).await;
494            }
495        };
496        Poller::new_with_layer(
497            stream,
498            async {
499                futures::join!(heartbeat, ack_heartbeat, reenqueue_beat);
500            },
501            layer,
502        )
503    }
504}
505
506impl<T, Res, C> Ack<T, Res> for MysqlStorage<T, C>
507where
508    T: Sync + Send,
509    Res: Serialize + Send + 'static + Sync,
510    C: Codec<Compact = Value> + Send,
511    C::Error: Debug,
512{
513    type Context = SqlContext;
514    type AckError = sqlx::Error;
515    async fn ack(&mut self, ctx: &Self::Context, res: &Response<Res>) -> Result<(), sqlx::Error> {
516        self.ack_notify
517            .notify((
518                ctx.clone(),
519                res.map(|res| C::encode(res).expect("Could not encode result")),
520            ))
521            .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::BrokenPipe, e)))?;
522
523        Ok(())
524    }
525}
526
527impl<T, C: Codec> MysqlStorage<T, C> {
528    /// Kill a job
529    pub async fn kill(&mut self, worker_id: &WorkerId, job_id: &TaskId) -> Result<(), sqlx::Error> {
530        let pool = self.pool.clone();
531
532        let mut tx = pool.acquire().await?;
533        let query =
534            "UPDATE jobs SET status = 'Killed', done_at = NOW() WHERE id = ? AND lock_by = ?";
535        sqlx::query(query)
536            .bind(job_id.to_string())
537            .bind(worker_id.to_string())
538            .execute(&mut *tx)
539            .await?;
540        Ok(())
541    }
542
543    /// Puts the job instantly back into the queue
544    pub async fn retry(
545        &mut self,
546        worker_id: &WorkerId,
547        job_id: &TaskId,
548    ) -> Result<(), sqlx::Error> {
549        let pool = self.pool.clone();
550
551        let mut tx = pool.acquire().await?;
552        let query =
553                "UPDATE jobs SET status = 'Pending', done_at = NULL, lock_by = NULL WHERE id = ? AND lock_by = ?";
554        sqlx::query(query)
555            .bind(job_id.to_string())
556            .bind(worker_id.to_string())
557            .execute(&mut *tx)
558            .await?;
559        Ok(())
560    }
561
562    /// Readd jobs that are abandoned to the queue
563    pub async fn reenqueue_orphaned(
564        &self,
565        count: i32,
566        dead_since: DateTime<Utc>,
567    ) -> Result<bool, sqlx::Error> {
568        let job_type = self.config.namespace.clone();
569        let mut tx = self.pool.acquire().await?;
570        let query = r#"Update jobs
571                        INNER JOIN ( SELECT workers.id as worker_id, jobs.id as job_id from workers INNER JOIN jobs ON jobs.lock_by = workers.id WHERE jobs.status = "Running" AND workers.last_seen < ? AND workers.worker_type = ?
572                            ORDER BY lock_at ASC LIMIT ?) as workers ON jobs.lock_by = workers.worker_id AND jobs.id = workers.job_id
573                        SET status = "Pending", done_at = NULL, lock_by = NULL, lock_at = NULL, last_error ="Job was abandoned";"#;
574
575        sqlx::query(query)
576            .bind(dead_since)
577            .bind(job_type)
578            .bind(count)
579            .execute(&mut *tx)
580            .await?;
581        Ok(true)
582    }
583}
584
585impl<J: 'static + Serialize + DeserializeOwned + Unpin + Send + Sync> BackendExpose<J>
586    for MysqlStorage<J>
587{
588    type Request = Request<J, Parts<SqlContext>>;
589    type Error = SqlError;
590    async fn stats(&self) -> Result<Stat, Self::Error> {
591        let fetch_query = "SELECT
592            COUNT(CASE WHEN status = 'Pending' THEN 1 END) AS pending,
593            COUNT(CASE WHEN status = 'Running' THEN 1 END) AS running,
594            COUNT(CASE WHEN status = 'Done' THEN 1 END) AS done,
595            COUNT(CASE WHEN status = 'Retry' THEN 1 END) AS retry,
596            COUNT(CASE WHEN status = 'Failed' THEN 1 END) AS failed,
597            COUNT(CASE WHEN status = 'Killed' THEN 1 END) AS killed
598        FROM jobs WHERE job_type = ?";
599
600        let res: (i64, i64, i64, i64, i64, i64) = sqlx::query_as(fetch_query)
601            .bind(self.get_config().namespace())
602            .fetch_one(self.pool())
603            .await?;
604
605        Ok(Stat {
606            pending: res.0.try_into()?,
607            running: res.1.try_into()?,
608            dead: res.4.try_into()?,
609            failed: res.3.try_into()?,
610            success: res.2.try_into()?,
611        })
612    }
613
614    async fn list_jobs(
615        &self,
616        status: &State,
617        page: i32,
618    ) -> Result<Vec<Self::Request>, Self::Error> {
619        let status = status.to_string();
620        let fetch_query = "SELECT * FROM jobs WHERE status = ? AND job_type = ? ORDER BY done_at DESC, run_at DESC LIMIT 10 OFFSET ?";
621        let res: Vec<SqlRequest<serde_json::Value>> = sqlx::query_as(fetch_query)
622            .bind(status)
623            .bind(self.get_config().namespace())
624            .bind(((page - 1) * 10).to_string())
625            .fetch_all(self.pool())
626            .await?;
627        Ok(res
628            .into_iter()
629            .map(|j| {
630                let (req, ctx) = j.req.take_parts();
631                let req: J = MysqlCodec::decode(req).unwrap();
632                Request::new_with_ctx(req, ctx)
633            })
634            .collect())
635    }
636
637    async fn list_workers(&self) -> Result<Vec<Worker<WorkerState>>, Self::Error> {
638        let fetch_query =
639            "SELECT id, layers, last_seen FROM workers WHERE worker_type = ? ORDER BY last_seen DESC LIMIT 20 OFFSET ?";
640        let res: Vec<(String, String, i64)> = sqlx::query_as(fetch_query)
641            .bind(self.get_config().namespace())
642            .bind(0)
643            .fetch_all(self.pool())
644            .await?;
645        Ok(res
646            .into_iter()
647            .map(|w| Worker::new(WorkerId::new(w.0), WorkerState::new::<Self>(w.1)))
648            .collect())
649    }
650}
651
652#[cfg(test)]
653mod tests {
654
655    use crate::sql_storage_tests;
656
657    use super::*;
658
659    use apalis_core::test_utils::DummyService;
660    use email_service::Email;
661    use futures::StreamExt;
662
663    use apalis_core::generic_storage_test;
664    use apalis_core::test_utils::apalis_test_service_fn;
665    use apalis_core::test_utils::TestWrapper;
666
667    generic_storage_test!(setup);
668
669    sql_storage_tests!(setup::<Email>, MysqlStorage<Email>, Email);
670
671    /// migrate DB and return a storage instance.
672    async fn setup<T: Serialize + DeserializeOwned>() -> MysqlStorage<T> {
673        let db_url = &std::env::var("DATABASE_URL").expect("No DATABASE_URL is specified");
674        // Because connections cannot be shared across async runtime
675        // (different runtimes are created for each test),
676        // we don't share the storage and tests must be run sequentially.
677        let pool = MySqlPool::connect(db_url).await.unwrap();
678        MysqlStorage::setup(&pool)
679            .await
680            .expect("failed to migrate DB");
681        let mut storage = MysqlStorage::new(pool);
682        cleanup(&mut storage, &WorkerId::new("test-worker")).await;
683        storage
684    }
685
686    /// rollback DB changes made by tests.
687    /// Delete the following rows:
688    ///  - jobs whose state is `Pending` or locked by `worker_id`
689    ///  - worker identified by `worker_id`
690    ///
691    /// You should execute this function in the end of a test
692    async fn cleanup<T>(storage: &mut MysqlStorage<T>, worker_id: &WorkerId) {
693        sqlx::query("DELETE FROM jobs WHERE job_type = ?")
694            .bind(storage.config.namespace())
695            .execute(&storage.pool)
696            .await
697            .expect("failed to delete jobs");
698        sqlx::query("DELETE FROM workers WHERE id = ?")
699            .bind(worker_id.to_string())
700            .execute(&storage.pool)
701            .await
702            .expect("failed to delete worker");
703    }
704
705    async fn consume_one(
706        storage: &mut MysqlStorage<Email>,
707        worker: &Worker<Context>,
708    ) -> Request<Email, SqlContext> {
709        let mut stream = storage
710            .clone()
711            .stream_jobs(worker, std::time::Duration::from_secs(10), 1);
712        stream
713            .next()
714            .await
715            .expect("stream is empty")
716            .expect("failed to poll job")
717            .expect("no job is pending")
718    }
719
720    fn example_email() -> Email {
721        Email {
722            subject: "Test Subject".to_string(),
723            to: "example@mysql".to_string(),
724            text: "Some Text".to_string(),
725        }
726    }
727
728    async fn register_worker_at(
729        storage: &mut MysqlStorage<Email>,
730        last_seen: DateTime<Utc>,
731    ) -> Worker<Context> {
732        let worker_id = WorkerId::new("test-worker");
733        let wrk = Worker::new(worker_id, Context::default());
734        wrk.start();
735        storage
736            .keep_alive_at::<DummyService>(&wrk.id(), last_seen)
737            .await
738            .expect("failed to register worker");
739        wrk
740    }
741
742    async fn register_worker(storage: &mut MysqlStorage<Email>) -> Worker<Context> {
743        let now = Utc::now();
744
745        register_worker_at(storage, now).await
746    }
747
748    async fn push_email(storage: &mut MysqlStorage<Email>, email: Email) {
749        storage.push(email).await.expect("failed to push a job");
750    }
751
752    async fn get_job(
753        storage: &mut MysqlStorage<Email>,
754        job_id: &TaskId,
755    ) -> Request<Email, SqlContext> {
756        // add a slight delay to allow background actions like ack to complete
757        apalis_core::sleep(Duration::from_secs(1)).await;
758        storage
759            .fetch_by_id(job_id)
760            .await
761            .expect("failed to fetch job by id")
762            .expect("no job found by id")
763    }
764
765    #[tokio::test]
766    async fn test_consume_last_pushed_job() {
767        let mut storage = setup().await;
768        push_email(&mut storage, example_email()).await;
769
770        let worker = register_worker(&mut storage).await;
771
772        let job = consume_one(&mut storage, &worker).await;
773        let ctx = job.parts.context;
774        // TODO: Fix assertions
775        assert_eq!(*ctx.status(), State::Running);
776        assert_eq!(*ctx.lock_by(), Some(worker.id().clone()));
777        assert!(ctx.lock_at().is_some());
778    }
779
780    #[tokio::test]
781    async fn test_kill_job() {
782        let mut storage = setup().await;
783
784        push_email(&mut storage, example_email()).await;
785
786        let worker = register_worker(&mut storage).await;
787
788        let job = consume_one(&mut storage, &worker).await;
789
790        let job_id = &job.parts.task_id;
791
792        storage
793            .kill(worker.id(), job_id)
794            .await
795            .expect("failed to kill job");
796
797        let job = get_job(&mut storage, job_id).await;
798        let ctx = job.parts.context;
799        // TODO: Fix assertions
800        assert_eq!(*ctx.status(), State::Killed);
801        assert!(ctx.done_at().is_some());
802    }
803
804    #[tokio::test]
805    async fn test_storage_heartbeat_reenqueuorphaned_pulse_last_seen_6min() {
806        let mut storage = setup().await;
807
808        // push an Email job
809        storage
810            .push(example_email())
811            .await
812            .expect("failed to push job");
813
814        // register a worker not responding since 6 minutes ago
815        let worker_id = WorkerId::new("test-worker");
816        let worker = Worker::new(worker_id, Context::default());
817        worker.start();
818        let five_minutes_ago = Utc::now() - Duration::from_secs(5 * 60);
819
820        let six_minutes_ago = Utc::now() - Duration::from_secs(60 * 6);
821
822        storage
823            .keep_alive_at::<Email>(worker.id(), six_minutes_ago)
824            .await
825            .unwrap();
826
827        // fetch job
828        let job = consume_one(&mut storage, &worker).await;
829        let ctx = job.parts.context;
830
831        assert_eq!(*ctx.status(), State::Running);
832
833        storage
834            .reenqueue_orphaned(1, five_minutes_ago)
835            .await
836            .unwrap();
837
838        // then, the job status has changed to Pending
839        let job = storage
840            .fetch_by_id(&job.parts.task_id)
841            .await
842            .unwrap()
843            .unwrap();
844        let ctx = job.parts.context;
845        assert_eq!(*ctx.status(), State::Pending);
846        assert!(ctx.done_at().is_none());
847        assert!(ctx.lock_by().is_none());
848        assert!(ctx.lock_at().is_none());
849        assert_eq!(*ctx.last_error(), Some("Job was abandoned".to_owned()));
850        assert_eq!(job.parts.attempt.current(), 1);
851    }
852
853    #[tokio::test]
854    async fn test_storage_heartbeat_reenqueuorphaned_pulse_last_seen_4min() {
855        let mut storage = setup().await;
856
857        // push an Email job
858        storage
859            .push(example_email())
860            .await
861            .expect("failed to push job");
862
863        // register a worker responding at 4 minutes ago
864        let four_minutes_ago = Utc::now() - Duration::from_secs(4 * 60);
865        let six_minutes_ago = Utc::now() - Duration::from_secs(6 * 60);
866
867        let worker_id = WorkerId::new("test-worker");
868        let worker = Worker::new(worker_id, Context::default());
869        worker.start();
870        storage
871            .keep_alive_at::<Email>(&worker.id(), four_minutes_ago)
872            .await
873            .unwrap();
874
875        // fetch job
876        let job = consume_one(&mut storage, &worker).await;
877        let ctx = &job.parts.context;
878
879        assert_eq!(*ctx.status(), State::Running);
880
881        // heartbeat with ReenqueueOrpharned pulse
882        storage
883            .reenqueue_orphaned(1, six_minutes_ago)
884            .await
885            .unwrap();
886
887        // then, the job status is not changed
888        let job = storage
889            .fetch_by_id(&job.parts.task_id)
890            .await
891            .unwrap()
892            .unwrap();
893        let ctx = job.parts.context;
894        assert_eq!(*ctx.status(), State::Running);
895        assert_eq!(*ctx.lock_by(), Some(worker.id().clone()));
896        assert!(ctx.lock_at().is_some());
897        assert_eq!(*ctx.last_error(), None);
898        assert_eq!(job.parts.attempt.current(), 1);
899    }
900}