apalis_sql/
sqlite.rs

1use crate::context::SqlContext;
2use crate::{calculate_status, Config, SqlError};
3use apalis_core::backend::{BackendExpose, Stat, WorkerState};
4use apalis_core::codec::json::JsonCodec;
5use apalis_core::error::Error;
6use apalis_core::layers::{Ack, AckLayer};
7use apalis_core::poller::controller::Controller;
8use apalis_core::poller::stream::BackendStream;
9use apalis_core::poller::Poller;
10use apalis_core::request::{Parts, Request, RequestStream, State};
11use apalis_core::response::Response;
12use apalis_core::storage::Storage;
13use apalis_core::task::namespace::Namespace;
14use apalis_core::task::task_id::TaskId;
15use apalis_core::worker::{Context, Event, Worker, WorkerId};
16use apalis_core::{backend::Backend, codec::Codec};
17use async_stream::try_stream;
18use chrono::{DateTime, Utc};
19use futures::{FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt};
20use log::error;
21use serde::{de::DeserializeOwned, Serialize};
22use sqlx::{Pool, Row, Sqlite};
23use std::any::type_name;
24use std::convert::TryInto;
25use std::sync::Arc;
26use std::{fmt, io};
27use std::{marker::PhantomData, time::Duration};
28
29use crate::from_row::SqlRequest;
30
31pub use sqlx::sqlite::SqlitePool;
32
33/// Represents a [Storage] that persists to Sqlite
34// #[derive(Debug)]
35pub struct SqliteStorage<T, C = JsonCodec<String>> {
36    pool: Pool<Sqlite>,
37    job_type: PhantomData<T>,
38    controller: Controller,
39    config: Config,
40    codec: PhantomData<C>,
41}
42
43impl<T, C> fmt::Debug for SqliteStorage<T, C> {
44    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
45        f.debug_struct("MysqlStorage")
46            .field("pool", &self.pool)
47            .field("job_type", &"PhantomData<T>")
48            .field("controller", &self.controller)
49            .field("config", &self.config)
50            .field("codec", &std::any::type_name::<C>())
51            .finish()
52    }
53}
54
55impl<T> Clone for SqliteStorage<T> {
56    fn clone(&self) -> Self {
57        SqliteStorage {
58            pool: self.pool.clone(),
59            job_type: PhantomData,
60            controller: self.controller.clone(),
61            config: self.config.clone(),
62            codec: self.codec,
63        }
64    }
65}
66
67impl SqliteStorage<()> {
68    /// Perform migrations for storage
69    #[cfg(feature = "migrate")]
70    pub async fn setup(pool: &Pool<Sqlite>) -> Result<(), sqlx::Error> {
71        sqlx::query("PRAGMA journal_mode = 'WAL';")
72            .execute(pool)
73            .await?;
74        sqlx::query("PRAGMA temp_store = 2;").execute(pool).await?;
75        sqlx::query("PRAGMA synchronous = NORMAL;")
76            .execute(pool)
77            .await?;
78        sqlx::query("PRAGMA cache_size = 64000;")
79            .execute(pool)
80            .await?;
81        Self::migrations().run(pool).await?;
82        Ok(())
83    }
84
85    /// Get sqlite migrations without running them
86    #[cfg(feature = "migrate")]
87    pub fn migrations() -> sqlx::migrate::Migrator {
88        sqlx::migrate!("migrations/sqlite")
89    }
90}
91
92impl<T: Serialize + DeserializeOwned> SqliteStorage<T> {
93    /// Create a new instance
94    pub fn new(pool: SqlitePool) -> Self {
95        Self {
96            pool,
97            job_type: PhantomData,
98            controller: Controller::new(),
99            config: Config::new(type_name::<T>()),
100            codec: PhantomData,
101        }
102    }
103
104    /// Create a new instance with a custom config
105    pub fn new_with_config(pool: SqlitePool, config: Config) -> Self {
106        Self {
107            pool,
108            job_type: PhantomData,
109            controller: Controller::new(),
110            config,
111            codec: PhantomData,
112        }
113    }
114    /// Keeps a storage notified that the worker is still alive manually
115    pub async fn keep_alive_at<Service>(
116        &mut self,
117        worker_id: &WorkerId,
118        last_seen: i64,
119    ) -> Result<(), sqlx::Error> {
120        let worker_type = self.config.namespace.clone();
121        let storage_name = std::any::type_name::<Self>();
122        let query = "INSERT INTO Workers (id, worker_type, storage_name, layers, last_seen)
123                VALUES ($1, $2, $3, $4, $5)
124                ON CONFLICT (id) DO
125                   UPDATE SET last_seen = EXCLUDED.last_seen";
126        sqlx::query(query)
127            .bind(worker_id.to_string())
128            .bind(worker_type)
129            .bind(storage_name)
130            .bind(std::any::type_name::<Service>())
131            .bind(last_seen)
132            .execute(&self.pool)
133            .await?;
134        Ok(())
135    }
136
137    /// Expose the pool for other functionality, eg custom migrations
138    pub fn pool(&self) -> &Pool<Sqlite> {
139        &self.pool
140    }
141
142    /// Get the config used by the storage
143    pub fn get_config(&self) -> &Config {
144        &self.config
145    }
146}
147
148impl<T, C> SqliteStorage<T, C> {
149    /// Expose the code used
150    pub fn codec(&self) -> &PhantomData<C> {
151        &self.codec
152    }
153}
154
155async fn fetch_next(
156    pool: &Pool<Sqlite>,
157    worker_id: &WorkerId,
158    id: String,
159    config: &Config,
160) -> Result<Option<SqlRequest<String>>, sqlx::Error> {
161    let now: i64 = Utc::now().timestamp();
162    let update_query = "UPDATE Jobs SET status = 'Running', lock_by = ?2, lock_at = ?3, attempts = attempts + 1 WHERE id = ?1 AND job_type = ?4 AND status = 'Pending' AND lock_by IS NULL; Select * from Jobs where id = ?1 AND lock_by = ?2 AND job_type = ?4";
163    let job: Option<SqlRequest<String>> = sqlx::query_as(update_query)
164        .bind(id.to_string())
165        .bind(worker_id.to_string())
166        .bind(now)
167        .bind(config.namespace.clone())
168        .fetch_optional(pool)
169        .await?;
170
171    Ok(job)
172}
173
174impl<T, C> SqliteStorage<T, C>
175where
176    T: DeserializeOwned + Send + Unpin,
177    C: Codec<Compact = String>,
178{
179    fn stream_jobs(
180        &self,
181        worker: &Worker<Context>,
182        interval: Duration,
183        buffer_size: usize,
184    ) -> impl Stream<Item = Result<Option<Request<T, SqlContext>>, sqlx::Error>> {
185        let pool = self.pool.clone();
186        let worker = worker.clone();
187        let config = self.config.clone();
188        let namespace = Namespace(self.config.namespace.clone());
189        try_stream! {
190            loop {
191                apalis_core::sleep(interval).await;
192                if !worker.is_ready() {
193                    continue;
194                }
195                let worker_id = worker.id();
196                let tx = pool.clone();
197                let mut tx = tx.acquire().await?;
198                let job_type = &config.namespace;
199                let fetch_query = "SELECT id FROM Jobs
200                    WHERE (status = 'Pending' OR (status = 'Failed' AND attempts < max_attempts)) AND run_at < ?1 AND job_type = ?2 LIMIT ?3";
201                let now: i64 = Utc::now().timestamp();
202                let ids: Vec<(String,)> = sqlx::query_as(fetch_query)
203                    .bind(now)
204                    .bind(job_type)
205                    .bind(i64::try_from(buffer_size).map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?)
206                    .fetch_all(&mut *tx)
207                    .await?;
208                for id in ids {
209                    let res = fetch_next(&pool, worker_id, id.0, &config).await?;
210                    yield match res {
211                        None => None::<Request<T, SqlContext>>,
212                        Some(job) => {
213                            let (req, parts) = job.req.take_parts();
214                            let args = C::decode(req)
215                                .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
216                            let mut req = Request::new_with_parts(args, parts);
217                            req.parts.namespace = Some(namespace.clone());
218                            Some(req)
219                        }
220                    }
221                };
222            }
223        }
224    }
225}
226
227impl<T, C> Storage for SqliteStorage<T, C>
228where
229    T: Serialize + DeserializeOwned + Send + 'static + Unpin + Sync,
230    C: Codec<Compact = String> + Send,
231{
232    type Job = T;
233
234    type Error = sqlx::Error;
235
236    type Context = SqlContext;
237
238    async fn push_request(
239        &mut self,
240        job: Request<Self::Job, SqlContext>,
241    ) -> Result<Parts<SqlContext>, Self::Error> {
242        let query = "INSERT INTO Jobs VALUES (?1, ?2, ?3, 'Pending', 0, ?4, strftime('%s','now'), NULL, NULL, NULL, NULL)";
243        let (task, parts) = job.take_parts();
244        let raw = C::encode(&task)
245            .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
246        let job_type = self.config.namespace.clone();
247        sqlx::query(query)
248            .bind(raw)
249            .bind(parts.task_id.to_string())
250            .bind(job_type.to_string())
251            .bind(parts.context.max_attempts())
252            .execute(&self.pool)
253            .await?;
254        Ok(parts)
255    }
256
257    async fn schedule_request(
258        &mut self,
259        req: Request<Self::Job, SqlContext>,
260        on: i64,
261    ) -> Result<Parts<SqlContext>, Self::Error> {
262        let query =
263            "INSERT INTO Jobs VALUES (?1, ?2, ?3, 'Pending', 0, ?4, ?5, NULL, NULL, NULL, NULL)";
264        let id = &req.parts.task_id;
265        let job = C::encode(&req.args)
266            .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
267        let job_type = self.config.namespace.clone();
268        sqlx::query(query)
269            .bind(job)
270            .bind(id.to_string())
271            .bind(job_type)
272            .bind(req.parts.context.max_attempts())
273            .bind(on)
274            .execute(&self.pool)
275            .await?;
276        Ok(req.parts)
277    }
278
279    async fn fetch_by_id(
280        &mut self,
281        job_id: &TaskId,
282    ) -> Result<Option<Request<Self::Job, SqlContext>>, Self::Error> {
283        let fetch_query = "SELECT * FROM Jobs WHERE id = ?1";
284        let res: Option<SqlRequest<String>> = sqlx::query_as(fetch_query)
285            .bind(job_id.to_string())
286            .fetch_optional(&self.pool)
287            .await?;
288        match res {
289            None => Ok(None),
290            Some(job) => Ok(Some({
291                let (req, parts) = job.req.take_parts();
292                let args = C::decode(req)
293                    .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
294
295                let mut req: Request<T, SqlContext> = Request::new_with_parts(args, parts);
296                req.parts.namespace = Some(Namespace(self.config.namespace.clone()));
297                req
298            })),
299        }
300    }
301
302    async fn len(&mut self) -> Result<i64, Self::Error> {
303        let query = "Select Count(*) as count from Jobs where status='Pending'";
304        let record = sqlx::query(query).fetch_one(&self.pool).await?;
305        record.try_get("count")
306    }
307
308    async fn reschedule(
309        &mut self,
310        job: Request<T, SqlContext>,
311        wait: Duration,
312    ) -> Result<(), Self::Error> {
313        let task_id = job.parts.task_id;
314
315        let wait: i64 = wait
316            .as_secs()
317            .try_into()
318            .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
319
320        let mut tx = self.pool.acquire().await?;
321        let query =
322                "UPDATE Jobs SET status = 'Failed', done_at = NULL, lock_by = NULL, lock_at = NULL, run_at = ?2 WHERE id = ?1";
323        let now: i64 = Utc::now().timestamp();
324        let wait_until = now + wait;
325
326        sqlx::query(query)
327            .bind(task_id.to_string())
328            .bind(wait_until)
329            .execute(&mut *tx)
330            .await?;
331        Ok(())
332    }
333
334    async fn update(&mut self, job: Request<Self::Job, SqlContext>) -> Result<(), Self::Error> {
335        let ctx = job.parts.context;
336        let status = ctx.status().to_string();
337        let attempts = job.parts.attempt;
338        let done_at = *ctx.done_at();
339        let lock_by = ctx.lock_by().clone();
340        let lock_at = *ctx.lock_at();
341        let last_error = ctx.last_error().clone();
342        let job_id = job.parts.task_id;
343        let mut tx = self.pool.acquire().await?;
344        let query =
345                "UPDATE Jobs SET status = ?1, attempts = ?2, done_at = ?3, lock_by = ?4, lock_at = ?5, last_error = ?6 WHERE id = ?7";
346        sqlx::query(query)
347            .bind(status.to_owned())
348            .bind::<i64>(
349                attempts
350                    .current()
351                    .try_into()
352                    .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?,
353            )
354            .bind(done_at)
355            .bind(lock_by.map(|w| w.name().to_string()))
356            .bind(lock_at)
357            .bind(last_error)
358            .bind(job_id.to_string())
359            .execute(&mut *tx)
360            .await?;
361        Ok(())
362    }
363
364    async fn is_empty(&mut self) -> Result<bool, Self::Error> {
365        self.len().map_ok(|c| c == 0).await
366    }
367
368    async fn vacuum(&mut self) -> Result<usize, sqlx::Error> {
369        let query = "Delete from Jobs where status='Done'";
370        let record = sqlx::query(query).execute(&self.pool).await?;
371        Ok(record.rows_affected().try_into().unwrap_or_default())
372    }
373}
374
375impl<T> SqliteStorage<T> {
376    /// Puts the job instantly back into the queue
377    /// Another Worker may consume
378    pub async fn retry(
379        &mut self,
380        worker_id: &WorkerId,
381        job_id: &TaskId,
382    ) -> Result<(), sqlx::Error> {
383        let mut tx = self.pool.acquire().await?;
384        let query =
385                "UPDATE Jobs SET status = 'Pending', done_at = NULL, lock_by = NULL WHERE id = ?1 AND lock_by = ?2";
386        sqlx::query(query)
387            .bind(job_id.to_string())
388            .bind(worker_id.to_string())
389            .execute(&mut *tx)
390            .await?;
391        Ok(())
392    }
393
394    /// Kill a job
395    pub async fn kill(&mut self, worker_id: &WorkerId, job_id: &TaskId) -> Result<(), sqlx::Error> {
396        let mut tx = self.pool.begin().await?;
397        let query =
398                "UPDATE Jobs SET status = 'Killed', done_at = strftime('%s','now') WHERE id = ?1 AND lock_by = ?2";
399        sqlx::query(query)
400            .bind(job_id.to_string())
401            .bind(worker_id.to_string())
402            .execute(&mut *tx)
403            .await?;
404        tx.commit().await?;
405        Ok(())
406    }
407
408    /// Add jobs that failed back to the queue if there are still remaining attemps
409    pub async fn reenqueue_failed(&mut self) -> Result<(), sqlx::Error> {
410        let job_type = self.config.namespace.clone();
411        let mut tx = self.pool.acquire().await?;
412        let query = r#"Update Jobs
413                            SET status = "Pending", done_at = NULL, lock_by = NULL, lock_at = NULL
414                            WHERE id in
415                                (SELECT Jobs.id from Jobs
416                                    WHERE status= "Failed" AND Jobs.attempts < Jobs.max_attempts
417                                     ORDER BY lock_at ASC LIMIT ?2);"#;
418        sqlx::query(query)
419            .bind(job_type)
420            .bind::<u32>(
421                self.config
422                    .buffer_size
423                    .try_into()
424                    .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?,
425            )
426            .execute(&mut *tx)
427            .await?;
428        Ok(())
429    }
430
431    /// Add jobs that workers have disappeared to the queue
432    pub async fn reenqueue_orphaned(
433        &self,
434        count: i32,
435        dead_since: DateTime<Utc>,
436    ) -> Result<(), sqlx::Error> {
437        let job_type = self.config.namespace.clone();
438        let mut tx = self.pool.acquire().await?;
439        let query = r#"Update Jobs
440                            SET status = "Pending", done_at = NULL, lock_by = NULL, lock_at = NULL, last_error ="Job was abandoned"
441                            WHERE id in
442                                (SELECT Jobs.id from Jobs INNER join Workers ON lock_by = Workers.id
443                                    WHERE status= "Running" AND workers.last_seen < ?1
444                                    AND Workers.worker_type = ?2 ORDER BY lock_at ASC LIMIT ?3);"#;
445
446        sqlx::query(query)
447            .bind(dead_since.timestamp())
448            .bind(job_type)
449            .bind(count)
450            .execute(&mut *tx)
451            .await?;
452        Ok(())
453    }
454}
455
456/// Errors that can occur while polling an SQLite database.
457#[derive(thiserror::Error, Debug)]
458pub enum SqlitePollError {
459    /// Error during a keep-alive heartbeat.
460    #[error("Encountered an error during KeepAlive heartbeat: `{0}`")]
461    KeepAliveError(sqlx::Error),
462
463    /// Error during re-enqueuing orphaned tasks.
464    #[error("Encountered an error during ReenqueueOrphaned heartbeat: `{0}`")]
465    ReenqueueOrphanedError(sqlx::Error),
466}
467
468impl<T: Serialize + DeserializeOwned + Sync + Send + Unpin + 'static, Res>
469    Backend<Request<T, SqlContext>, Res> for SqliteStorage<T>
470{
471    type Stream = BackendStream<RequestStream<Request<T, SqlContext>>>;
472    type Layer = AckLayer<SqliteStorage<T>, T, SqlContext, Res>;
473
474    fn poll<Svc>(mut self, worker: &Worker<Context>) -> Poller<Self::Stream, Self::Layer> {
475        let layer = AckLayer::new(self.clone());
476        let config = self.config.clone();
477        let controller = self.controller.clone();
478        let stream = self
479            .stream_jobs(worker, config.poll_interval, config.buffer_size)
480            .map_err(|e| Error::SourceError(Arc::new(Box::new(e))));
481        let stream = BackendStream::new(stream.boxed(), controller);
482        let requeue_storage = self.clone();
483        let w = worker.clone();
484        let heartbeat = async move {
485            loop {
486                let now: i64 = Utc::now().timestamp();
487                if let Err(e) = self.keep_alive_at::<Self::Layer>(w.id(), now).await {
488                    w.emit(Event::Error(Box::new(SqlitePollError::KeepAliveError(e))));
489                }
490                apalis_core::sleep(Duration::from_secs(30)).await;
491            }
492        }
493        .boxed();
494        let w = worker.clone();
495        let reenqueue_beat = async move {
496            loop {
497                let dead_since = Utc::now()
498                    - chrono::Duration::from_std(config.reenqueue_orphaned_after).unwrap();
499                if let Err(e) = requeue_storage
500                    .reenqueue_orphaned(
501                        config
502                            .buffer_size
503                            .try_into()
504                            .expect("could not convert usize to i32"),
505                        dead_since,
506                    )
507                    .await
508                {
509                    w.emit(Event::Error(Box::new(
510                        SqlitePollError::ReenqueueOrphanedError(e),
511                    )));
512                }
513                apalis_core::sleep(config.poll_interval).await;
514            }
515        };
516        Poller::new_with_layer(
517            stream,
518            async {
519                futures::join!(heartbeat, reenqueue_beat);
520            },
521            layer,
522        )
523    }
524}
525
526impl<T: Sync + Send, Res: Serialize + Sync> Ack<T, Res> for SqliteStorage<T> {
527    type Context = SqlContext;
528    type AckError = sqlx::Error;
529    async fn ack(&mut self, ctx: &Self::Context, res: &Response<Res>) -> Result<(), sqlx::Error> {
530        let pool = self.pool.clone();
531        let query =
532                "UPDATE Jobs SET status = ?4, done_at = strftime('%s','now'), last_error = ?3 WHERE id = ?1 AND lock_by = ?2";
533        let result = serde_json::to_string(&res.inner.as_ref().map_err(|r| r.to_string()))
534            .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
535        sqlx::query(query)
536            .bind(res.task_id.to_string())
537            .bind(
538                ctx.lock_by()
539                    .as_ref()
540                    .expect("Task is not locked")
541                    .to_string(),
542            )
543            .bind(result)
544            .bind(calculate_status(&res.inner).to_string())
545            .execute(&pool)
546            .await?;
547        Ok(())
548    }
549}
550
551impl<J: 'static + Serialize + DeserializeOwned + Unpin + Send + Sync> BackendExpose<J>
552    for SqliteStorage<J, JsonCodec<String>>
553{
554    type Request = Request<J, Parts<SqlContext>>;
555    type Error = SqlError;
556    async fn stats(&self) -> Result<Stat, Self::Error> {
557        let fetch_query = "SELECT
558                            COUNT(1) FILTER (WHERE status = 'Pending') AS pending,
559                            COUNT(1) FILTER (WHERE status = 'Running') AS running,
560                            COUNT(1) FILTER (WHERE status = 'Done') AS done,
561                            COUNT(1) FILTER (WHERE status = 'Failed') AS failed,
562                            COUNT(1) FILTER (WHERE status = 'Killed') AS killed
563                        FROM Jobs WHERE job_type = ?";
564
565        let res: (i64, i64, i64, i64, i64, i64) = sqlx::query_as(fetch_query)
566            .bind(self.get_config().namespace())
567            .fetch_one(self.pool())
568            .await?;
569
570        Ok(Stat {
571            pending: res.0.try_into()?,
572            running: res.1.try_into()?,
573            dead: res.4.try_into()?,
574            failed: res.3.try_into()?,
575            success: res.2.try_into()?,
576        })
577    }
578
579    async fn list_jobs(
580        &self,
581        status: &State,
582        page: i32,
583    ) -> Result<Vec<Self::Request>, Self::Error> {
584        let status = status.to_string();
585        let fetch_query = "SELECT * FROM Jobs WHERE status = ? AND job_type = ? ORDER BY done_at DESC, run_at DESC LIMIT 10 OFFSET ?";
586        let res: Vec<SqlRequest<String>> = sqlx::query_as(fetch_query)
587            .bind(status)
588            .bind(self.get_config().namespace())
589            .bind(((page - 1) * 10).to_string())
590            .fetch_all(self.pool())
591            .await?;
592        Ok(res
593            .into_iter()
594            .map(|j| {
595                let (req, ctx) = j.req.take_parts();
596                let req = JsonCodec::<String>::decode(req).unwrap();
597                Request::new_with_ctx(req, ctx)
598            })
599            .collect())
600    }
601
602    async fn list_workers(&self) -> Result<Vec<Worker<WorkerState>>, Self::Error> {
603        let fetch_query =
604            "SELECT id, layers, last_seen FROM Workers WHERE worker_type = ? ORDER BY last_seen DESC LIMIT 20 OFFSET ?";
605        let res: Vec<(String, String, i64)> = sqlx::query_as(fetch_query)
606            .bind(self.get_config().namespace())
607            .bind(0)
608            .fetch_all(self.pool())
609            .await?;
610        Ok(res
611            .into_iter()
612            .map(|w| Worker::new(WorkerId::new(w.0), WorkerState::new::<Self>(w.1)))
613            .collect())
614    }
615}
616
617#[cfg(test)]
618mod tests {
619
620    use crate::sql_storage_tests;
621
622    use super::*;
623    use apalis_core::request::State;
624    use apalis_core::test_utils::DummyService;
625    use chrono::Utc;
626    use email_service::example_good_email;
627    use email_service::Email;
628    use futures::StreamExt;
629
630    use apalis_core::generic_storage_test;
631    use apalis_core::test_utils::apalis_test_service_fn;
632    use apalis_core::test_utils::TestWrapper;
633
634    generic_storage_test!(setup);
635    sql_storage_tests!(setup::<Email>, SqliteStorage<Email>, Email);
636
637    /// migrate DB and return a storage instance.
638    async fn setup<T: Serialize + DeserializeOwned>() -> SqliteStorage<T> {
639        // Because connections cannot be shared across async runtime
640        // (different runtimes are created for each test),
641        // we don't share the storage and tests must be run sequentially.
642        let pool = SqlitePool::connect("sqlite::memory:").await.unwrap();
643        SqliteStorage::setup(&pool)
644            .await
645            .expect("failed to migrate DB");
646        let storage = SqliteStorage::<T>::new(pool);
647
648        storage
649    }
650
651    #[tokio::test]
652    async fn test_inmemory_sqlite_worker() {
653        let mut sqlite = setup().await;
654        sqlite
655            .push(Email {
656                subject: "Test Subject".to_string(),
657                to: "example@sqlite".to_string(),
658                text: "Some Text".to_string(),
659            })
660            .await
661            .expect("Unable to push job");
662        let len = sqlite.len().await.expect("Could not fetch the jobs count");
663        assert_eq!(len, 1);
664    }
665
666    async fn consume_one(
667        storage: &mut SqliteStorage<Email>,
668        worker: &Worker<Context>,
669    ) -> Request<Email, SqlContext> {
670        let mut stream = storage
671            .stream_jobs(worker, std::time::Duration::from_secs(10), 1)
672            .boxed();
673        stream
674            .next()
675            .await
676            .expect("stream is empty")
677            .expect("failed to poll job")
678            .expect("no job is pending")
679    }
680
681    async fn register_worker_at(
682        storage: &mut SqliteStorage<Email>,
683        last_seen: i64,
684    ) -> Worker<Context> {
685        let worker_id = WorkerId::new("test-worker");
686
687        storage
688            .keep_alive_at::<DummyService>(&worker_id, last_seen)
689            .await
690            .expect("failed to register worker");
691        let wrk = Worker::new(worker_id, Context::default());
692        wrk.start();
693        wrk
694    }
695
696    async fn register_worker(storage: &mut SqliteStorage<Email>) -> Worker<Context> {
697        register_worker_at(storage, Utc::now().timestamp()).await
698    }
699
700    async fn push_email(storage: &mut SqliteStorage<Email>, email: Email) {
701        storage.push(email).await.expect("failed to push a job");
702    }
703
704    async fn get_job(
705        storage: &mut SqliteStorage<Email>,
706        job_id: &TaskId,
707    ) -> Request<Email, SqlContext> {
708        storage
709            .fetch_by_id(job_id)
710            .await
711            .expect("failed to fetch job by id")
712            .expect("no job found by id")
713    }
714
715    #[tokio::test]
716    async fn test_consume_last_pushed_job() {
717        let mut storage = setup().await;
718        let worker = register_worker(&mut storage).await;
719
720        push_email(&mut storage, example_good_email()).await;
721        let len = storage.len().await.expect("Could not fetch the jobs count");
722        assert_eq!(len, 1);
723
724        let job = consume_one(&mut storage, &worker).await;
725        let ctx = job.parts.context;
726        assert_eq!(*ctx.status(), State::Running);
727        assert_eq!(*ctx.lock_by(), Some(worker.id().clone()));
728        assert!(ctx.lock_at().is_some());
729    }
730
731    #[tokio::test]
732    async fn test_acknowledge_job() {
733        let mut storage = setup().await;
734        let worker = register_worker(&mut storage).await;
735
736        push_email(&mut storage, example_good_email()).await;
737        let job = consume_one(&mut storage, &worker).await;
738        let job_id = &job.parts.task_id;
739        let ctx = &job.parts.context;
740        let res = 1usize;
741        storage
742            .ack(
743                ctx,
744                &Response::success(res, job_id.clone(), job.parts.attempt.clone()),
745            )
746            .await
747            .expect("failed to acknowledge the job");
748
749        let job = get_job(&mut storage, job_id).await;
750        let ctx = job.parts.context;
751        assert_eq!(*ctx.status(), State::Done);
752        assert!(ctx.done_at().is_some());
753    }
754
755    #[tokio::test]
756    async fn test_kill_job() {
757        let mut storage = setup().await;
758
759        push_email(&mut storage, example_good_email()).await;
760
761        let worker = register_worker(&mut storage).await;
762
763        let job = consume_one(&mut storage, &worker).await;
764        let job_id = &job.parts.task_id;
765
766        storage
767            .kill(&worker.id(), job_id)
768            .await
769            .expect("failed to kill job");
770
771        let job = get_job(&mut storage, job_id).await;
772        let ctx = job.parts.context;
773        assert_eq!(*ctx.status(), State::Killed);
774        assert!(ctx.done_at().is_some());
775    }
776
777    #[tokio::test]
778    async fn test_heartbeat_renqueueorphaned_pulse_last_seen_6min() {
779        let mut storage = setup().await;
780
781        push_email(&mut storage, example_good_email()).await;
782
783        let six_minutes_ago = Utc::now() - Duration::from_secs(6 * 60);
784
785        let five_minutes_ago = Utc::now() - Duration::from_secs(5 * 60);
786        let worker = register_worker_at(&mut storage, six_minutes_ago.timestamp()).await;
787
788        let job = consume_one(&mut storage, &worker).await;
789        let job_id = &job.parts.task_id;
790        storage
791            .reenqueue_orphaned(1, five_minutes_ago)
792            .await
793            .expect("failed to heartbeat");
794        let job = get_job(&mut storage, job_id).await;
795        let ctx = &job.parts.context;
796        assert_eq!(*ctx.status(), State::Pending);
797        assert!(ctx.done_at().is_none());
798        assert!(ctx.lock_by().is_none());
799        assert!(ctx.lock_at().is_none());
800        assert_eq!(*ctx.last_error(), Some("Job was abandoned".to_owned()));
801        assert_eq!(job.parts.attempt.current(), 1);
802    }
803
804    #[tokio::test]
805    async fn test_heartbeat_renqueueorphaned_pulse_last_seen_4min() {
806        let mut storage = setup().await;
807
808        push_email(&mut storage, example_good_email()).await;
809
810        let six_minutes_ago = Utc::now() - Duration::from_secs(6 * 60);
811        let four_minutes_ago = Utc::now() - Duration::from_secs(4 * 60);
812        let worker = register_worker_at(&mut storage, four_minutes_ago.timestamp()).await;
813
814        let job = consume_one(&mut storage, &worker).await;
815        let job_id = &job.parts.task_id;
816        storage
817            .reenqueue_orphaned(1, six_minutes_ago)
818            .await
819            .expect("failed to heartbeat");
820
821        let job = get_job(&mut storage, job_id).await;
822        let ctx = &job.parts.context;
823        assert_eq!(*ctx.status(), State::Running);
824        assert_eq!(*ctx.lock_by(), Some(worker.id().clone()));
825        assert!(ctx.lock_at().is_some());
826        assert_eq!(*ctx.last_error(), None);
827        assert_eq!(job.parts.attempt.current(), 1);
828    }
829}