apalis_sql/
sqlite.rs

1use crate::context::SqlContext;
2use crate::{calculate_status, Config, SqlError};
3use apalis_core::backend::{BackendExpose, Stat, WorkerState};
4use apalis_core::codec::json::JsonCodec;
5use apalis_core::error::Error;
6use apalis_core::layers::{Ack, AckLayer};
7use apalis_core::poller::controller::Controller;
8use apalis_core::poller::stream::BackendStream;
9use apalis_core::poller::Poller;
10use apalis_core::request::{Parts, Request, RequestStream, State};
11use apalis_core::response::Response;
12use apalis_core::storage::Storage;
13use apalis_core::task::namespace::Namespace;
14use apalis_core::task::task_id::TaskId;
15use apalis_core::worker::{Context, Event, Worker, WorkerId};
16use apalis_core::{backend::Backend, codec::Codec};
17use async_stream::try_stream;
18use chrono::{DateTime, Utc};
19use futures::{FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt};
20use log::error;
21use serde::{de::DeserializeOwned, Serialize};
22use sqlx::{Pool, Row, Sqlite};
23use std::any::type_name;
24use std::convert::TryInto;
25use std::fmt::Debug;
26use std::sync::Arc;
27use std::{fmt, io};
28use std::{marker::PhantomData, time::Duration};
29
30use crate::from_row::SqlRequest;
31
32pub use sqlx::sqlite::SqlitePool;
33
34/// Represents a [Storage] that persists to Sqlite
35// #[derive(Debug)]
36pub struct SqliteStorage<T, C = JsonCodec<String>> {
37    pool: Pool<Sqlite>,
38    job_type: PhantomData<T>,
39    controller: Controller,
40    config: Config,
41    codec: PhantomData<C>,
42}
43
44impl<T, C> fmt::Debug for SqliteStorage<T, C> {
45    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
46        f.debug_struct("SqliteStorage")
47            .field("pool", &self.pool)
48            .field("job_type", &"PhantomData<T>")
49            .field("controller", &self.controller)
50            .field("config", &self.config)
51            .field("codec", &std::any::type_name::<C>())
52            .finish()
53    }
54}
55
56impl<T, C> Clone for SqliteStorage<T, C> {
57    fn clone(&self) -> Self {
58        SqliteStorage {
59            pool: self.pool.clone(),
60            job_type: PhantomData,
61            controller: self.controller.clone(),
62            config: self.config.clone(),
63            codec: self.codec,
64        }
65    }
66}
67
68impl SqliteStorage<()> {
69    /// Perform migrations for storage
70    #[cfg(feature = "migrate")]
71    pub async fn setup(pool: &Pool<Sqlite>) -> Result<(), sqlx::Error> {
72        sqlx::query("PRAGMA journal_mode = 'WAL';")
73            .execute(pool)
74            .await?;
75        sqlx::query("PRAGMA temp_store = 2;").execute(pool).await?;
76        sqlx::query("PRAGMA synchronous = NORMAL;")
77            .execute(pool)
78            .await?;
79        sqlx::query("PRAGMA cache_size = 64000;")
80            .execute(pool)
81            .await?;
82        Self::migrations().run(pool).await?;
83        Ok(())
84    }
85
86    /// Get sqlite migrations without running them
87    #[cfg(feature = "migrate")]
88    pub fn migrations() -> sqlx::migrate::Migrator {
89        sqlx::migrate!("migrations/sqlite")
90    }
91}
92
93impl<T> SqliteStorage<T> {
94    /// Create a new instance
95    pub fn new(pool: SqlitePool) -> Self {
96        Self {
97            pool,
98            job_type: PhantomData,
99            controller: Controller::new(),
100            config: Config::new(type_name::<T>()),
101            codec: PhantomData,
102        }
103    }
104
105    /// Create a new instance with a custom config
106    pub fn new_with_config(pool: SqlitePool, config: Config) -> Self {
107        Self {
108            pool,
109            job_type: PhantomData,
110            controller: Controller::new(),
111            config,
112            codec: PhantomData,
113        }
114    }
115}
116impl<T, C> SqliteStorage<T, C> {
117    /// Keeps a storage notified that the worker is still alive manually
118    pub async fn keep_alive_at(
119        &mut self,
120        worker: &Worker<Context>,
121        last_seen: i64,
122    ) -> Result<(), sqlx::Error> {
123        let worker_type = self.config.namespace.clone();
124        let storage_name = std::any::type_name::<Self>();
125        let query = "INSERT INTO Workers (id, worker_type, storage_name, layers, last_seen)
126                VALUES ($1, $2, $3, $4, $5)
127                ON CONFLICT (id) DO
128                   UPDATE SET last_seen = EXCLUDED.last_seen";
129        sqlx::query(query)
130            .bind(worker.id().to_string())
131            .bind(worker_type)
132            .bind(storage_name)
133            .bind(worker.get_service())
134            .bind(last_seen)
135            .execute(&self.pool)
136            .await?;
137        Ok(())
138    }
139
140    /// Expose the pool for other functionality, eg custom migrations
141    pub fn pool(&self) -> &Pool<Sqlite> {
142        &self.pool
143    }
144
145    /// Get the config used by the storage
146    pub fn get_config(&self) -> &Config {
147        &self.config
148    }
149}
150
151impl<T, C> SqliteStorage<T, C> {
152    /// Expose the code used
153    pub fn codec(&self) -> &PhantomData<C> {
154        &self.codec
155    }
156}
157
158async fn fetch_next(
159    pool: &Pool<Sqlite>,
160    worker_id: &WorkerId,
161    id: String,
162    config: &Config,
163) -> Result<Option<SqlRequest<String>>, sqlx::Error> {
164    let now: i64 = Utc::now().timestamp();
165    let update_query = "UPDATE Jobs SET status = 'Running', lock_by = ?2, lock_at = ?3 WHERE id = ?1 AND job_type = ?4 AND status = 'Pending' AND lock_by IS NULL RETURNING *";
166    let job: Option<SqlRequest<String>> = sqlx::query_as(update_query)
167        .bind(id.to_string())
168        .bind(worker_id.to_string())
169        .bind(now)
170        .bind(config.namespace.clone())
171        .fetch_optional(pool)
172        .await?;
173
174    if job.is_none() {
175        // Retry path.
176        let job: Option<SqlRequest<String>> =
177            sqlx::query_as("SELECT * FROM Jobs WHERE id = ?1 AND lock_by = ?2 AND job_type = ?3")
178                .bind(id.to_string())
179                .bind(worker_id.to_string())
180                .bind(config.namespace.clone())
181                .fetch_optional(pool)
182                .await?;
183
184        return Ok(job);
185    }
186
187    Ok(job)
188}
189
190impl<T, C> SqliteStorage<T, C>
191where
192    T: DeserializeOwned + Send + Unpin,
193    C: Codec<Compact = String>,
194{
195    fn stream_jobs(
196        &self,
197        worker: &Worker<Context>,
198        interval: Duration,
199        buffer_size: usize,
200    ) -> impl Stream<Item = Result<Option<Request<T, SqlContext>>, sqlx::Error>> {
201        let pool = self.pool.clone();
202        let worker = worker.clone();
203        let config = self.config.clone();
204        let namespace = Namespace(self.config.namespace.clone());
205        try_stream! {
206            loop {
207                apalis_core::sleep(interval).await;
208                if !worker.is_ready() {
209                    continue;
210                }
211                let worker_id = worker.id();
212                let job_type = &config.namespace;
213                let fetch_query = "SELECT id FROM Jobs
214                    WHERE (status = 'Pending' OR (status = 'Failed' AND attempts < max_attempts)) AND run_at < ?1 AND job_type = ?2 ORDER BY priority DESC LIMIT ?3";
215                let now: i64 = Utc::now().timestamp();
216                let ids: Vec<(String,)> = sqlx::query_as(fetch_query)
217                    .bind(now)
218                    .bind(job_type)
219                    .bind(i64::try_from(buffer_size).map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?)
220                    .fetch_all(&pool)
221                    .await?;
222
223                for id in ids {
224                    let res = fetch_next(&pool, worker_id, id.0, &config).await?;
225                    yield match res {
226                        None => None::<Request<T, SqlContext>>,
227                        Some(job) => {
228                            let (req, parts) = job.req.take_parts();
229                            let args = C::decode(req)
230                                .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
231                            let mut req = Request::new_with_parts(args, parts);
232                            req.parts.namespace = Some(namespace.clone());
233                            Some(req)
234                        }
235                    }
236                };
237            }
238        }
239    }
240}
241
242impl<T, C> Storage for SqliteStorage<T, C>
243where
244    T: Serialize + DeserializeOwned + Send + 'static + Unpin + Sync,
245    C: Codec<Compact = String> + Send + 'static + Sync,
246    C::Error: std::error::Error + Send + Sync + 'static,
247{
248    type Job = T;
249
250    type Error = sqlx::Error;
251
252    type Context = SqlContext;
253
254    type Compact = String;
255
256    async fn push_request(
257        &mut self,
258        job: Request<Self::Job, SqlContext>,
259    ) -> Result<Parts<SqlContext>, Self::Error> {
260        let query = "INSERT INTO Jobs VALUES (?1, ?2, ?3, 'Pending', 0, ?4, strftime('%s','now'), NULL, NULL, NULL, NULL, ?5)";
261        let (task, parts) = job.take_parts();
262        let raw = C::encode(&task)
263            .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
264        let job_type = self.config.namespace.clone();
265        sqlx::query(query)
266            .bind(raw)
267            .bind(parts.task_id.to_string())
268            .bind(job_type.to_string())
269            .bind(parts.context.max_attempts())
270            .bind(parts.context.priority())
271            .execute(&self.pool)
272            .await?;
273        Ok(parts)
274    }
275
276    async fn push_raw_request(
277        &mut self,
278        job: Request<Self::Compact, SqlContext>,
279    ) -> Result<Parts<SqlContext>, Self::Error> {
280        let query = "INSERT INTO Jobs VALUES (?1, ?2, ?3, 'Pending', 0, ?4, strftime('%s','now'), NULL, NULL, NULL, NULL, ?5)";
281        let (task, parts) = job.take_parts();
282        let raw = C::encode(&task)
283            .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
284        let job_type = self.config.namespace.clone();
285        sqlx::query(query)
286            .bind(raw)
287            .bind(parts.task_id.to_string())
288            .bind(job_type.to_string())
289            .bind(parts.context.max_attempts())
290            .bind(parts.context.priority())
291            .execute(&self.pool)
292            .await?;
293        Ok(parts)
294    }
295
296    async fn schedule_request(
297        &mut self,
298        req: Request<Self::Job, SqlContext>,
299        on: i64,
300    ) -> Result<Parts<SqlContext>, Self::Error> {
301        let query =
302            "INSERT INTO Jobs VALUES (?1, ?2, ?3, 'Pending', 0, ?4, ?5, NULL, NULL, NULL, NULL, ?6)";
303        let id = &req.parts.task_id;
304        let job = C::encode(&req.args)
305            .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
306        let job_type = self.config.namespace.clone();
307        sqlx::query(query)
308            .bind(job)
309            .bind(id.to_string())
310            .bind(job_type)
311            .bind(req.parts.context.max_attempts())
312            .bind(on)
313            .bind(req.parts.context.priority())
314            .execute(&self.pool)
315            .await?;
316        Ok(req.parts)
317    }
318
319    async fn fetch_by_id(
320        &mut self,
321        job_id: &TaskId,
322    ) -> Result<Option<Request<Self::Job, SqlContext>>, Self::Error> {
323        let fetch_query = "SELECT * FROM Jobs WHERE id = ?1";
324        let res: Option<SqlRequest<String>> = sqlx::query_as(fetch_query)
325            .bind(job_id.to_string())
326            .fetch_optional(&self.pool)
327            .await?;
328        match res {
329            None => Ok(None),
330            Some(job) => Ok(Some({
331                let (req, parts) = job.req.take_parts();
332                let args = C::decode(req)
333                    .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
334
335                let mut req: Request<T, SqlContext> = Request::new_with_parts(args, parts);
336                req.parts.namespace = Some(Namespace(self.config.namespace.clone()));
337                req
338            })),
339        }
340    }
341
342    async fn len(&mut self) -> Result<i64, Self::Error> {
343        let query = "Select COUNT(*) AS count FROM Jobs WHERE (status = 'Pending' OR (status = 'Failed' AND attempts < max_attempts))";
344        let record = sqlx::query(query).fetch_one(&self.pool).await?;
345        record.try_get("count")
346    }
347
348    async fn reschedule(
349        &mut self,
350        job: Request<T, SqlContext>,
351        wait: Duration,
352    ) -> Result<(), Self::Error> {
353        let task_id = job.parts.task_id;
354
355        let wait: i64 = wait
356            .as_secs()
357            .try_into()
358            .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
359
360        let query =
361                "UPDATE Jobs SET status = 'Failed', done_at = NULL, lock_by = NULL, lock_at = NULL, run_at = ?2 WHERE id = ?1";
362        let now: i64 = Utc::now().timestamp();
363        let wait_until = now + wait;
364
365        sqlx::query(query)
366            .bind(task_id.to_string())
367            .bind(wait_until)
368            .execute(self.pool())
369            .await?;
370        Ok(())
371    }
372
373    async fn update(&mut self, job: Request<Self::Job, SqlContext>) -> Result<(), Self::Error> {
374        let ctx = job.parts.context;
375        let status = ctx.status().to_string();
376        let attempts = job.parts.attempt;
377        let done_at = *ctx.done_at();
378        let lock_by = ctx.lock_by().clone();
379        let lock_at = *ctx.lock_at();
380        let last_error = ctx.last_error().clone();
381        let priority = *ctx.priority();
382        let job_id = job.parts.task_id;
383        let query =
384                "UPDATE Jobs SET status = ?1, attempts = ?2, done_at = ?3, lock_by = ?4, lock_at = ?5, last_error = ?6, priority = ?7 WHERE id = ?8";
385        sqlx::query(query)
386            .bind(status.to_owned())
387            .bind::<i64>(
388                attempts
389                    .current()
390                    .try_into()
391                    .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?,
392            )
393            .bind(done_at)
394            .bind(lock_by.map(|w| w.name().to_string()))
395            .bind(lock_at)
396            .bind(last_error)
397            .bind(priority)
398            .bind(job_id.to_string())
399            .execute(self.pool())
400            .await?;
401        Ok(())
402    }
403
404    async fn is_empty(&mut self) -> Result<bool, Self::Error> {
405        self.len().map_ok(|c| c == 0).await
406    }
407
408    async fn vacuum(&mut self) -> Result<usize, sqlx::Error> {
409        let query = "Delete from Jobs where status='Done'";
410        let record = sqlx::query(query).execute(&self.pool).await?;
411        Ok(record.rows_affected().try_into().unwrap_or_default())
412    }
413}
414
415impl<T, C> SqliteStorage<T, C> {
416    /// Puts the job instantly back into the queue
417    /// Another Worker may consume
418    pub async fn retry(
419        &mut self,
420        worker_id: &WorkerId,
421        job_id: &TaskId,
422    ) -> Result<(), sqlx::Error> {
423        let query =
424                "UPDATE Jobs SET status = 'Pending', done_at = NULL, lock_by = NULL WHERE id = ?1 AND lock_by = ?2";
425        sqlx::query(query)
426            .bind(job_id.to_string())
427            .bind(worker_id.to_string())
428            .execute(self.pool())
429            .await?;
430        Ok(())
431    }
432
433    /// Kill a job
434    pub async fn kill(&mut self, worker_id: &WorkerId, job_id: &TaskId) -> Result<(), sqlx::Error> {
435        let query =
436                "UPDATE Jobs SET status = 'Killed', done_at = strftime('%s','now') WHERE id = ?1 AND lock_by = ?2";
437        sqlx::query(query)
438            .bind(job_id.to_string())
439            .bind(worker_id.to_string())
440            .execute(self.pool())
441            .await?;
442        Ok(())
443    }
444
445    /// Add jobs that workers have disappeared to the queue
446    pub async fn reenqueue_orphaned(
447        &self,
448        count: i32,
449        dead_since: DateTime<Utc>,
450    ) -> Result<(), sqlx::Error> {
451        let job_type = self.config.namespace.clone();
452        let query = r#"Update Jobs
453                            SET status = "Pending", done_at = NULL, lock_by = NULL, lock_at = NULL, attempts = attempts + 1, last_error ="Job was abandoned"
454                            WHERE id in
455                                (SELECT Jobs.id from Jobs INNER join Workers ON lock_by = Workers.id
456                                    WHERE status= "Running" AND workers.last_seen < ?1
457                                    AND Workers.worker_type = ?2 ORDER BY lock_at ASC LIMIT ?3);"#;
458
459        sqlx::query(query)
460            .bind(dead_since.timestamp())
461            .bind(job_type)
462            .bind(count)
463            .execute(self.pool())
464            .await?;
465        Ok(())
466    }
467}
468
469/// Errors that can occur while polling an SQLite database.
470#[derive(thiserror::Error, Debug)]
471pub enum SqlitePollError {
472    /// Error during a keep-alive heartbeat.
473    #[error("Encountered an error during KeepAlive heartbeat: `{0}`")]
474    KeepAliveError(sqlx::Error),
475
476    /// Error during re-enqueuing orphaned tasks.
477    #[error("Encountered an error during ReenqueueOrphaned heartbeat: `{0}`")]
478    ReenqueueOrphanedError(sqlx::Error),
479}
480
481impl<T, C> Backend<Request<T, SqlContext>> for SqliteStorage<T, C>
482where
483    C: Codec<Compact = String> + Send + 'static + Sync,
484    C::Error: std::error::Error + 'static + Send + Sync,
485    T: Serialize + DeserializeOwned + Sync + Send + Unpin + 'static,
486{
487    type Stream = BackendStream<RequestStream<Request<T, SqlContext>>>;
488    type Layer = AckLayer<SqliteStorage<T, C>, T, SqlContext, C>;
489
490    type Codec = JsonCodec<String>;
491
492    fn poll(mut self, worker: &Worker<Context>) -> Poller<Self::Stream, Self::Layer> {
493        let layer = AckLayer::new(self.clone());
494        let config = self.config.clone();
495        let controller = self.controller.clone();
496        let stream = self
497            .stream_jobs(worker, config.poll_interval, config.buffer_size)
498            .map_err(|e| Error::SourceError(Arc::new(Box::new(e))));
499        let stream = BackendStream::new(stream.boxed(), controller);
500        let requeue_storage = self.clone();
501        let w = worker.clone();
502        let heartbeat = async move {
503            // Lets reenqueue any jobs that belonged to this worker in case of a death
504            if let Err(e) = self
505                .reenqueue_orphaned((config.buffer_size * 10) as i32, Utc::now())
506                .await
507            {
508                w.emit(Event::Error(Box::new(
509                    SqlitePollError::ReenqueueOrphanedError(e),
510                )));
511            }
512            loop {
513                let now: i64 = Utc::now().timestamp();
514                if let Err(e) = self.keep_alive_at(&w, now).await {
515                    w.emit(Event::Error(Box::new(SqlitePollError::KeepAliveError(e))));
516                }
517                apalis_core::sleep(Duration::from_secs(30)).await;
518            }
519        }
520        .boxed();
521        let w = worker.clone();
522        let reenqueue_beat = async move {
523            loop {
524                let dead_since = Utc::now()
525                    - chrono::Duration::from_std(config.reenqueue_orphaned_after).unwrap();
526                if let Err(e) = requeue_storage
527                    .reenqueue_orphaned(
528                        config
529                            .buffer_size
530                            .try_into()
531                            .expect("could not convert usize to i32"),
532                        dead_since,
533                    )
534                    .await
535                {
536                    w.emit(Event::Error(Box::new(
537                        SqlitePollError::ReenqueueOrphanedError(e),
538                    )));
539                }
540                apalis_core::sleep(config.poll_interval).await;
541            }
542        };
543        Poller::new_with_layer(
544            stream,
545            async {
546                futures::join!(heartbeat, reenqueue_beat);
547            },
548            layer,
549        )
550    }
551}
552
553impl<T: Sync + Send, C: Send, Res: Serialize + Sync> Ack<T, Res, C> for SqliteStorage<T, C> {
554    type Context = SqlContext;
555    type AckError = sqlx::Error;
556    async fn ack(&mut self, ctx: &Self::Context, res: &Response<Res>) -> Result<(), sqlx::Error> {
557        let pool = self.pool.clone();
558        let query =
559                "UPDATE Jobs SET status = ?4, attempts = ?5, done_at = strftime('%s','now'), last_error = ?3 WHERE id = ?1 AND lock_by = ?2";
560        let result = serde_json::to_string(&res.inner.as_ref().map_err(|r| r.to_string()))
561            .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
562        sqlx::query(query)
563            .bind(res.task_id.to_string())
564            .bind(
565                ctx.lock_by()
566                    .as_ref()
567                    .expect("Task is not locked")
568                    .to_string(),
569            )
570            .bind(result)
571            .bind(calculate_status(ctx, res).to_string())
572            .bind(res.attempt.current() as u32)
573            .execute(&pool)
574            .await?;
575        Ok(())
576    }
577}
578
579impl<J: 'static + Serialize + DeserializeOwned + Unpin + Send + Sync> BackendExpose<J>
580    for SqliteStorage<J, JsonCodec<String>>
581{
582    type Request = Request<J, Parts<SqlContext>>;
583    type Error = SqlError;
584    async fn stats(&self) -> Result<Stat, Self::Error> {
585        let fetch_query = "SELECT
586                            COUNT(1) FILTER (WHERE status = 'Pending') AS pending,
587                            COUNT(1) FILTER (WHERE status = 'Running') AS running,
588                            COUNT(1) FILTER (WHERE status = 'Done') AS done,
589                            COUNT(1) FILTER (WHERE status = 'Failed') AS failed,
590                            COUNT(1) FILTER (WHERE status = 'Killed') AS killed
591                        FROM Jobs WHERE job_type = ?";
592
593        let res: (i64, i64, i64, i64, i64, i64) = sqlx::query_as(fetch_query)
594            .bind(self.get_config().namespace())
595            .fetch_one(self.pool())
596            .await?;
597
598        Ok(Stat {
599            pending: res.0.try_into()?,
600            running: res.1.try_into()?,
601            dead: res.4.try_into()?,
602            failed: res.3.try_into()?,
603            success: res.2.try_into()?,
604        })
605    }
606
607    async fn list_jobs(
608        &self,
609        status: &State,
610        page: i32,
611    ) -> Result<Vec<Self::Request>, Self::Error> {
612        let status = status.to_string();
613        let fetch_query = "SELECT * FROM Jobs WHERE status = ? AND job_type = ? ORDER BY done_at DESC, run_at DESC LIMIT 10 OFFSET ?";
614        let res: Vec<SqlRequest<String>> = sqlx::query_as(fetch_query)
615            .bind(status)
616            .bind(self.get_config().namespace())
617            .bind(((page - 1) * 10).to_string())
618            .fetch_all(self.pool())
619            .await?;
620        Ok(res
621            .into_iter()
622            .map(|j| {
623                let (req, ctx) = j.req.take_parts();
624                let req = JsonCodec::<String>::decode(req).unwrap();
625                Request::new_with_ctx(req, ctx)
626            })
627            .collect())
628    }
629
630    async fn list_workers(&self) -> Result<Vec<Worker<WorkerState>>, Self::Error> {
631        let fetch_query =
632            "SELECT id, layers, last_seen FROM Workers WHERE worker_type = ? ORDER BY last_seen DESC LIMIT 20 OFFSET ?";
633        let res: Vec<(String, String, i64)> = sqlx::query_as(fetch_query)
634            .bind(self.get_config().namespace())
635            .bind(0)
636            .fetch_all(self.pool())
637            .await?;
638        Ok(res
639            .into_iter()
640            .map(|w| Worker::new(WorkerId::new(w.0), WorkerState::new::<Self>(w.1)))
641            .collect())
642    }
643}
644
645#[cfg(test)]
646mod tests {
647
648    use crate::sql_storage_tests;
649
650    use super::*;
651    use apalis_core::request::State;
652    use chrono::Utc;
653    use email_service::example_good_email;
654    use email_service::Email;
655    use futures::StreamExt;
656
657    use apalis_core::generic_storage_test;
658    use apalis_core::test_utils::apalis_test_service_fn;
659    use apalis_core::test_utils::TestWrapper;
660
661    generic_storage_test!(setup);
662    sql_storage_tests!(setup::<Email>, SqliteStorage<Email>, Email);
663
664    /// migrate DB and return a storage instance.
665    async fn setup<T: Serialize + DeserializeOwned>() -> SqliteStorage<T> {
666        // Because connections cannot be shared across async runtime
667        // (different runtimes are created for each test),
668        // we don't share the storage and tests must be run sequentially.
669        let pool = SqlitePool::connect("sqlite::memory:").await.unwrap();
670        SqliteStorage::setup(&pool)
671            .await
672            .expect("failed to migrate DB");
673        let config = Config::new("apalis::test");
674        let storage = SqliteStorage::<T>::new_with_config(pool, config);
675
676        storage
677    }
678
679    #[tokio::test]
680    async fn test_inmemory_sqlite_worker() {
681        let mut sqlite = setup().await;
682        sqlite
683            .push(Email {
684                subject: "Test Subject".to_string(),
685                to: "example@sqlite".to_string(),
686                text: "Some Text".to_string(),
687            })
688            .await
689            .expect("Unable to push job");
690        let len = sqlite.len().await.expect("Could not fetch the jobs count");
691        assert_eq!(len, 1);
692    }
693
694    async fn consume_one(
695        storage: &mut SqliteStorage<Email>,
696        worker: &Worker<Context>,
697    ) -> Request<Email, SqlContext> {
698        let mut stream = storage
699            .stream_jobs(worker, std::time::Duration::from_secs(10), 1)
700            .boxed();
701        stream
702            .next()
703            .await
704            .expect("stream is empty")
705            .expect("failed to poll job")
706            .expect("no job is pending")
707    }
708
709    async fn register_worker_at(
710        storage: &mut SqliteStorage<Email>,
711        last_seen: i64,
712    ) -> Worker<Context> {
713        let worker_id = WorkerId::new("test-worker");
714
715        let worker = Worker::new(worker_id, Default::default());
716        storage
717            .keep_alive_at(&worker, last_seen)
718            .await
719            .expect("failed to register worker");
720        worker.start();
721        worker
722    }
723
724    async fn register_worker(storage: &mut SqliteStorage<Email>) -> Worker<Context> {
725        register_worker_at(storage, Utc::now().timestamp()).await
726    }
727
728    async fn push_email(storage: &mut SqliteStorage<Email>, email: Email) {
729        storage.push(email).await.expect("failed to push a job");
730    }
731
732    async fn get_job(
733        storage: &mut SqliteStorage<Email>,
734        job_id: &TaskId,
735    ) -> Request<Email, SqlContext> {
736        storage
737            .fetch_by_id(job_id)
738            .await
739            .expect("failed to fetch job by id")
740            .expect("no job found by id")
741    }
742
743    #[tokio::test]
744    async fn test_consume_last_pushed_job() {
745        let mut storage = setup().await;
746        let worker = register_worker(&mut storage).await;
747
748        push_email(&mut storage, example_good_email()).await;
749        let len = storage.len().await.expect("Could not fetch the jobs count");
750        assert_eq!(len, 1);
751
752        let job = consume_one(&mut storage, &worker).await;
753        let ctx = job.parts.context;
754        assert_eq!(*ctx.status(), State::Running);
755        assert_eq!(*ctx.lock_by(), Some(worker.id().clone()));
756        assert!(ctx.lock_at().is_some());
757    }
758
759    #[tokio::test]
760    async fn test_acknowledge_job() {
761        let mut storage = setup().await;
762        let worker = register_worker(&mut storage).await;
763
764        push_email(&mut storage, example_good_email()).await;
765        let job = consume_one(&mut storage, &worker).await;
766        let job_id = &job.parts.task_id;
767        let ctx = &job.parts.context;
768        let res = 1usize;
769        storage
770            .ack(
771                ctx,
772                &Response::success(res, job_id.clone(), job.parts.attempt.clone()),
773            )
774            .await
775            .expect("failed to acknowledge the job");
776
777        let job = get_job(&mut storage, job_id).await;
778        let ctx = job.parts.context;
779        assert_eq!(*ctx.status(), State::Done);
780        assert!(ctx.done_at().is_some());
781    }
782
783    #[tokio::test]
784    async fn test_kill_job() {
785        let mut storage = setup().await;
786
787        push_email(&mut storage, example_good_email()).await;
788
789        let worker = register_worker(&mut storage).await;
790
791        let job = consume_one(&mut storage, &worker).await;
792        let job_id = &job.parts.task_id;
793
794        storage
795            .kill(&worker.id(), job_id)
796            .await
797            .expect("failed to kill job");
798
799        let job = get_job(&mut storage, job_id).await;
800        let ctx = job.parts.context;
801        assert_eq!(*ctx.status(), State::Killed);
802        assert!(ctx.done_at().is_some());
803    }
804
805    #[tokio::test]
806    async fn test_heartbeat_renqueueorphaned_pulse_last_seen_6min() {
807        let mut storage = setup().await;
808
809        push_email(&mut storage, example_good_email()).await;
810
811        let six_minutes_ago = Utc::now() - Duration::from_secs(6 * 60);
812
813        let five_minutes_ago = Utc::now() - Duration::from_secs(5 * 60);
814        let worker = register_worker_at(&mut storage, six_minutes_ago.timestamp()).await;
815
816        let job = consume_one(&mut storage, &worker).await;
817        let job_id = &job.parts.task_id;
818        storage
819            .reenqueue_orphaned(1, five_minutes_ago)
820            .await
821            .expect("failed to heartbeat");
822        let job = get_job(&mut storage, job_id).await;
823        let ctx = &job.parts.context;
824        assert_eq!(*ctx.status(), State::Pending);
825        assert!(ctx.done_at().is_none());
826        assert!(ctx.lock_by().is_none());
827        assert!(ctx.lock_at().is_none());
828        assert_eq!(*ctx.last_error(), Some("Job was abandoned".to_owned()));
829        assert_eq!(job.parts.attempt.current(), 1);
830
831        let job = consume_one(&mut storage, &worker).await;
832        let ctx = &job.parts.context;
833        // Simulate worker
834        job.parts.attempt.increment();
835        storage
836            .ack(
837                ctx,
838                &Response::new(Ok("success".to_owned()), job_id.clone(), job.parts.attempt),
839            )
840            .await
841            .unwrap();
842        //end simulate worker
843
844        let job = get_job(&mut storage, &job_id).await;
845        let ctx = &job.parts.context;
846        assert_eq!(*ctx.status(), State::Done);
847        assert_eq!(*ctx.lock_by(), Some(worker.id().clone()));
848        assert!(ctx.lock_at().is_some());
849        assert_eq!(*ctx.last_error(), Some("{\"Ok\":\"success\"}".to_owned()));
850        assert_eq!(job.parts.attempt.current(), 2);
851    }
852
853    #[tokio::test]
854    async fn test_heartbeat_renqueueorphaned_pulse_last_seen_4min() {
855        let mut storage = setup().await;
856
857        push_email(&mut storage, example_good_email()).await;
858
859        let six_minutes_ago = Utc::now() - Duration::from_secs(6 * 60);
860        let four_minutes_ago = Utc::now() - Duration::from_secs(4 * 60);
861        let worker = register_worker_at(&mut storage, four_minutes_ago.timestamp()).await;
862
863        let job = consume_one(&mut storage, &worker).await;
864        let job_id = job.parts.task_id;
865        storage
866            .reenqueue_orphaned(1, six_minutes_ago)
867            .await
868            .expect("failed to heartbeat");
869
870        let job = get_job(&mut storage, &job_id).await;
871        let ctx = &job.parts.context;
872
873        // Simulate worker
874        job.parts.attempt.increment();
875        storage
876            .ack(
877                ctx,
878                &Response::new(Ok("success".to_owned()), job_id.clone(), job.parts.attempt),
879            )
880            .await
881            .unwrap();
882        //end simulate worker
883
884        let job = get_job(&mut storage, &job_id).await;
885        let ctx = &job.parts.context;
886        assert_eq!(*ctx.status(), State::Done);
887        assert_eq!(*ctx.lock_by(), Some(worker.id().clone()));
888        assert!(ctx.lock_at().is_some());
889        assert_eq!(*ctx.last_error(), Some("{\"Ok\":\"success\"}".to_owned()));
890        assert_eq!(job.parts.attempt.current(), 1);
891    }
892}