apalis_sql/
sqlite.rs

1use crate::context::SqlContext;
2use crate::{calculate_status, Config, SqlError};
3use apalis_core::backend::{BackendExpose, Stat, WorkerState};
4use apalis_core::codec::json::JsonCodec;
5use apalis_core::error::Error;
6use apalis_core::layers::{Ack, AckLayer};
7use apalis_core::poller::controller::Controller;
8use apalis_core::poller::stream::BackendStream;
9use apalis_core::poller::Poller;
10use apalis_core::request::{Parts, Request, RequestStream, State};
11use apalis_core::response::Response;
12use apalis_core::storage::Storage;
13use apalis_core::task::namespace::Namespace;
14use apalis_core::task::task_id::TaskId;
15use apalis_core::worker::{Context, Event, Worker, WorkerId};
16use apalis_core::{backend::Backend, codec::Codec};
17use async_stream::try_stream;
18use chrono::{DateTime, Utc};
19use futures::{FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt};
20use log::error;
21use serde::{de::DeserializeOwned, Serialize};
22use sqlx::{Pool, Row, Sqlite};
23use std::any::type_name;
24use std::convert::TryInto;
25use std::fmt::Debug;
26use std::sync::Arc;
27use std::{fmt, io};
28use std::{marker::PhantomData, time::Duration};
29
30use crate::from_row::SqlRequest;
31
32pub use sqlx::sqlite::SqlitePool;
33
34/// Represents a [Storage] that persists to Sqlite
35// #[derive(Debug)]
36pub struct SqliteStorage<T, C = JsonCodec<String>> {
37    pool: Pool<Sqlite>,
38    job_type: PhantomData<T>,
39    controller: Controller,
40    config: Config,
41    codec: PhantomData<C>,
42}
43
44impl<T, C> fmt::Debug for SqliteStorage<T, C> {
45    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
46        f.debug_struct("SqliteStorage")
47            .field("pool", &self.pool)
48            .field("job_type", &"PhantomData<T>")
49            .field("controller", &self.controller)
50            .field("config", &self.config)
51            .field("codec", &std::any::type_name::<C>())
52            .finish()
53    }
54}
55
56impl<T, C> Clone for SqliteStorage<T, C> {
57    fn clone(&self) -> Self {
58        SqliteStorage {
59            pool: self.pool.clone(),
60            job_type: PhantomData,
61            controller: self.controller.clone(),
62            config: self.config.clone(),
63            codec: self.codec,
64        }
65    }
66}
67
68impl SqliteStorage<()> {
69    /// Perform migrations for storage
70    #[cfg(feature = "migrate")]
71    pub async fn setup(pool: &Pool<Sqlite>) -> Result<(), sqlx::Error> {
72        sqlx::query("PRAGMA journal_mode = 'WAL';")
73            .execute(pool)
74            .await?;
75        sqlx::query("PRAGMA temp_store = 2;").execute(pool).await?;
76        sqlx::query("PRAGMA synchronous = NORMAL;")
77            .execute(pool)
78            .await?;
79        sqlx::query("PRAGMA cache_size = 64000;")
80            .execute(pool)
81            .await?;
82        Self::migrations().run(pool).await?;
83        Ok(())
84    }
85
86    /// Get sqlite migrations without running them
87    #[cfg(feature = "migrate")]
88    pub fn migrations() -> sqlx::migrate::Migrator {
89        sqlx::migrate!("migrations/sqlite")
90    }
91}
92
93impl<T> SqliteStorage<T> {
94    /// Create a new instance
95    pub fn new(pool: SqlitePool) -> Self {
96        Self {
97            pool,
98            job_type: PhantomData,
99            controller: Controller::new(),
100            config: Config::new(type_name::<T>()),
101            codec: PhantomData,
102        }
103    }
104
105    /// Create a new instance with a custom config
106    pub fn new_with_config(pool: SqlitePool, config: Config) -> Self {
107        Self {
108            pool,
109            job_type: PhantomData,
110            controller: Controller::new(),
111            config,
112            codec: PhantomData,
113        }
114    }
115}
116impl<T, C> SqliteStorage<T, C> {
117    /// Keeps a storage notified that the worker is still alive manually
118    pub async fn keep_alive_at(
119        &mut self,
120        worker: &Worker<Context>,
121        last_seen: i64,
122    ) -> Result<(), sqlx::Error> {
123        let worker_type = self.config.namespace.clone();
124        let storage_name = std::any::type_name::<Self>();
125        let query = "INSERT INTO Workers (id, worker_type, storage_name, layers, last_seen)
126                VALUES ($1, $2, $3, $4, $5)
127                ON CONFLICT (id) DO
128                   UPDATE SET last_seen = EXCLUDED.last_seen";
129        sqlx::query(query)
130            .bind(worker.id().to_string())
131            .bind(worker_type)
132            .bind(storage_name)
133            .bind(worker.get_service())
134            .bind(last_seen)
135            .execute(&self.pool)
136            .await?;
137        Ok(())
138    }
139
140    /// Expose the pool for other functionality, eg custom migrations
141    pub fn pool(&self) -> &Pool<Sqlite> {
142        &self.pool
143    }
144
145    /// Get the config used by the storage
146    pub fn get_config(&self) -> &Config {
147        &self.config
148    }
149}
150
151impl<T, C> SqliteStorage<T, C> {
152    /// Expose the code used
153    pub fn codec(&self) -> &PhantomData<C> {
154        &self.codec
155    }
156}
157
158async fn fetch_next(
159    pool: &Pool<Sqlite>,
160    worker_id: &WorkerId,
161    id: String,
162    config: &Config,
163) -> Result<Option<SqlRequest<String>>, sqlx::Error> {
164    let now: i64 = Utc::now().timestamp();
165    let update_query = "UPDATE Jobs SET status = 'Running', lock_by = ?2, lock_at = ?3 WHERE id = ?1 AND job_type = ?4 AND status = 'Pending' AND lock_by IS NULL; Select * from Jobs where id = ?1 AND lock_by = ?2 AND job_type = ?4";
166    let job: Option<SqlRequest<String>> = sqlx::query_as(update_query)
167        .bind(id.to_string())
168        .bind(worker_id.to_string())
169        .bind(now)
170        .bind(config.namespace.clone())
171        .fetch_optional(pool)
172        .await?;
173
174    Ok(job)
175}
176
177impl<T, C> SqliteStorage<T, C>
178where
179    T: DeserializeOwned + Send + Unpin,
180    C: Codec<Compact = String>,
181{
182    fn stream_jobs(
183        &self,
184        worker: &Worker<Context>,
185        interval: Duration,
186        buffer_size: usize,
187    ) -> impl Stream<Item = Result<Option<Request<T, SqlContext>>, sqlx::Error>> {
188        let pool = self.pool.clone();
189        let worker = worker.clone();
190        let config = self.config.clone();
191        let namespace = Namespace(self.config.namespace.clone());
192        try_stream! {
193            loop {
194                apalis_core::sleep(interval).await;
195                if !worker.is_ready() {
196                    continue;
197                }
198                let worker_id = worker.id();
199                let tx = pool.clone();
200                let mut tx = tx.acquire().await?;
201                let job_type = &config.namespace;
202                let fetch_query = "SELECT id FROM Jobs
203                    WHERE (status = 'Pending' OR (status = 'Failed' AND attempts < max_attempts)) AND run_at < ?1 AND job_type = ?2 ORDER BY priority DESC LIMIT ?3";
204                let now: i64 = Utc::now().timestamp();
205                let ids: Vec<(String,)> = sqlx::query_as(fetch_query)
206                    .bind(now)
207                    .bind(job_type)
208                    .bind(i64::try_from(buffer_size).map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?)
209                    .fetch_all(&mut *tx)
210                    .await?;
211                for id in ids {
212                    let res = fetch_next(&pool, worker_id, id.0, &config).await?;
213                    yield match res {
214                        None => None::<Request<T, SqlContext>>,
215                        Some(job) => {
216                            let (req, parts) = job.req.take_parts();
217                            let args = C::decode(req)
218                                .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
219                            let mut req = Request::new_with_parts(args, parts);
220                            req.parts.namespace = Some(namespace.clone());
221                            Some(req)
222                        }
223                    }
224                };
225            }
226        }
227    }
228}
229
230impl<T, C> Storage for SqliteStorage<T, C>
231where
232    T: Serialize + DeserializeOwned + Send + 'static + Unpin + Sync,
233    C: Codec<Compact = String> + Send + 'static + Sync,
234    C::Error: std::error::Error + Send + Sync + 'static,
235{
236    type Job = T;
237
238    type Error = sqlx::Error;
239
240    type Context = SqlContext;
241
242    type Compact = String;
243
244    async fn push_request(
245        &mut self,
246        job: Request<Self::Job, SqlContext>,
247    ) -> Result<Parts<SqlContext>, Self::Error> {
248        let query = "INSERT INTO Jobs VALUES (?1, ?2, ?3, 'Pending', 0, ?4, strftime('%s','now'), NULL, NULL, NULL, NULL, ?5)";
249        let (task, parts) = job.take_parts();
250        let raw = C::encode(&task)
251            .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
252        let job_type = self.config.namespace.clone();
253        sqlx::query(query)
254            .bind(raw)
255            .bind(parts.task_id.to_string())
256            .bind(job_type.to_string())
257            .bind(parts.context.max_attempts())
258            .bind(parts.context.priority())
259            .execute(&self.pool)
260            .await?;
261        Ok(parts)
262    }
263
264    async fn push_raw_request(
265        &mut self,
266        job: Request<Self::Compact, SqlContext>,
267    ) -> Result<Parts<SqlContext>, Self::Error> {
268        let query = "INSERT INTO Jobs VALUES (?1, ?2, ?3, 'Pending', 0, ?4, strftime('%s','now'), NULL, NULL, NULL, NULL, ?5)";
269        let (task, parts) = job.take_parts();
270        let raw = C::encode(&task)
271            .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
272        let job_type = self.config.namespace.clone();
273        sqlx::query(query)
274            .bind(raw)
275            .bind(parts.task_id.to_string())
276            .bind(job_type.to_string())
277            .bind(parts.context.max_attempts())
278            .bind(parts.context.priority())
279            .execute(&self.pool)
280            .await?;
281        Ok(parts)
282    }
283
284    async fn schedule_request(
285        &mut self,
286        req: Request<Self::Job, SqlContext>,
287        on: i64,
288    ) -> Result<Parts<SqlContext>, Self::Error> {
289        let query =
290            "INSERT INTO Jobs VALUES (?1, ?2, ?3, 'Pending', 0, ?4, ?5, NULL, NULL, NULL, NULL, ?6)";
291        let id = &req.parts.task_id;
292        let job = C::encode(&req.args)
293            .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
294        let job_type = self.config.namespace.clone();
295        sqlx::query(query)
296            .bind(job)
297            .bind(id.to_string())
298            .bind(job_type)
299            .bind(req.parts.context.max_attempts())
300            .bind(req.parts.context.priority())
301            .bind(on)
302            .execute(&self.pool)
303            .await?;
304        Ok(req.parts)
305    }
306
307    async fn fetch_by_id(
308        &mut self,
309        job_id: &TaskId,
310    ) -> Result<Option<Request<Self::Job, SqlContext>>, Self::Error> {
311        let fetch_query = "SELECT * FROM Jobs WHERE id = ?1";
312        let res: Option<SqlRequest<String>> = sqlx::query_as(fetch_query)
313            .bind(job_id.to_string())
314            .fetch_optional(&self.pool)
315            .await?;
316        match res {
317            None => Ok(None),
318            Some(job) => Ok(Some({
319                let (req, parts) = job.req.take_parts();
320                let args = C::decode(req)
321                    .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
322
323                let mut req: Request<T, SqlContext> = Request::new_with_parts(args, parts);
324                req.parts.namespace = Some(Namespace(self.config.namespace.clone()));
325                req
326            })),
327        }
328    }
329
330    async fn len(&mut self) -> Result<i64, Self::Error> {
331        let query = "Select Count(*) as count from Jobs WHERE (status = 'Pending' OR (status = 'Failed' AND attempts < max_attempts))";
332        let record = sqlx::query(query).fetch_one(&self.pool).await?;
333        record.try_get("count")
334    }
335
336    async fn reschedule(
337        &mut self,
338        job: Request<T, SqlContext>,
339        wait: Duration,
340    ) -> Result<(), Self::Error> {
341        let task_id = job.parts.task_id;
342
343        let wait: i64 = wait
344            .as_secs()
345            .try_into()
346            .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
347
348        let mut tx = self.pool.acquire().await?;
349        let query =
350                "UPDATE Jobs SET status = 'Failed', done_at = NULL, lock_by = NULL, lock_at = NULL, run_at = ?2 WHERE id = ?1";
351        let now: i64 = Utc::now().timestamp();
352        let wait_until = now + wait;
353
354        sqlx::query(query)
355            .bind(task_id.to_string())
356            .bind(wait_until)
357            .execute(&mut *tx)
358            .await?;
359        Ok(())
360    }
361
362    async fn update(&mut self, job: Request<Self::Job, SqlContext>) -> Result<(), Self::Error> {
363        let ctx = job.parts.context;
364        let status = ctx.status().to_string();
365        let attempts = job.parts.attempt;
366        let done_at = *ctx.done_at();
367        let lock_by = ctx.lock_by().clone();
368        let lock_at = *ctx.lock_at();
369        let last_error = ctx.last_error().clone();
370        let priority = *ctx.priority();
371        let job_id = job.parts.task_id;
372        let mut tx = self.pool.acquire().await?;
373        let query =
374                "UPDATE Jobs SET status = ?1, attempts = ?2, done_at = ?3, lock_by = ?4, lock_at = ?5, last_error = ?6, priority = ?7 WHERE id = ?8";
375        sqlx::query(query)
376            .bind(status.to_owned())
377            .bind::<i64>(
378                attempts
379                    .current()
380                    .try_into()
381                    .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?,
382            )
383            .bind(done_at)
384            .bind(lock_by.map(|w| w.name().to_string()))
385            .bind(lock_at)
386            .bind(last_error)
387            .bind(priority)
388            .bind(job_id.to_string())
389            .execute(&mut *tx)
390            .await?;
391        Ok(())
392    }
393
394    async fn is_empty(&mut self) -> Result<bool, Self::Error> {
395        self.len().map_ok(|c| c == 0).await
396    }
397
398    async fn vacuum(&mut self) -> Result<usize, sqlx::Error> {
399        let query = "Delete from Jobs where status='Done'";
400        let record = sqlx::query(query).execute(&self.pool).await?;
401        Ok(record.rows_affected().try_into().unwrap_or_default())
402    }
403}
404
405impl<T, C> SqliteStorage<T, C> {
406    /// Puts the job instantly back into the queue
407    /// Another Worker may consume
408    pub async fn retry(
409        &mut self,
410        worker_id: &WorkerId,
411        job_id: &TaskId,
412    ) -> Result<(), sqlx::Error> {
413        let mut tx = self.pool.acquire().await?;
414        let query =
415                "UPDATE Jobs SET status = 'Pending', done_at = NULL, lock_by = NULL WHERE id = ?1 AND lock_by = ?2";
416        sqlx::query(query)
417            .bind(job_id.to_string())
418            .bind(worker_id.to_string())
419            .execute(&mut *tx)
420            .await?;
421        Ok(())
422    }
423
424    /// Kill a job
425    pub async fn kill(&mut self, worker_id: &WorkerId, job_id: &TaskId) -> Result<(), sqlx::Error> {
426        let mut tx = self.pool.begin().await?;
427        let query =
428                "UPDATE Jobs SET status = 'Killed', done_at = strftime('%s','now') WHERE id = ?1 AND lock_by = ?2";
429        sqlx::query(query)
430            .bind(job_id.to_string())
431            .bind(worker_id.to_string())
432            .execute(&mut *tx)
433            .await?;
434        tx.commit().await?;
435        Ok(())
436    }
437
438    /// Add jobs that workers have disappeared to the queue
439    pub async fn reenqueue_orphaned(
440        &self,
441        count: i32,
442        dead_since: DateTime<Utc>,
443    ) -> Result<(), sqlx::Error> {
444        let job_type = self.config.namespace.clone();
445        let mut tx = self.pool.acquire().await?;
446        let query = r#"Update Jobs
447                            SET status = "Pending", done_at = NULL, lock_by = NULL, lock_at = NULL, attempts = attempts + 1, last_error ="Job was abandoned"
448                            WHERE id in
449                                (SELECT Jobs.id from Jobs INNER join Workers ON lock_by = Workers.id
450                                    WHERE status= "Running" AND workers.last_seen < ?1
451                                    AND Workers.worker_type = ?2 ORDER BY lock_at ASC LIMIT ?3);"#;
452
453        sqlx::query(query)
454            .bind(dead_since.timestamp())
455            .bind(job_type)
456            .bind(count)
457            .execute(&mut *tx)
458            .await?;
459        Ok(())
460    }
461}
462
463/// Errors that can occur while polling an SQLite database.
464#[derive(thiserror::Error, Debug)]
465pub enum SqlitePollError {
466    /// Error during a keep-alive heartbeat.
467    #[error("Encountered an error during KeepAlive heartbeat: `{0}`")]
468    KeepAliveError(sqlx::Error),
469
470    /// Error during re-enqueuing orphaned tasks.
471    #[error("Encountered an error during ReenqueueOrphaned heartbeat: `{0}`")]
472    ReenqueueOrphanedError(sqlx::Error),
473}
474
475impl<T, C> Backend<Request<T, SqlContext>> for SqliteStorage<T, C>
476where
477    C: Codec<Compact = String> + Send + 'static + Sync,
478    C::Error: std::error::Error + 'static + Send + Sync,
479    T: Serialize + DeserializeOwned + Sync + Send + Unpin + 'static,
480{
481    type Stream = BackendStream<RequestStream<Request<T, SqlContext>>>;
482    type Layer = AckLayer<SqliteStorage<T, C>, T, SqlContext, C>;
483
484    type Codec = JsonCodec<String>;
485
486    fn poll(mut self, worker: &Worker<Context>) -> Poller<Self::Stream, Self::Layer> {
487        let layer = AckLayer::new(self.clone());
488        let config = self.config.clone();
489        let controller = self.controller.clone();
490        let stream = self
491            .stream_jobs(worker, config.poll_interval, config.buffer_size)
492            .map_err(|e| Error::SourceError(Arc::new(Box::new(e))));
493        let stream = BackendStream::new(stream.boxed(), controller);
494        let requeue_storage = self.clone();
495        let w = worker.clone();
496        let heartbeat = async move {
497            // Lets reenqueue any jobs that belonged to this worker in case of a death
498            if let Err(e) = self
499                .reenqueue_orphaned((config.buffer_size * 10) as i32, Utc::now())
500                .await
501            {
502                w.emit(Event::Error(Box::new(
503                    SqlitePollError::ReenqueueOrphanedError(e),
504                )));
505            }
506            loop {
507                let now: i64 = Utc::now().timestamp();
508                if let Err(e) = self.keep_alive_at(&w, now).await {
509                    w.emit(Event::Error(Box::new(SqlitePollError::KeepAliveError(e))));
510                }
511                apalis_core::sleep(Duration::from_secs(30)).await;
512            }
513        }
514        .boxed();
515        let w = worker.clone();
516        let reenqueue_beat = async move {
517            loop {
518                let dead_since = Utc::now()
519                    - chrono::Duration::from_std(config.reenqueue_orphaned_after).unwrap();
520                if let Err(e) = requeue_storage
521                    .reenqueue_orphaned(
522                        config
523                            .buffer_size
524                            .try_into()
525                            .expect("could not convert usize to i32"),
526                        dead_since,
527                    )
528                    .await
529                {
530                    w.emit(Event::Error(Box::new(
531                        SqlitePollError::ReenqueueOrphanedError(e),
532                    )));
533                }
534                apalis_core::sleep(config.poll_interval).await;
535            }
536        };
537        Poller::new_with_layer(
538            stream,
539            async {
540                futures::join!(heartbeat, reenqueue_beat);
541            },
542            layer,
543        )
544    }
545}
546
547impl<T: Sync + Send, C: Send, Res: Serialize + Sync> Ack<T, Res, C> for SqliteStorage<T, C> {
548    type Context = SqlContext;
549    type AckError = sqlx::Error;
550    async fn ack(&mut self, ctx: &Self::Context, res: &Response<Res>) -> Result<(), sqlx::Error> {
551        let pool = self.pool.clone();
552        let query =
553                "UPDATE Jobs SET status = ?4, attempts = ?5, done_at = strftime('%s','now'), last_error = ?3 WHERE id = ?1 AND lock_by = ?2";
554        let result = serde_json::to_string(&res.inner.as_ref().map_err(|r| r.to_string()))
555            .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
556        sqlx::query(query)
557            .bind(res.task_id.to_string())
558            .bind(
559                ctx.lock_by()
560                    .as_ref()
561                    .expect("Task is not locked")
562                    .to_string(),
563            )
564            .bind(result)
565            .bind(calculate_status(ctx, res).to_string())
566            .bind(res.attempt.current() as u32)
567            .execute(&pool)
568            .await?;
569        Ok(())
570    }
571}
572
573impl<J: 'static + Serialize + DeserializeOwned + Unpin + Send + Sync> BackendExpose<J>
574    for SqliteStorage<J, JsonCodec<String>>
575{
576    type Request = Request<J, Parts<SqlContext>>;
577    type Error = SqlError;
578    async fn stats(&self) -> Result<Stat, Self::Error> {
579        let fetch_query = "SELECT
580                            COUNT(1) FILTER (WHERE status = 'Pending') AS pending,
581                            COUNT(1) FILTER (WHERE status = 'Running') AS running,
582                            COUNT(1) FILTER (WHERE status = 'Done') AS done,
583                            COUNT(1) FILTER (WHERE status = 'Failed') AS failed,
584                            COUNT(1) FILTER (WHERE status = 'Killed') AS killed
585                        FROM Jobs WHERE job_type = ?";
586
587        let res: (i64, i64, i64, i64, i64, i64) = sqlx::query_as(fetch_query)
588            .bind(self.get_config().namespace())
589            .fetch_one(self.pool())
590            .await?;
591
592        Ok(Stat {
593            pending: res.0.try_into()?,
594            running: res.1.try_into()?,
595            dead: res.4.try_into()?,
596            failed: res.3.try_into()?,
597            success: res.2.try_into()?,
598        })
599    }
600
601    async fn list_jobs(
602        &self,
603        status: &State,
604        page: i32,
605    ) -> Result<Vec<Self::Request>, Self::Error> {
606        let status = status.to_string();
607        let fetch_query = "SELECT * FROM Jobs WHERE status = ? AND job_type = ? ORDER BY done_at DESC, run_at DESC LIMIT 10 OFFSET ?";
608        let res: Vec<SqlRequest<String>> = sqlx::query_as(fetch_query)
609            .bind(status)
610            .bind(self.get_config().namespace())
611            .bind(((page - 1) * 10).to_string())
612            .fetch_all(self.pool())
613            .await?;
614        Ok(res
615            .into_iter()
616            .map(|j| {
617                let (req, ctx) = j.req.take_parts();
618                let req = JsonCodec::<String>::decode(req).unwrap();
619                Request::new_with_ctx(req, ctx)
620            })
621            .collect())
622    }
623
624    async fn list_workers(&self) -> Result<Vec<Worker<WorkerState>>, Self::Error> {
625        let fetch_query =
626            "SELECT id, layers, last_seen FROM Workers WHERE worker_type = ? ORDER BY last_seen DESC LIMIT 20 OFFSET ?";
627        let res: Vec<(String, String, i64)> = sqlx::query_as(fetch_query)
628            .bind(self.get_config().namespace())
629            .bind(0)
630            .fetch_all(self.pool())
631            .await?;
632        Ok(res
633            .into_iter()
634            .map(|w| Worker::new(WorkerId::new(w.0), WorkerState::new::<Self>(w.1)))
635            .collect())
636    }
637}
638
639#[cfg(test)]
640mod tests {
641
642    use crate::sql_storage_tests;
643
644    use super::*;
645    use apalis_core::request::State;
646    use chrono::Utc;
647    use email_service::example_good_email;
648    use email_service::Email;
649    use futures::StreamExt;
650
651    use apalis_core::generic_storage_test;
652    use apalis_core::test_utils::apalis_test_service_fn;
653    use apalis_core::test_utils::TestWrapper;
654
655    generic_storage_test!(setup);
656    sql_storage_tests!(setup::<Email>, SqliteStorage<Email>, Email);
657
658    /// migrate DB and return a storage instance.
659    async fn setup<T: Serialize + DeserializeOwned>() -> SqliteStorage<T> {
660        // Because connections cannot be shared across async runtime
661        // (different runtimes are created for each test),
662        // we don't share the storage and tests must be run sequentially.
663        let pool = SqlitePool::connect("sqlite::memory:").await.unwrap();
664        SqliteStorage::setup(&pool)
665            .await
666            .expect("failed to migrate DB");
667        let config = Config::new("apalis::test");
668        let storage = SqliteStorage::<T>::new_with_config(pool, config);
669
670        storage
671    }
672
673    #[tokio::test]
674    async fn test_inmemory_sqlite_worker() {
675        let mut sqlite = setup().await;
676        sqlite
677            .push(Email {
678                subject: "Test Subject".to_string(),
679                to: "example@sqlite".to_string(),
680                text: "Some Text".to_string(),
681            })
682            .await
683            .expect("Unable to push job");
684        let len = sqlite.len().await.expect("Could not fetch the jobs count");
685        assert_eq!(len, 1);
686    }
687
688    async fn consume_one(
689        storage: &mut SqliteStorage<Email>,
690        worker: &Worker<Context>,
691    ) -> Request<Email, SqlContext> {
692        let mut stream = storage
693            .stream_jobs(worker, std::time::Duration::from_secs(10), 1)
694            .boxed();
695        stream
696            .next()
697            .await
698            .expect("stream is empty")
699            .expect("failed to poll job")
700            .expect("no job is pending")
701    }
702
703    async fn register_worker_at(
704        storage: &mut SqliteStorage<Email>,
705        last_seen: i64,
706    ) -> Worker<Context> {
707        let worker_id = WorkerId::new("test-worker");
708
709        let worker = Worker::new(worker_id, Default::default());
710        storage
711            .keep_alive_at(&worker, last_seen)
712            .await
713            .expect("failed to register worker");
714        worker.start();
715        worker
716    }
717
718    async fn register_worker(storage: &mut SqliteStorage<Email>) -> Worker<Context> {
719        register_worker_at(storage, Utc::now().timestamp()).await
720    }
721
722    async fn push_email(storage: &mut SqliteStorage<Email>, email: Email) {
723        storage.push(email).await.expect("failed to push a job");
724    }
725
726    async fn get_job(
727        storage: &mut SqliteStorage<Email>,
728        job_id: &TaskId,
729    ) -> Request<Email, SqlContext> {
730        storage
731            .fetch_by_id(job_id)
732            .await
733            .expect("failed to fetch job by id")
734            .expect("no job found by id")
735    }
736
737    #[tokio::test]
738    async fn test_consume_last_pushed_job() {
739        let mut storage = setup().await;
740        let worker = register_worker(&mut storage).await;
741
742        push_email(&mut storage, example_good_email()).await;
743        let len = storage.len().await.expect("Could not fetch the jobs count");
744        assert_eq!(len, 1);
745
746        let job = consume_one(&mut storage, &worker).await;
747        let ctx = job.parts.context;
748        assert_eq!(*ctx.status(), State::Running);
749        assert_eq!(*ctx.lock_by(), Some(worker.id().clone()));
750        assert!(ctx.lock_at().is_some());
751    }
752
753    #[tokio::test]
754    async fn test_acknowledge_job() {
755        let mut storage = setup().await;
756        let worker = register_worker(&mut storage).await;
757
758        push_email(&mut storage, example_good_email()).await;
759        let job = consume_one(&mut storage, &worker).await;
760        let job_id = &job.parts.task_id;
761        let ctx = &job.parts.context;
762        let res = 1usize;
763        storage
764            .ack(
765                ctx,
766                &Response::success(res, job_id.clone(), job.parts.attempt.clone()),
767            )
768            .await
769            .expect("failed to acknowledge the job");
770
771        let job = get_job(&mut storage, job_id).await;
772        let ctx = job.parts.context;
773        assert_eq!(*ctx.status(), State::Done);
774        assert!(ctx.done_at().is_some());
775    }
776
777    #[tokio::test]
778    async fn test_kill_job() {
779        let mut storage = setup().await;
780
781        push_email(&mut storage, example_good_email()).await;
782
783        let worker = register_worker(&mut storage).await;
784
785        let job = consume_one(&mut storage, &worker).await;
786        let job_id = &job.parts.task_id;
787
788        storage
789            .kill(&worker.id(), job_id)
790            .await
791            .expect("failed to kill job");
792
793        let job = get_job(&mut storage, job_id).await;
794        let ctx = job.parts.context;
795        assert_eq!(*ctx.status(), State::Killed);
796        assert!(ctx.done_at().is_some());
797    }
798
799    #[tokio::test]
800    async fn test_heartbeat_renqueueorphaned_pulse_last_seen_6min() {
801        let mut storage = setup().await;
802
803        push_email(&mut storage, example_good_email()).await;
804
805        let six_minutes_ago = Utc::now() - Duration::from_secs(6 * 60);
806
807        let five_minutes_ago = Utc::now() - Duration::from_secs(5 * 60);
808        let worker = register_worker_at(&mut storage, six_minutes_ago.timestamp()).await;
809
810        let job = consume_one(&mut storage, &worker).await;
811        let job_id = &job.parts.task_id;
812        storage
813            .reenqueue_orphaned(1, five_minutes_ago)
814            .await
815            .expect("failed to heartbeat");
816        let job = get_job(&mut storage, job_id).await;
817        let ctx = &job.parts.context;
818        assert_eq!(*ctx.status(), State::Pending);
819        assert!(ctx.done_at().is_none());
820        assert!(ctx.lock_by().is_none());
821        assert!(ctx.lock_at().is_none());
822        assert_eq!(*ctx.last_error(), Some("Job was abandoned".to_owned()));
823        assert_eq!(job.parts.attempt.current(), 1);
824
825        let job = consume_one(&mut storage, &worker).await;
826        let ctx = &job.parts.context;
827        // Simulate worker
828        job.parts.attempt.increment();
829        storage
830            .ack(
831                ctx,
832                &Response::new(Ok("success".to_owned()), job_id.clone(), job.parts.attempt),
833            )
834            .await
835            .unwrap();
836        //end simulate worker
837
838        let job = get_job(&mut storage, &job_id).await;
839        let ctx = &job.parts.context;
840        assert_eq!(*ctx.status(), State::Done);
841        assert_eq!(*ctx.lock_by(), Some(worker.id().clone()));
842        assert!(ctx.lock_at().is_some());
843        assert_eq!(*ctx.last_error(), Some("{\"Ok\":\"success\"}".to_owned()));
844        assert_eq!(job.parts.attempt.current(), 2);
845    }
846
847    #[tokio::test]
848    async fn test_heartbeat_renqueueorphaned_pulse_last_seen_4min() {
849        let mut storage = setup().await;
850
851        push_email(&mut storage, example_good_email()).await;
852
853        let six_minutes_ago = Utc::now() - Duration::from_secs(6 * 60);
854        let four_minutes_ago = Utc::now() - Duration::from_secs(4 * 60);
855        let worker = register_worker_at(&mut storage, four_minutes_ago.timestamp()).await;
856
857        let job = consume_one(&mut storage, &worker).await;
858        let job_id = job.parts.task_id;
859        storage
860            .reenqueue_orphaned(1, six_minutes_ago)
861            .await
862            .expect("failed to heartbeat");
863
864        let job = get_job(&mut storage, &job_id).await;
865        let ctx = &job.parts.context;
866
867        // Simulate worker
868        job.parts.attempt.increment();
869        storage
870            .ack(
871                ctx,
872                &Response::new(Ok("success".to_owned()), job_id.clone(), job.parts.attempt),
873            )
874            .await
875            .unwrap();
876        //end simulate worker
877
878        let job = get_job(&mut storage, &job_id).await;
879        let ctx = &job.parts.context;
880        assert_eq!(*ctx.status(), State::Done);
881        assert_eq!(*ctx.lock_by(), Some(worker.id().clone()));
882        assert!(ctx.lock_at().is_some());
883        assert_eq!(*ctx.last_error(), Some("{\"Ok\":\"success\"}".to_owned()));
884        assert_eq!(job.parts.attempt.current(), 1);
885    }
886}