1use crate::context::SqlContext;
2use crate::{calculate_status, Config, SqlError};
3use apalis_core::backend::{BackendExpose, Stat, WorkerState};
4use apalis_core::codec::json::JsonCodec;
5use apalis_core::error::Error;
6use apalis_core::layers::{Ack, AckLayer};
7use apalis_core::poller::controller::Controller;
8use apalis_core::poller::stream::BackendStream;
9use apalis_core::poller::Poller;
10use apalis_core::request::{Parts, Request, RequestStream, State};
11use apalis_core::response::Response;
12use apalis_core::storage::Storage;
13use apalis_core::task::namespace::Namespace;
14use apalis_core::task::task_id::TaskId;
15use apalis_core::worker::{Context, Event, Worker, WorkerId};
16use apalis_core::{backend::Backend, codec::Codec};
17use async_stream::try_stream;
18use chrono::{DateTime, Utc};
19use futures::{FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt};
20use log::error;
21use serde::{de::DeserializeOwned, Serialize};
22use sqlx::{Pool, Row, Sqlite};
23use std::any::type_name;
24use std::convert::TryInto;
25use std::fmt::Debug;
26use std::sync::Arc;
27use std::{fmt, io};
28use std::{marker::PhantomData, time::Duration};
29
30use crate::from_row::SqlRequest;
31
32pub use sqlx::sqlite::SqlitePool;
33
34pub struct SqliteStorage<T, C = JsonCodec<String>> {
37 pool: Pool<Sqlite>,
38 job_type: PhantomData<T>,
39 controller: Controller,
40 config: Config,
41 codec: PhantomData<C>,
42}
43
44impl<T, C> fmt::Debug for SqliteStorage<T, C> {
45 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
46 f.debug_struct("SqliteStorage")
47 .field("pool", &self.pool)
48 .field("job_type", &"PhantomData<T>")
49 .field("controller", &self.controller)
50 .field("config", &self.config)
51 .field("codec", &std::any::type_name::<C>())
52 .finish()
53 }
54}
55
56impl<T, C> Clone for SqliteStorage<T, C> {
57 fn clone(&self) -> Self {
58 SqliteStorage {
59 pool: self.pool.clone(),
60 job_type: PhantomData,
61 controller: self.controller.clone(),
62 config: self.config.clone(),
63 codec: self.codec,
64 }
65 }
66}
67
68impl SqliteStorage<()> {
69 #[cfg(feature = "migrate")]
71 pub async fn setup(pool: &Pool<Sqlite>) -> Result<(), sqlx::Error> {
72 sqlx::query("PRAGMA journal_mode = 'WAL';")
73 .execute(pool)
74 .await?;
75 sqlx::query("PRAGMA temp_store = 2;").execute(pool).await?;
76 sqlx::query("PRAGMA synchronous = NORMAL;")
77 .execute(pool)
78 .await?;
79 sqlx::query("PRAGMA cache_size = 64000;")
80 .execute(pool)
81 .await?;
82 Self::migrations().run(pool).await?;
83 Ok(())
84 }
85
86 #[cfg(feature = "migrate")]
88 pub fn migrations() -> sqlx::migrate::Migrator {
89 sqlx::migrate!("migrations/sqlite")
90 }
91}
92
93impl<T> SqliteStorage<T> {
94 pub fn new(pool: SqlitePool) -> Self {
96 Self {
97 pool,
98 job_type: PhantomData,
99 controller: Controller::new(),
100 config: Config::new(type_name::<T>()),
101 codec: PhantomData,
102 }
103 }
104
105 pub fn new_with_config(pool: SqlitePool, config: Config) -> Self {
107 Self {
108 pool,
109 job_type: PhantomData,
110 controller: Controller::new(),
111 config,
112 codec: PhantomData,
113 }
114 }
115}
116impl<T, C> SqliteStorage<T, C> {
117 pub async fn keep_alive_at(
119 &mut self,
120 worker: &Worker<Context>,
121 last_seen: i64,
122 ) -> Result<(), sqlx::Error> {
123 let worker_type = self.config.namespace.clone();
124 let storage_name = std::any::type_name::<Self>();
125 let query = "INSERT INTO Workers (id, worker_type, storage_name, layers, last_seen)
126 VALUES ($1, $2, $3, $4, $5)
127 ON CONFLICT (id) DO
128 UPDATE SET last_seen = EXCLUDED.last_seen";
129 sqlx::query(query)
130 .bind(worker.id().to_string())
131 .bind(worker_type)
132 .bind(storage_name)
133 .bind(worker.get_service())
134 .bind(last_seen)
135 .execute(&self.pool)
136 .await?;
137 Ok(())
138 }
139
140 pub fn pool(&self) -> &Pool<Sqlite> {
142 &self.pool
143 }
144
145 pub fn get_config(&self) -> &Config {
147 &self.config
148 }
149}
150
151impl<T, C> SqliteStorage<T, C> {
152 pub fn codec(&self) -> &PhantomData<C> {
154 &self.codec
155 }
156}
157
158async fn fetch_next(
159 pool: &Pool<Sqlite>,
160 worker_id: &WorkerId,
161 id: String,
162 config: &Config,
163) -> Result<Option<SqlRequest<String>>, sqlx::Error> {
164 let now: i64 = Utc::now().timestamp();
165 let update_query = "UPDATE Jobs SET status = 'Running', lock_by = ?2, lock_at = ?3 WHERE id = ?1 AND job_type = ?4 AND status = 'Pending' AND lock_by IS NULL; Select * from Jobs where id = ?1 AND lock_by = ?2 AND job_type = ?4";
166 let job: Option<SqlRequest<String>> = sqlx::query_as(update_query)
167 .bind(id.to_string())
168 .bind(worker_id.to_string())
169 .bind(now)
170 .bind(config.namespace.clone())
171 .fetch_optional(pool)
172 .await?;
173
174 Ok(job)
175}
176
177impl<T, C> SqliteStorage<T, C>
178where
179 T: DeserializeOwned + Send + Unpin,
180 C: Codec<Compact = String>,
181{
182 fn stream_jobs(
183 &self,
184 worker: &Worker<Context>,
185 interval: Duration,
186 buffer_size: usize,
187 ) -> impl Stream<Item = Result<Option<Request<T, SqlContext>>, sqlx::Error>> {
188 let pool = self.pool.clone();
189 let worker = worker.clone();
190 let config = self.config.clone();
191 let namespace = Namespace(self.config.namespace.clone());
192 try_stream! {
193 loop {
194 apalis_core::sleep(interval).await;
195 if !worker.is_ready() {
196 continue;
197 }
198 let worker_id = worker.id();
199 let tx = pool.clone();
200 let mut tx = tx.acquire().await?;
201 let job_type = &config.namespace;
202 let fetch_query = "SELECT id FROM Jobs
203 WHERE (status = 'Pending' OR (status = 'Failed' AND attempts < max_attempts)) AND run_at < ?1 AND job_type = ?2 ORDER BY priority DESC LIMIT ?3";
204 let now: i64 = Utc::now().timestamp();
205 let ids: Vec<(String,)> = sqlx::query_as(fetch_query)
206 .bind(now)
207 .bind(job_type)
208 .bind(i64::try_from(buffer_size).map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?)
209 .fetch_all(&mut *tx)
210 .await?;
211 for id in ids {
212 let res = fetch_next(&pool, worker_id, id.0, &config).await?;
213 yield match res {
214 None => None::<Request<T, SqlContext>>,
215 Some(job) => {
216 let (req, parts) = job.req.take_parts();
217 let args = C::decode(req)
218 .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
219 let mut req = Request::new_with_parts(args, parts);
220 req.parts.namespace = Some(namespace.clone());
221 Some(req)
222 }
223 }
224 };
225 }
226 }
227 }
228}
229
230impl<T, C> Storage for SqliteStorage<T, C>
231where
232 T: Serialize + DeserializeOwned + Send + 'static + Unpin + Sync,
233 C: Codec<Compact = String> + Send + 'static + Sync,
234 C::Error: std::error::Error + Send + Sync + 'static,
235{
236 type Job = T;
237
238 type Error = sqlx::Error;
239
240 type Context = SqlContext;
241
242 type Compact = String;
243
244 async fn push_request(
245 &mut self,
246 job: Request<Self::Job, SqlContext>,
247 ) -> Result<Parts<SqlContext>, Self::Error> {
248 let query = "INSERT INTO Jobs VALUES (?1, ?2, ?3, 'Pending', 0, ?4, strftime('%s','now'), NULL, NULL, NULL, NULL, ?5)";
249 let (task, parts) = job.take_parts();
250 let raw = C::encode(&task)
251 .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
252 let job_type = self.config.namespace.clone();
253 sqlx::query(query)
254 .bind(raw)
255 .bind(parts.task_id.to_string())
256 .bind(job_type.to_string())
257 .bind(parts.context.max_attempts())
258 .bind(parts.context.priority())
259 .execute(&self.pool)
260 .await?;
261 Ok(parts)
262 }
263
264 async fn push_raw_request(
265 &mut self,
266 job: Request<Self::Compact, SqlContext>,
267 ) -> Result<Parts<SqlContext>, Self::Error> {
268 let query = "INSERT INTO Jobs VALUES (?1, ?2, ?3, 'Pending', 0, ?4, strftime('%s','now'), NULL, NULL, NULL, NULL, ?5)";
269 let (task, parts) = job.take_parts();
270 let raw = C::encode(&task)
271 .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
272 let job_type = self.config.namespace.clone();
273 sqlx::query(query)
274 .bind(raw)
275 .bind(parts.task_id.to_string())
276 .bind(job_type.to_string())
277 .bind(parts.context.max_attempts())
278 .bind(parts.context.priority())
279 .execute(&self.pool)
280 .await?;
281 Ok(parts)
282 }
283
284 async fn schedule_request(
285 &mut self,
286 req: Request<Self::Job, SqlContext>,
287 on: i64,
288 ) -> Result<Parts<SqlContext>, Self::Error> {
289 let query =
290 "INSERT INTO Jobs VALUES (?1, ?2, ?3, 'Pending', 0, ?4, ?5, NULL, NULL, NULL, NULL, ?6)";
291 let id = &req.parts.task_id;
292 let job = C::encode(&req.args)
293 .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
294 let job_type = self.config.namespace.clone();
295 sqlx::query(query)
296 .bind(job)
297 .bind(id.to_string())
298 .bind(job_type)
299 .bind(req.parts.context.max_attempts())
300 .bind(req.parts.context.priority())
301 .bind(on)
302 .execute(&self.pool)
303 .await?;
304 Ok(req.parts)
305 }
306
307 async fn fetch_by_id(
308 &mut self,
309 job_id: &TaskId,
310 ) -> Result<Option<Request<Self::Job, SqlContext>>, Self::Error> {
311 let fetch_query = "SELECT * FROM Jobs WHERE id = ?1";
312 let res: Option<SqlRequest<String>> = sqlx::query_as(fetch_query)
313 .bind(job_id.to_string())
314 .fetch_optional(&self.pool)
315 .await?;
316 match res {
317 None => Ok(None),
318 Some(job) => Ok(Some({
319 let (req, parts) = job.req.take_parts();
320 let args = C::decode(req)
321 .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
322
323 let mut req: Request<T, SqlContext> = Request::new_with_parts(args, parts);
324 req.parts.namespace = Some(Namespace(self.config.namespace.clone()));
325 req
326 })),
327 }
328 }
329
330 async fn len(&mut self) -> Result<i64, Self::Error> {
331 let query = "Select Count(*) as count from Jobs WHERE (status = 'Pending' OR (status = 'Failed' AND attempts < max_attempts))";
332 let record = sqlx::query(query).fetch_one(&self.pool).await?;
333 record.try_get("count")
334 }
335
336 async fn reschedule(
337 &mut self,
338 job: Request<T, SqlContext>,
339 wait: Duration,
340 ) -> Result<(), Self::Error> {
341 let task_id = job.parts.task_id;
342
343 let wait: i64 = wait
344 .as_secs()
345 .try_into()
346 .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
347
348 let mut tx = self.pool.acquire().await?;
349 let query =
350 "UPDATE Jobs SET status = 'Failed', done_at = NULL, lock_by = NULL, lock_at = NULL, run_at = ?2 WHERE id = ?1";
351 let now: i64 = Utc::now().timestamp();
352 let wait_until = now + wait;
353
354 sqlx::query(query)
355 .bind(task_id.to_string())
356 .bind(wait_until)
357 .execute(&mut *tx)
358 .await?;
359 Ok(())
360 }
361
362 async fn update(&mut self, job: Request<Self::Job, SqlContext>) -> Result<(), Self::Error> {
363 let ctx = job.parts.context;
364 let status = ctx.status().to_string();
365 let attempts = job.parts.attempt;
366 let done_at = *ctx.done_at();
367 let lock_by = ctx.lock_by().clone();
368 let lock_at = *ctx.lock_at();
369 let last_error = ctx.last_error().clone();
370 let priority = *ctx.priority();
371 let job_id = job.parts.task_id;
372 let mut tx = self.pool.acquire().await?;
373 let query =
374 "UPDATE Jobs SET status = ?1, attempts = ?2, done_at = ?3, lock_by = ?4, lock_at = ?5, last_error = ?6, priority = ?7 WHERE id = ?8";
375 sqlx::query(query)
376 .bind(status.to_owned())
377 .bind::<i64>(
378 attempts
379 .current()
380 .try_into()
381 .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?,
382 )
383 .bind(done_at)
384 .bind(lock_by.map(|w| w.name().to_string()))
385 .bind(lock_at)
386 .bind(last_error)
387 .bind(priority)
388 .bind(job_id.to_string())
389 .execute(&mut *tx)
390 .await?;
391 Ok(())
392 }
393
394 async fn is_empty(&mut self) -> Result<bool, Self::Error> {
395 self.len().map_ok(|c| c == 0).await
396 }
397
398 async fn vacuum(&mut self) -> Result<usize, sqlx::Error> {
399 let query = "Delete from Jobs where status='Done'";
400 let record = sqlx::query(query).execute(&self.pool).await?;
401 Ok(record.rows_affected().try_into().unwrap_or_default())
402 }
403}
404
405impl<T, C> SqliteStorage<T, C> {
406 pub async fn retry(
409 &mut self,
410 worker_id: &WorkerId,
411 job_id: &TaskId,
412 ) -> Result<(), sqlx::Error> {
413 let mut tx = self.pool.acquire().await?;
414 let query =
415 "UPDATE Jobs SET status = 'Pending', done_at = NULL, lock_by = NULL WHERE id = ?1 AND lock_by = ?2";
416 sqlx::query(query)
417 .bind(job_id.to_string())
418 .bind(worker_id.to_string())
419 .execute(&mut *tx)
420 .await?;
421 Ok(())
422 }
423
424 pub async fn kill(&mut self, worker_id: &WorkerId, job_id: &TaskId) -> Result<(), sqlx::Error> {
426 let mut tx = self.pool.begin().await?;
427 let query =
428 "UPDATE Jobs SET status = 'Killed', done_at = strftime('%s','now') WHERE id = ?1 AND lock_by = ?2";
429 sqlx::query(query)
430 .bind(job_id.to_string())
431 .bind(worker_id.to_string())
432 .execute(&mut *tx)
433 .await?;
434 tx.commit().await?;
435 Ok(())
436 }
437
438 pub async fn reenqueue_orphaned(
440 &self,
441 count: i32,
442 dead_since: DateTime<Utc>,
443 ) -> Result<(), sqlx::Error> {
444 let job_type = self.config.namespace.clone();
445 let mut tx = self.pool.acquire().await?;
446 let query = r#"Update Jobs
447 SET status = "Pending", done_at = NULL, lock_by = NULL, lock_at = NULL, attempts = attempts + 1, last_error ="Job was abandoned"
448 WHERE id in
449 (SELECT Jobs.id from Jobs INNER join Workers ON lock_by = Workers.id
450 WHERE status= "Running" AND workers.last_seen < ?1
451 AND Workers.worker_type = ?2 ORDER BY lock_at ASC LIMIT ?3);"#;
452
453 sqlx::query(query)
454 .bind(dead_since.timestamp())
455 .bind(job_type)
456 .bind(count)
457 .execute(&mut *tx)
458 .await?;
459 Ok(())
460 }
461}
462
463#[derive(thiserror::Error, Debug)]
465pub enum SqlitePollError {
466 #[error("Encountered an error during KeepAlive heartbeat: `{0}`")]
468 KeepAliveError(sqlx::Error),
469
470 #[error("Encountered an error during ReenqueueOrphaned heartbeat: `{0}`")]
472 ReenqueueOrphanedError(sqlx::Error),
473}
474
475impl<T, C> Backend<Request<T, SqlContext>> for SqliteStorage<T, C>
476where
477 C: Codec<Compact = String> + Send + 'static + Sync,
478 C::Error: std::error::Error + 'static + Send + Sync,
479 T: Serialize + DeserializeOwned + Sync + Send + Unpin + 'static,
480{
481 type Stream = BackendStream<RequestStream<Request<T, SqlContext>>>;
482 type Layer = AckLayer<SqliteStorage<T, C>, T, SqlContext, C>;
483
484 type Codec = JsonCodec<String>;
485
486 fn poll(mut self, worker: &Worker<Context>) -> Poller<Self::Stream, Self::Layer> {
487 let layer = AckLayer::new(self.clone());
488 let config = self.config.clone();
489 let controller = self.controller.clone();
490 let stream = self
491 .stream_jobs(worker, config.poll_interval, config.buffer_size)
492 .map_err(|e| Error::SourceError(Arc::new(Box::new(e))));
493 let stream = BackendStream::new(stream.boxed(), controller);
494 let requeue_storage = self.clone();
495 let w = worker.clone();
496 let heartbeat = async move {
497 if let Err(e) = self
499 .reenqueue_orphaned((config.buffer_size * 10) as i32, Utc::now())
500 .await
501 {
502 w.emit(Event::Error(Box::new(
503 SqlitePollError::ReenqueueOrphanedError(e),
504 )));
505 }
506 loop {
507 let now: i64 = Utc::now().timestamp();
508 if let Err(e) = self.keep_alive_at(&w, now).await {
509 w.emit(Event::Error(Box::new(SqlitePollError::KeepAliveError(e))));
510 }
511 apalis_core::sleep(Duration::from_secs(30)).await;
512 }
513 }
514 .boxed();
515 let w = worker.clone();
516 let reenqueue_beat = async move {
517 loop {
518 let dead_since = Utc::now()
519 - chrono::Duration::from_std(config.reenqueue_orphaned_after).unwrap();
520 if let Err(e) = requeue_storage
521 .reenqueue_orphaned(
522 config
523 .buffer_size
524 .try_into()
525 .expect("could not convert usize to i32"),
526 dead_since,
527 )
528 .await
529 {
530 w.emit(Event::Error(Box::new(
531 SqlitePollError::ReenqueueOrphanedError(e),
532 )));
533 }
534 apalis_core::sleep(config.poll_interval).await;
535 }
536 };
537 Poller::new_with_layer(
538 stream,
539 async {
540 futures::join!(heartbeat, reenqueue_beat);
541 },
542 layer,
543 )
544 }
545}
546
547impl<T: Sync + Send, C: Send, Res: Serialize + Sync> Ack<T, Res, C> for SqliteStorage<T, C> {
548 type Context = SqlContext;
549 type AckError = sqlx::Error;
550 async fn ack(&mut self, ctx: &Self::Context, res: &Response<Res>) -> Result<(), sqlx::Error> {
551 let pool = self.pool.clone();
552 let query =
553 "UPDATE Jobs SET status = ?4, attempts = ?5, done_at = strftime('%s','now'), last_error = ?3 WHERE id = ?1 AND lock_by = ?2";
554 let result = serde_json::to_string(&res.inner.as_ref().map_err(|r| r.to_string()))
555 .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
556 sqlx::query(query)
557 .bind(res.task_id.to_string())
558 .bind(
559 ctx.lock_by()
560 .as_ref()
561 .expect("Task is not locked")
562 .to_string(),
563 )
564 .bind(result)
565 .bind(calculate_status(ctx, res).to_string())
566 .bind(res.attempt.current() as u32)
567 .execute(&pool)
568 .await?;
569 Ok(())
570 }
571}
572
573impl<J: 'static + Serialize + DeserializeOwned + Unpin + Send + Sync> BackendExpose<J>
574 for SqliteStorage<J, JsonCodec<String>>
575{
576 type Request = Request<J, Parts<SqlContext>>;
577 type Error = SqlError;
578 async fn stats(&self) -> Result<Stat, Self::Error> {
579 let fetch_query = "SELECT
580 COUNT(1) FILTER (WHERE status = 'Pending') AS pending,
581 COUNT(1) FILTER (WHERE status = 'Running') AS running,
582 COUNT(1) FILTER (WHERE status = 'Done') AS done,
583 COUNT(1) FILTER (WHERE status = 'Failed') AS failed,
584 COUNT(1) FILTER (WHERE status = 'Killed') AS killed
585 FROM Jobs WHERE job_type = ?";
586
587 let res: (i64, i64, i64, i64, i64, i64) = sqlx::query_as(fetch_query)
588 .bind(self.get_config().namespace())
589 .fetch_one(self.pool())
590 .await?;
591
592 Ok(Stat {
593 pending: res.0.try_into()?,
594 running: res.1.try_into()?,
595 dead: res.4.try_into()?,
596 failed: res.3.try_into()?,
597 success: res.2.try_into()?,
598 })
599 }
600
601 async fn list_jobs(
602 &self,
603 status: &State,
604 page: i32,
605 ) -> Result<Vec<Self::Request>, Self::Error> {
606 let status = status.to_string();
607 let fetch_query = "SELECT * FROM Jobs WHERE status = ? AND job_type = ? ORDER BY done_at DESC, run_at DESC LIMIT 10 OFFSET ?";
608 let res: Vec<SqlRequest<String>> = sqlx::query_as(fetch_query)
609 .bind(status)
610 .bind(self.get_config().namespace())
611 .bind(((page - 1) * 10).to_string())
612 .fetch_all(self.pool())
613 .await?;
614 Ok(res
615 .into_iter()
616 .map(|j| {
617 let (req, ctx) = j.req.take_parts();
618 let req = JsonCodec::<String>::decode(req).unwrap();
619 Request::new_with_ctx(req, ctx)
620 })
621 .collect())
622 }
623
624 async fn list_workers(&self) -> Result<Vec<Worker<WorkerState>>, Self::Error> {
625 let fetch_query =
626 "SELECT id, layers, last_seen FROM Workers WHERE worker_type = ? ORDER BY last_seen DESC LIMIT 20 OFFSET ?";
627 let res: Vec<(String, String, i64)> = sqlx::query_as(fetch_query)
628 .bind(self.get_config().namespace())
629 .bind(0)
630 .fetch_all(self.pool())
631 .await?;
632 Ok(res
633 .into_iter()
634 .map(|w| Worker::new(WorkerId::new(w.0), WorkerState::new::<Self>(w.1)))
635 .collect())
636 }
637}
638
639#[cfg(test)]
640mod tests {
641
642 use crate::sql_storage_tests;
643
644 use super::*;
645 use apalis_core::request::State;
646 use chrono::Utc;
647 use email_service::example_good_email;
648 use email_service::Email;
649 use futures::StreamExt;
650
651 use apalis_core::generic_storage_test;
652 use apalis_core::test_utils::apalis_test_service_fn;
653 use apalis_core::test_utils::TestWrapper;
654
655 generic_storage_test!(setup);
656 sql_storage_tests!(setup::<Email>, SqliteStorage<Email>, Email);
657
658 async fn setup<T: Serialize + DeserializeOwned>() -> SqliteStorage<T> {
660 let pool = SqlitePool::connect("sqlite::memory:").await.unwrap();
664 SqliteStorage::setup(&pool)
665 .await
666 .expect("failed to migrate DB");
667 let config = Config::new("apalis::test");
668 let storage = SqliteStorage::<T>::new_with_config(pool, config);
669
670 storage
671 }
672
673 #[tokio::test]
674 async fn test_inmemory_sqlite_worker() {
675 let mut sqlite = setup().await;
676 sqlite
677 .push(Email {
678 subject: "Test Subject".to_string(),
679 to: "example@sqlite".to_string(),
680 text: "Some Text".to_string(),
681 })
682 .await
683 .expect("Unable to push job");
684 let len = sqlite.len().await.expect("Could not fetch the jobs count");
685 assert_eq!(len, 1);
686 }
687
688 async fn consume_one(
689 storage: &mut SqliteStorage<Email>,
690 worker: &Worker<Context>,
691 ) -> Request<Email, SqlContext> {
692 let mut stream = storage
693 .stream_jobs(worker, std::time::Duration::from_secs(10), 1)
694 .boxed();
695 stream
696 .next()
697 .await
698 .expect("stream is empty")
699 .expect("failed to poll job")
700 .expect("no job is pending")
701 }
702
703 async fn register_worker_at(
704 storage: &mut SqliteStorage<Email>,
705 last_seen: i64,
706 ) -> Worker<Context> {
707 let worker_id = WorkerId::new("test-worker");
708
709 let worker = Worker::new(worker_id, Default::default());
710 storage
711 .keep_alive_at(&worker, last_seen)
712 .await
713 .expect("failed to register worker");
714 worker.start();
715 worker
716 }
717
718 async fn register_worker(storage: &mut SqliteStorage<Email>) -> Worker<Context> {
719 register_worker_at(storage, Utc::now().timestamp()).await
720 }
721
722 async fn push_email(storage: &mut SqliteStorage<Email>, email: Email) {
723 storage.push(email).await.expect("failed to push a job");
724 }
725
726 async fn get_job(
727 storage: &mut SqliteStorage<Email>,
728 job_id: &TaskId,
729 ) -> Request<Email, SqlContext> {
730 storage
731 .fetch_by_id(job_id)
732 .await
733 .expect("failed to fetch job by id")
734 .expect("no job found by id")
735 }
736
737 #[tokio::test]
738 async fn test_consume_last_pushed_job() {
739 let mut storage = setup().await;
740 let worker = register_worker(&mut storage).await;
741
742 push_email(&mut storage, example_good_email()).await;
743 let len = storage.len().await.expect("Could not fetch the jobs count");
744 assert_eq!(len, 1);
745
746 let job = consume_one(&mut storage, &worker).await;
747 let ctx = job.parts.context;
748 assert_eq!(*ctx.status(), State::Running);
749 assert_eq!(*ctx.lock_by(), Some(worker.id().clone()));
750 assert!(ctx.lock_at().is_some());
751 }
752
753 #[tokio::test]
754 async fn test_acknowledge_job() {
755 let mut storage = setup().await;
756 let worker = register_worker(&mut storage).await;
757
758 push_email(&mut storage, example_good_email()).await;
759 let job = consume_one(&mut storage, &worker).await;
760 let job_id = &job.parts.task_id;
761 let ctx = &job.parts.context;
762 let res = 1usize;
763 storage
764 .ack(
765 ctx,
766 &Response::success(res, job_id.clone(), job.parts.attempt.clone()),
767 )
768 .await
769 .expect("failed to acknowledge the job");
770
771 let job = get_job(&mut storage, job_id).await;
772 let ctx = job.parts.context;
773 assert_eq!(*ctx.status(), State::Done);
774 assert!(ctx.done_at().is_some());
775 }
776
777 #[tokio::test]
778 async fn test_kill_job() {
779 let mut storage = setup().await;
780
781 push_email(&mut storage, example_good_email()).await;
782
783 let worker = register_worker(&mut storage).await;
784
785 let job = consume_one(&mut storage, &worker).await;
786 let job_id = &job.parts.task_id;
787
788 storage
789 .kill(&worker.id(), job_id)
790 .await
791 .expect("failed to kill job");
792
793 let job = get_job(&mut storage, job_id).await;
794 let ctx = job.parts.context;
795 assert_eq!(*ctx.status(), State::Killed);
796 assert!(ctx.done_at().is_some());
797 }
798
799 #[tokio::test]
800 async fn test_heartbeat_renqueueorphaned_pulse_last_seen_6min() {
801 let mut storage = setup().await;
802
803 push_email(&mut storage, example_good_email()).await;
804
805 let six_minutes_ago = Utc::now() - Duration::from_secs(6 * 60);
806
807 let five_minutes_ago = Utc::now() - Duration::from_secs(5 * 60);
808 let worker = register_worker_at(&mut storage, six_minutes_ago.timestamp()).await;
809
810 let job = consume_one(&mut storage, &worker).await;
811 let job_id = &job.parts.task_id;
812 storage
813 .reenqueue_orphaned(1, five_minutes_ago)
814 .await
815 .expect("failed to heartbeat");
816 let job = get_job(&mut storage, job_id).await;
817 let ctx = &job.parts.context;
818 assert_eq!(*ctx.status(), State::Pending);
819 assert!(ctx.done_at().is_none());
820 assert!(ctx.lock_by().is_none());
821 assert!(ctx.lock_at().is_none());
822 assert_eq!(*ctx.last_error(), Some("Job was abandoned".to_owned()));
823 assert_eq!(job.parts.attempt.current(), 1);
824
825 let job = consume_one(&mut storage, &worker).await;
826 let ctx = &job.parts.context;
827 job.parts.attempt.increment();
829 storage
830 .ack(
831 ctx,
832 &Response::new(Ok("success".to_owned()), job_id.clone(), job.parts.attempt),
833 )
834 .await
835 .unwrap();
836 let job = get_job(&mut storage, &job_id).await;
839 let ctx = &job.parts.context;
840 assert_eq!(*ctx.status(), State::Done);
841 assert_eq!(*ctx.lock_by(), Some(worker.id().clone()));
842 assert!(ctx.lock_at().is_some());
843 assert_eq!(*ctx.last_error(), Some("{\"Ok\":\"success\"}".to_owned()));
844 assert_eq!(job.parts.attempt.current(), 2);
845 }
846
847 #[tokio::test]
848 async fn test_heartbeat_renqueueorphaned_pulse_last_seen_4min() {
849 let mut storage = setup().await;
850
851 push_email(&mut storage, example_good_email()).await;
852
853 let six_minutes_ago = Utc::now() - Duration::from_secs(6 * 60);
854 let four_minutes_ago = Utc::now() - Duration::from_secs(4 * 60);
855 let worker = register_worker_at(&mut storage, four_minutes_ago.timestamp()).await;
856
857 let job = consume_one(&mut storage, &worker).await;
858 let job_id = job.parts.task_id;
859 storage
860 .reenqueue_orphaned(1, six_minutes_ago)
861 .await
862 .expect("failed to heartbeat");
863
864 let job = get_job(&mut storage, &job_id).await;
865 let ctx = &job.parts.context;
866
867 job.parts.attempt.increment();
869 storage
870 .ack(
871 ctx,
872 &Response::new(Ok("success".to_owned()), job_id.clone(), job.parts.attempt),
873 )
874 .await
875 .unwrap();
876 let job = get_job(&mut storage, &job_id).await;
879 let ctx = &job.parts.context;
880 assert_eq!(*ctx.status(), State::Done);
881 assert_eq!(*ctx.lock_by(), Some(worker.id().clone()));
882 assert!(ctx.lock_at().is_some());
883 assert_eq!(*ctx.last_error(), Some("{\"Ok\":\"success\"}".to_owned()));
884 assert_eq!(job.parts.attempt.current(), 1);
885 }
886}