apalis_sql/
sqlite.rs

1use crate::context::SqlContext;
2use crate::{calculate_status, Config, SqlError};
3use apalis_core::backend::{BackendExpose, Stat, WorkerState};
4use apalis_core::codec::json::JsonCodec;
5use apalis_core::error::Error;
6use apalis_core::layers::{Ack, AckLayer};
7use apalis_core::poller::controller::Controller;
8use apalis_core::poller::stream::BackendStream;
9use apalis_core::poller::Poller;
10use apalis_core::request::{Parts, Request, RequestStream, State};
11use apalis_core::response::Response;
12use apalis_core::storage::Storage;
13use apalis_core::task::namespace::Namespace;
14use apalis_core::task::task_id::TaskId;
15use apalis_core::worker::{Context, Event, Worker, WorkerId};
16use apalis_core::{backend::Backend, codec::Codec};
17use async_stream::try_stream;
18use chrono::{DateTime, Utc};
19use futures::{FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt};
20use log::error;
21use serde::{de::DeserializeOwned, Serialize};
22use sqlx::{Pool, Row, Sqlite};
23use std::any::type_name;
24use std::convert::TryInto;
25use std::fmt::Debug;
26use std::sync::Arc;
27use std::{fmt, io};
28use std::{marker::PhantomData, time::Duration};
29
30use crate::from_row::SqlRequest;
31
32pub use sqlx::sqlite::SqlitePool;
33
34/// Represents a [Storage] that persists to Sqlite
35// #[derive(Debug)]
36pub struct SqliteStorage<T, C = JsonCodec<String>> {
37    pool: Pool<Sqlite>,
38    job_type: PhantomData<T>,
39    controller: Controller,
40    config: Config,
41    codec: PhantomData<C>,
42}
43
44impl<T, C> fmt::Debug for SqliteStorage<T, C> {
45    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
46        f.debug_struct("SqliteStorage")
47            .field("pool", &self.pool)
48            .field("job_type", &"PhantomData<T>")
49            .field("controller", &self.controller)
50            .field("config", &self.config)
51            .field("codec", &std::any::type_name::<C>())
52            .finish()
53    }
54}
55
56impl<T, C> Clone for SqliteStorage<T, C> {
57    fn clone(&self) -> Self {
58        SqliteStorage {
59            pool: self.pool.clone(),
60            job_type: PhantomData,
61            controller: self.controller.clone(),
62            config: self.config.clone(),
63            codec: self.codec,
64        }
65    }
66}
67
68impl SqliteStorage<()> {
69    /// Perform migrations for storage
70    #[cfg(feature = "migrate")]
71    pub async fn setup(pool: &Pool<Sqlite>) -> Result<(), sqlx::Error> {
72        sqlx::query("PRAGMA journal_mode = 'WAL';")
73            .execute(pool)
74            .await?;
75        sqlx::query("PRAGMA temp_store = 2;").execute(pool).await?;
76        sqlx::query("PRAGMA synchronous = NORMAL;")
77            .execute(pool)
78            .await?;
79        sqlx::query("PRAGMA cache_size = 64000;")
80            .execute(pool)
81            .await?;
82        Self::migrations().run(pool).await?;
83        Ok(())
84    }
85
86    /// Get sqlite migrations without running them
87    #[cfg(feature = "migrate")]
88    pub fn migrations() -> sqlx::migrate::Migrator {
89        sqlx::migrate!("migrations/sqlite")
90    }
91}
92
93impl<T> SqliteStorage<T> {
94    /// Create a new instance
95    pub fn new(pool: SqlitePool) -> Self {
96        Self {
97            pool,
98            job_type: PhantomData,
99            controller: Controller::new(),
100            config: Config::new(type_name::<T>()),
101            codec: PhantomData,
102        }
103    }
104
105    /// Create a new instance with a custom config
106    pub fn new_with_config(pool: SqlitePool, config: Config) -> Self {
107        Self {
108            pool,
109            job_type: PhantomData,
110            controller: Controller::new(),
111            config,
112            codec: PhantomData,
113        }
114    }
115}
116impl<T, C> SqliteStorage<T, C> {
117    /// Keeps a storage notified that the worker is still alive manually
118    pub async fn keep_alive_at(
119        &mut self,
120        worker: &Worker<Context>,
121        last_seen: i64,
122    ) -> Result<(), sqlx::Error> {
123        let worker_type = self.config.namespace.clone();
124        let storage_name = std::any::type_name::<Self>();
125        let query = "INSERT INTO Workers (id, worker_type, storage_name, layers, last_seen)
126                VALUES ($1, $2, $3, $4, $5)
127                ON CONFLICT (id) DO
128                   UPDATE SET last_seen = EXCLUDED.last_seen";
129        sqlx::query(query)
130            .bind(worker.id().to_string())
131            .bind(worker_type)
132            .bind(storage_name)
133            .bind(worker.get_service())
134            .bind(last_seen)
135            .execute(&self.pool)
136            .await?;
137        Ok(())
138    }
139
140    /// Expose the pool for other functionality, eg custom migrations
141    pub fn pool(&self) -> &Pool<Sqlite> {
142        &self.pool
143    }
144
145    /// Get the config used by the storage
146    pub fn get_config(&self) -> &Config {
147        &self.config
148    }
149}
150
151impl<T, C> SqliteStorage<T, C> {
152    /// Expose the code used
153    pub fn codec(&self) -> &PhantomData<C> {
154        &self.codec
155    }
156}
157
158async fn fetch_next(
159    pool: &Pool<Sqlite>,
160    worker_id: &WorkerId,
161    id: String,
162    config: &Config,
163) -> Result<Option<SqlRequest<String>>, sqlx::Error> {
164    let now: i64 = Utc::now().timestamp();
165    let update_query = "UPDATE Jobs SET status = 'Running', lock_by = ?2, lock_at = ?3 WHERE id = ?1 AND job_type = ?4 AND status = 'Pending' AND lock_by IS NULL RETURNING *";
166    let job: Option<SqlRequest<String>> = sqlx::query_as(update_query)
167        .bind(id.to_string())
168        .bind(worker_id.to_string())
169        .bind(now)
170        .bind(config.namespace.clone())
171        .fetch_optional(pool)
172        .await?;
173
174    if job.is_none() {
175        // Retry path.
176        let job: Option<SqlRequest<String>> =
177            sqlx::query_as("SELECT * FROM Jobs WHERE id = ?1 AND lock_by = ?2 AND job_type = ?3")
178                .bind(id.to_string())
179                .bind(worker_id.to_string())
180                .bind(config.namespace.clone())
181                .fetch_optional(pool)
182                .await?;
183
184        return Ok(job);
185    }
186
187    Ok(job)
188}
189
190impl<T, C> SqliteStorage<T, C>
191where
192    T: DeserializeOwned + Send + Unpin,
193    C: Codec<Compact = String>,
194{
195    fn stream_jobs(
196        &self,
197        worker: &Worker<Context>,
198        interval: Duration,
199        buffer_size: usize,
200    ) -> impl Stream<Item = Result<Option<Request<T, SqlContext>>, sqlx::Error>> {
201        let pool = self.pool.clone();
202        let worker = worker.clone();
203        let config = self.config.clone();
204        let namespace = Namespace(self.config.namespace.clone());
205        try_stream! {
206            loop {
207                apalis_core::sleep(interval).await;
208                if !worker.is_ready() {
209                    continue;
210                }
211                let worker_id = worker.id();
212                let job_type = &config.namespace;
213                let fetch_query = "SELECT id FROM Jobs
214                    WHERE (status = 'Pending' OR (status = 'Failed' AND attempts < max_attempts)) AND run_at < ?1 AND job_type = ?2 ORDER BY priority DESC LIMIT ?3";
215                let now: i64 = Utc::now().timestamp();
216                let ids: Vec<(String,)> = sqlx::query_as(fetch_query)
217                    .bind(now)
218                    .bind(job_type)
219                    .bind(i64::try_from(buffer_size).map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?)
220                    .fetch_all(&pool)
221                    .await?;
222
223                if ids.is_empty() {
224                    yield None::<Request<T, SqlContext>>;
225                } else {
226                    for id in ids {
227                        let res = fetch_next(&pool, worker_id, id.0, &config).await?;
228                        yield match res {
229                            None => None::<Request<T, SqlContext>>,
230                            Some(job) => {
231                                let (req, parts) = job.req.take_parts();
232                                let args = C::decode(req)
233                                    .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
234                                let mut req = Request::new_with_parts(args, parts);
235                                req.parts.namespace = Some(namespace.clone());
236                                Some(req)
237                            }
238                        };
239                    }
240                }
241            }
242        }
243    }
244}
245
246impl<T, C> Storage for SqliteStorage<T, C>
247where
248    T: Serialize + DeserializeOwned + Send + 'static + Unpin + Sync,
249    C: Codec<Compact = String> + Send + 'static + Sync,
250    C::Error: std::error::Error + Send + Sync + 'static,
251{
252    type Job = T;
253
254    type Error = sqlx::Error;
255
256    type Context = SqlContext;
257
258    type Compact = String;
259
260    async fn push_request(
261        &mut self,
262        job: Request<Self::Job, SqlContext>,
263    ) -> Result<Parts<SqlContext>, Self::Error> {
264        let query = "INSERT INTO Jobs VALUES (?1, ?2, ?3, 'Pending', 0, ?4, strftime('%s','now'), NULL, NULL, NULL, NULL, ?5)";
265        let (task, parts) = job.take_parts();
266        let raw = C::encode(&task)
267            .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
268        let job_type = self.config.namespace.clone();
269        sqlx::query(query)
270            .bind(raw)
271            .bind(parts.task_id.to_string())
272            .bind(job_type.to_string())
273            .bind(parts.context.max_attempts())
274            .bind(parts.context.priority())
275            .execute(&self.pool)
276            .await?;
277        Ok(parts)
278    }
279
280    async fn push_raw_request(
281        &mut self,
282        job: Request<Self::Compact, SqlContext>,
283    ) -> Result<Parts<SqlContext>, Self::Error> {
284        let query = "INSERT INTO Jobs VALUES (?1, ?2, ?3, 'Pending', 0, ?4, strftime('%s','now'), NULL, NULL, NULL, NULL, ?5)";
285        let (task, parts) = job.take_parts();
286        let raw = C::encode(&task)
287            .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
288        let job_type = self.config.namespace.clone();
289        sqlx::query(query)
290            .bind(raw)
291            .bind(parts.task_id.to_string())
292            .bind(job_type.to_string())
293            .bind(parts.context.max_attempts())
294            .bind(parts.context.priority())
295            .execute(&self.pool)
296            .await?;
297        Ok(parts)
298    }
299
300    async fn schedule_request(
301        &mut self,
302        req: Request<Self::Job, SqlContext>,
303        on: i64,
304    ) -> Result<Parts<SqlContext>, Self::Error> {
305        let query =
306            "INSERT INTO Jobs VALUES (?1, ?2, ?3, 'Pending', 0, ?4, ?5, NULL, NULL, NULL, NULL, ?6)";
307        let id = &req.parts.task_id;
308        let job = C::encode(&req.args)
309            .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
310        let job_type = self.config.namespace.clone();
311        sqlx::query(query)
312            .bind(job)
313            .bind(id.to_string())
314            .bind(job_type)
315            .bind(req.parts.context.max_attempts())
316            .bind(on)
317            .bind(req.parts.context.priority())
318            .execute(&self.pool)
319            .await?;
320        Ok(req.parts)
321    }
322
323    async fn fetch_by_id(
324        &mut self,
325        job_id: &TaskId,
326    ) -> Result<Option<Request<Self::Job, SqlContext>>, Self::Error> {
327        let fetch_query = "SELECT * FROM Jobs WHERE id = ?1";
328        let res: Option<SqlRequest<String>> = sqlx::query_as(fetch_query)
329            .bind(job_id.to_string())
330            .fetch_optional(&self.pool)
331            .await?;
332        match res {
333            None => Ok(None),
334            Some(job) => Ok(Some({
335                let (req, parts) = job.req.take_parts();
336                let args = C::decode(req)
337                    .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
338
339                let mut req: Request<T, SqlContext> = Request::new_with_parts(args, parts);
340                req.parts.namespace = Some(Namespace(self.config.namespace.clone()));
341                req
342            })),
343        }
344    }
345
346    async fn len(&mut self) -> Result<i64, Self::Error> {
347        let query = "Select COUNT(*) AS count FROM Jobs WHERE (status = 'Pending' OR (status = 'Failed' AND attempts < max_attempts))";
348        let record = sqlx::query(query).fetch_one(&self.pool).await?;
349        record.try_get("count")
350    }
351
352    async fn reschedule(
353        &mut self,
354        job: Request<T, SqlContext>,
355        wait: Duration,
356    ) -> Result<(), Self::Error> {
357        let task_id = job.parts.task_id;
358
359        let wait: i64 = wait
360            .as_secs()
361            .try_into()
362            .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
363
364        let query =
365                "UPDATE Jobs SET status = 'Failed', done_at = NULL, lock_by = NULL, lock_at = NULL, run_at = ?2 WHERE id = ?1";
366        let now: i64 = Utc::now().timestamp();
367        let wait_until = now + wait;
368
369        sqlx::query(query)
370            .bind(task_id.to_string())
371            .bind(wait_until)
372            .execute(self.pool())
373            .await?;
374        Ok(())
375    }
376
377    async fn update(&mut self, job: Request<Self::Job, SqlContext>) -> Result<(), Self::Error> {
378        let ctx = job.parts.context;
379        let status = ctx.status().to_string();
380        let attempts = job.parts.attempt;
381        let done_at = *ctx.done_at();
382        let lock_by = ctx.lock_by().clone();
383        let lock_at = *ctx.lock_at();
384        let last_error = ctx.last_error().clone();
385        let priority = *ctx.priority();
386        let job_id = job.parts.task_id;
387        let query =
388                "UPDATE Jobs SET status = ?1, attempts = ?2, done_at = ?3, lock_by = ?4, lock_at = ?5, last_error = ?6, priority = ?7 WHERE id = ?8";
389        sqlx::query(query)
390            .bind(status.to_owned())
391            .bind::<i64>(
392                attempts
393                    .current()
394                    .try_into()
395                    .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?,
396            )
397            .bind(done_at)
398            .bind(lock_by.map(|w| w.name().to_string()))
399            .bind(lock_at)
400            .bind(last_error)
401            .bind(priority)
402            .bind(job_id.to_string())
403            .execute(self.pool())
404            .await?;
405        Ok(())
406    }
407
408    async fn is_empty(&mut self) -> Result<bool, Self::Error> {
409        self.len().map_ok(|c| c == 0).await
410    }
411
412    async fn vacuum(&mut self) -> Result<usize, sqlx::Error> {
413        let query = "Delete from Jobs where status='Done'";
414        let record = sqlx::query(query).execute(&self.pool).await?;
415        Ok(record.rows_affected().try_into().unwrap_or_default())
416    }
417}
418
419impl<T, C> SqliteStorage<T, C> {
420    /// Puts the job instantly back into the queue
421    /// Another Worker may consume
422    pub async fn retry(
423        &mut self,
424        worker_id: &WorkerId,
425        job_id: &TaskId,
426    ) -> Result<(), sqlx::Error> {
427        let query =
428                "UPDATE Jobs SET status = 'Pending', done_at = NULL, lock_by = NULL WHERE id = ?1 AND lock_by = ?2";
429        sqlx::query(query)
430            .bind(job_id.to_string())
431            .bind(worker_id.to_string())
432            .execute(self.pool())
433            .await?;
434        Ok(())
435    }
436
437    /// Kill a job
438    pub async fn kill(&mut self, worker_id: &WorkerId, job_id: &TaskId) -> Result<(), sqlx::Error> {
439        let query =
440                "UPDATE Jobs SET status = 'Killed', done_at = strftime('%s','now') WHERE id = ?1 AND lock_by = ?2";
441        sqlx::query(query)
442            .bind(job_id.to_string())
443            .bind(worker_id.to_string())
444            .execute(self.pool())
445            .await?;
446        Ok(())
447    }
448
449    /// Add jobs that workers have disappeared to the queue
450    pub async fn reenqueue_orphaned(
451        &self,
452        count: i32,
453        dead_since: DateTime<Utc>,
454    ) -> Result<(), sqlx::Error> {
455        let job_type = self.config.namespace.clone();
456        let query = r#"Update Jobs
457                            SET status = "Pending", done_at = NULL, lock_by = NULL, lock_at = NULL, attempts = attempts + 1, last_error ="Job was abandoned"
458                            WHERE id in
459                                (SELECT Jobs.id from Jobs INNER join Workers ON lock_by = Workers.id
460                                    WHERE status= "Running" AND workers.last_seen < ?1
461                                    AND Workers.worker_type = ?2 ORDER BY lock_at ASC LIMIT ?3);"#;
462
463        sqlx::query(query)
464            .bind(dead_since.timestamp())
465            .bind(job_type)
466            .bind(count)
467            .execute(self.pool())
468            .await?;
469        Ok(())
470    }
471}
472
473/// Errors that can occur while polling an SQLite database.
474#[derive(thiserror::Error, Debug)]
475pub enum SqlitePollError {
476    /// Error during a keep-alive heartbeat.
477    #[error("Encountered an error during KeepAlive heartbeat: `{0}`")]
478    KeepAliveError(sqlx::Error),
479
480    /// Error during re-enqueuing orphaned tasks.
481    #[error("Encountered an error during ReenqueueOrphaned heartbeat: `{0}`")]
482    ReenqueueOrphanedError(sqlx::Error),
483}
484
485impl<T, C> Backend<Request<T, SqlContext>> for SqliteStorage<T, C>
486where
487    C: Codec<Compact = String> + Send + 'static + Sync,
488    C::Error: std::error::Error + 'static + Send + Sync,
489    T: Serialize + DeserializeOwned + Sync + Send + Unpin + 'static,
490{
491    type Stream = BackendStream<RequestStream<Request<T, SqlContext>>>;
492    type Layer = AckLayer<SqliteStorage<T, C>, T, SqlContext, C>;
493
494    type Codec = JsonCodec<String>;
495
496    fn poll(mut self, worker: &Worker<Context>) -> Poller<Self::Stream, Self::Layer> {
497        let layer = AckLayer::new(self.clone());
498        let config = self.config.clone();
499        let controller = self.controller.clone();
500        let stream = self
501            .stream_jobs(worker, config.poll_interval, config.buffer_size)
502            .map_err(|e| Error::SourceError(Arc::new(Box::new(e))));
503        let stream = BackendStream::new(stream.boxed(), controller);
504        let requeue_storage = self.clone();
505        let w = worker.clone();
506        let heartbeat = async move {
507            // Lets reenqueue any jobs that belonged to this worker in case of a death
508            if let Err(e) = self
509                .reenqueue_orphaned((config.buffer_size * 10) as i32, Utc::now())
510                .await
511            {
512                w.emit(Event::Error(Box::new(
513                    SqlitePollError::ReenqueueOrphanedError(e),
514                )));
515            }
516            loop {
517                let now: i64 = Utc::now().timestamp();
518                if let Err(e) = self.keep_alive_at(&w, now).await {
519                    w.emit(Event::Error(Box::new(SqlitePollError::KeepAliveError(e))));
520                }
521                apalis_core::sleep(Duration::from_secs(30)).await;
522            }
523        }
524        .boxed();
525        let w = worker.clone();
526        let reenqueue_beat = async move {
527            loop {
528                let dead_since = Utc::now()
529                    - chrono::Duration::from_std(config.reenqueue_orphaned_after).unwrap();
530                if let Err(e) = requeue_storage
531                    .reenqueue_orphaned(
532                        config
533                            .buffer_size
534                            .try_into()
535                            .expect("could not convert usize to i32"),
536                        dead_since,
537                    )
538                    .await
539                {
540                    w.emit(Event::Error(Box::new(
541                        SqlitePollError::ReenqueueOrphanedError(e),
542                    )));
543                }
544                apalis_core::sleep(config.poll_interval).await;
545            }
546        };
547        Poller::new_with_layer(
548            stream,
549            async {
550                futures::join!(heartbeat, reenqueue_beat);
551            },
552            layer,
553        )
554    }
555}
556
557impl<T: Sync + Send, C: Send, Res: Serialize + Sync> Ack<T, Res, C> for SqliteStorage<T, C> {
558    type Context = SqlContext;
559    type AckError = sqlx::Error;
560    async fn ack(&mut self, ctx: &Self::Context, res: &Response<Res>) -> Result<(), sqlx::Error> {
561        let pool = self.pool.clone();
562        let query =
563                "UPDATE Jobs SET status = ?4, attempts = ?5, done_at = strftime('%s','now'), last_error = ?3 WHERE id = ?1 AND lock_by = ?2";
564        let result = serde_json::to_string(&res.inner.as_ref().map_err(|r| r.to_string()))
565            .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
566        sqlx::query(query)
567            .bind(res.task_id.to_string())
568            .bind(
569                ctx.lock_by()
570                    .as_ref()
571                    .expect("Task is not locked")
572                    .to_string(),
573            )
574            .bind(result)
575            .bind(calculate_status(ctx, res).to_string())
576            .bind(res.attempt.current() as u32)
577            .execute(&pool)
578            .await?;
579        Ok(())
580    }
581}
582
583impl<J: 'static + Serialize + DeserializeOwned + Unpin + Send + Sync> BackendExpose<J>
584    for SqliteStorage<J, JsonCodec<String>>
585{
586    type Request = Request<J, Parts<SqlContext>>;
587    type Error = SqlError;
588    async fn stats(&self) -> Result<Stat, Self::Error> {
589        let fetch_query = "SELECT
590                            COUNT(1) FILTER (WHERE status = 'Pending') AS pending,
591                            COUNT(1) FILTER (WHERE status = 'Running') AS running,
592                            COUNT(1) FILTER (WHERE status = 'Done') AS done,
593                            COUNT(1) FILTER (WHERE status = 'Failed') AS failed,
594                            COUNT(1) FILTER (WHERE status = 'Killed') AS killed
595                        FROM Jobs WHERE job_type = ?";
596
597        let res: (i64, i64, i64, i64, i64) = sqlx::query_as(fetch_query)
598            .bind(self.get_config().namespace())
599            .fetch_one(self.pool())
600            .await?;
601
602        Ok(Stat {
603            pending: res.0.try_into()?,
604            running: res.1.try_into()?,
605            dead: res.4.try_into()?,
606            failed: res.3.try_into()?,
607            success: res.2.try_into()?,
608        })
609    }
610
611    async fn list_jobs(
612        &self,
613        status: &State,
614        page: i32,
615    ) -> Result<Vec<Self::Request>, Self::Error> {
616        let status = status.to_string();
617        let fetch_query = "SELECT * FROM Jobs WHERE status = ? AND job_type = ? ORDER BY done_at DESC, run_at DESC LIMIT 10 OFFSET ?";
618        let res: Vec<SqlRequest<String>> = sqlx::query_as(fetch_query)
619            .bind(status)
620            .bind(self.get_config().namespace())
621            .bind(((page - 1) * 10).to_string())
622            .fetch_all(self.pool())
623            .await?;
624        Ok(res
625            .into_iter()
626            .map(|j| {
627                let (req, ctx) = j.req.take_parts();
628                let req = JsonCodec::<String>::decode(req).unwrap();
629                Request::new_with_ctx(req, ctx)
630            })
631            .collect())
632    }
633
634    async fn list_workers(&self) -> Result<Vec<Worker<WorkerState>>, Self::Error> {
635        let fetch_query =
636            "SELECT id, layers, last_seen FROM Workers WHERE worker_type = ? ORDER BY last_seen DESC LIMIT 20 OFFSET ?";
637        let res: Vec<(String, String, i64)> = sqlx::query_as(fetch_query)
638            .bind(self.get_config().namespace())
639            .bind(0)
640            .fetch_all(self.pool())
641            .await?;
642        Ok(res
643            .into_iter()
644            .map(|w| Worker::new(WorkerId::new(w.0), WorkerState::new::<Self>(w.1)))
645            .collect())
646    }
647}
648
649#[cfg(test)]
650mod tests {
651
652    use crate::sql_storage_tests;
653
654    use super::*;
655    use apalis_core::request::State;
656    use chrono::Utc;
657    use email_service::example_good_email;
658    use email_service::Email;
659    use futures::StreamExt;
660
661    use apalis_core::generic_storage_test;
662    use apalis_core::test_utils::apalis_test_service_fn;
663    use apalis_core::test_utils::TestWrapper;
664
665    generic_storage_test!(setup);
666    sql_storage_tests!(setup::<Email>, SqliteStorage<Email>, Email);
667
668    /// migrate DB and return a storage instance.
669    async fn setup<T: Serialize + DeserializeOwned>() -> SqliteStorage<T> {
670        // Because connections cannot be shared across async runtime
671        // (different runtimes are created for each test),
672        // we don't share the storage and tests must be run sequentially.
673        let pool = SqlitePool::connect("sqlite::memory:").await.unwrap();
674        SqliteStorage::setup(&pool)
675            .await
676            .expect("failed to migrate DB");
677        let config = Config::new("apalis::test");
678        let storage = SqliteStorage::<T>::new_with_config(pool, config);
679
680        storage
681    }
682
683    #[tokio::test]
684    async fn test_inmemory_sqlite_worker() {
685        let mut sqlite = setup().await;
686        sqlite
687            .push(Email {
688                subject: "Test Subject".to_string(),
689                to: "example@sqlite".to_string(),
690                text: "Some Text".to_string(),
691            })
692            .await
693            .expect("Unable to push job");
694        let len = sqlite.len().await.expect("Could not fetch the jobs count");
695        assert_eq!(len, 1);
696    }
697
698    async fn consume_one(
699        storage: &mut SqliteStorage<Email>,
700        worker: &Worker<Context>,
701    ) -> Request<Email, SqlContext> {
702        let mut stream = storage
703            .stream_jobs(worker, std::time::Duration::from_secs(10), 1)
704            .boxed();
705        stream
706            .next()
707            .await
708            .expect("stream is empty")
709            .expect("failed to poll job")
710            .expect("no job is pending")
711    }
712
713    async fn register_worker_at(
714        storage: &mut SqliteStorage<Email>,
715        last_seen: i64,
716    ) -> Worker<Context> {
717        let worker_id = WorkerId::new("test-worker");
718
719        let worker = Worker::new(worker_id, Default::default());
720        storage
721            .keep_alive_at(&worker, last_seen)
722            .await
723            .expect("failed to register worker");
724        worker.start();
725        worker
726    }
727
728    async fn register_worker(storage: &mut SqliteStorage<Email>) -> Worker<Context> {
729        register_worker_at(storage, Utc::now().timestamp()).await
730    }
731
732    async fn push_email(storage: &mut SqliteStorage<Email>, email: Email) {
733        storage.push(email).await.expect("failed to push a job");
734    }
735
736    async fn get_job(
737        storage: &mut SqliteStorage<Email>,
738        job_id: &TaskId,
739    ) -> Request<Email, SqlContext> {
740        storage
741            .fetch_by_id(job_id)
742            .await
743            .expect("failed to fetch job by id")
744            .expect("no job found by id")
745    }
746
747    #[tokio::test]
748    async fn test_consume_last_pushed_job() {
749        let mut storage = setup().await;
750        let worker = register_worker(&mut storage).await;
751
752        push_email(&mut storage, example_good_email()).await;
753        let len = storage.len().await.expect("Could not fetch the jobs count");
754        assert_eq!(len, 1);
755
756        let job = consume_one(&mut storage, &worker).await;
757        let ctx = job.parts.context;
758        assert_eq!(*ctx.status(), State::Running);
759        assert_eq!(*ctx.lock_by(), Some(worker.id().clone()));
760        assert!(ctx.lock_at().is_some());
761    }
762
763    #[tokio::test]
764    async fn test_acknowledge_job() {
765        let mut storage = setup().await;
766        let worker = register_worker(&mut storage).await;
767
768        push_email(&mut storage, example_good_email()).await;
769        let job = consume_one(&mut storage, &worker).await;
770        let job_id = &job.parts.task_id;
771        let ctx = &job.parts.context;
772        let res = 1usize;
773        storage
774            .ack(
775                ctx,
776                &Response::success(res, job_id.clone(), job.parts.attempt.clone()),
777            )
778            .await
779            .expect("failed to acknowledge the job");
780
781        let job = get_job(&mut storage, job_id).await;
782        let ctx = job.parts.context;
783        assert_eq!(*ctx.status(), State::Done);
784        assert!(ctx.done_at().is_some());
785    }
786
787    #[tokio::test]
788    async fn test_kill_job() {
789        let mut storage = setup().await;
790
791        push_email(&mut storage, example_good_email()).await;
792
793        let worker = register_worker(&mut storage).await;
794
795        let job = consume_one(&mut storage, &worker).await;
796        let job_id = &job.parts.task_id;
797
798        storage
799            .kill(&worker.id(), job_id)
800            .await
801            .expect("failed to kill job");
802
803        let job = get_job(&mut storage, job_id).await;
804        let ctx = job.parts.context;
805        assert_eq!(*ctx.status(), State::Killed);
806        assert!(ctx.done_at().is_some());
807    }
808
809    #[tokio::test]
810    async fn test_heartbeat_renqueueorphaned_pulse_last_seen_6min() {
811        let mut storage = setup().await;
812
813        push_email(&mut storage, example_good_email()).await;
814
815        let six_minutes_ago = Utc::now() - Duration::from_secs(6 * 60);
816
817        let five_minutes_ago = Utc::now() - Duration::from_secs(5 * 60);
818        let worker = register_worker_at(&mut storage, six_minutes_ago.timestamp()).await;
819
820        let job = consume_one(&mut storage, &worker).await;
821        let job_id = &job.parts.task_id;
822        storage
823            .reenqueue_orphaned(1, five_minutes_ago)
824            .await
825            .expect("failed to heartbeat");
826        let job = get_job(&mut storage, job_id).await;
827        let ctx = &job.parts.context;
828        assert_eq!(*ctx.status(), State::Pending);
829        assert!(ctx.done_at().is_none());
830        assert!(ctx.lock_by().is_none());
831        assert!(ctx.lock_at().is_none());
832        assert_eq!(*ctx.last_error(), Some("Job was abandoned".to_owned()));
833        assert_eq!(job.parts.attempt.current(), 1);
834
835        let job = consume_one(&mut storage, &worker).await;
836        let ctx = &job.parts.context;
837        // Simulate worker
838        job.parts.attempt.increment();
839        storage
840            .ack(
841                ctx,
842                &Response::new(Ok("success".to_owned()), job_id.clone(), job.parts.attempt),
843            )
844            .await
845            .unwrap();
846        //end simulate worker
847
848        let job = get_job(&mut storage, &job_id).await;
849        let ctx = &job.parts.context;
850        assert_eq!(*ctx.status(), State::Done);
851        assert_eq!(*ctx.lock_by(), Some(worker.id().clone()));
852        assert!(ctx.lock_at().is_some());
853        assert_eq!(*ctx.last_error(), Some("{\"Ok\":\"success\"}".to_owned()));
854        assert_eq!(job.parts.attempt.current(), 2);
855    }
856
857    #[tokio::test]
858    async fn test_heartbeat_renqueueorphaned_pulse_last_seen_4min() {
859        let mut storage = setup().await;
860
861        push_email(&mut storage, example_good_email()).await;
862
863        let six_minutes_ago = Utc::now() - Duration::from_secs(6 * 60);
864        let four_minutes_ago = Utc::now() - Duration::from_secs(4 * 60);
865        let worker = register_worker_at(&mut storage, four_minutes_ago.timestamp()).await;
866
867        let job = consume_one(&mut storage, &worker).await;
868        let job_id = job.parts.task_id;
869        storage
870            .reenqueue_orphaned(1, six_minutes_ago)
871            .await
872            .expect("failed to heartbeat");
873
874        let job = get_job(&mut storage, &job_id).await;
875        let ctx = &job.parts.context;
876
877        // Simulate worker
878        job.parts.attempt.increment();
879        storage
880            .ack(
881                ctx,
882                &Response::new(Ok("success".to_owned()), job_id.clone(), job.parts.attempt),
883            )
884            .await
885            .unwrap();
886        //end simulate worker
887
888        let job = get_job(&mut storage, &job_id).await;
889        let ctx = &job.parts.context;
890        assert_eq!(*ctx.status(), State::Done);
891        assert_eq!(*ctx.lock_by(), Some(worker.id().clone()));
892        assert!(ctx.lock_at().is_some());
893        assert_eq!(*ctx.last_error(), Some("{\"Ok\":\"success\"}".to_owned()));
894        assert_eq!(job.parts.attempt.current(), 1);
895    }
896}