1use crate::context::SqlContext;
2use crate::{calculate_status, Config, SqlError};
3use apalis_core::backend::{BackendExpose, Stat, WorkerState};
4use apalis_core::codec::json::JsonCodec;
5use apalis_core::error::Error;
6use apalis_core::layers::{Ack, AckLayer};
7use apalis_core::poller::controller::Controller;
8use apalis_core::poller::stream::BackendStream;
9use apalis_core::poller::Poller;
10use apalis_core::request::{Parts, Request, RequestStream, State};
11use apalis_core::response::Response;
12use apalis_core::storage::Storage;
13use apalis_core::task::namespace::Namespace;
14use apalis_core::task::task_id::TaskId;
15use apalis_core::worker::{Context, Event, Worker, WorkerId};
16use apalis_core::{backend::Backend, codec::Codec};
17use async_stream::try_stream;
18use chrono::{DateTime, Utc};
19use futures::{FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt};
20use log::error;
21use serde::{de::DeserializeOwned, Serialize};
22use sqlx::{Pool, Row, Sqlite};
23use std::any::type_name;
24use std::convert::TryInto;
25use std::fmt::Debug;
26use std::sync::Arc;
27use std::{fmt, io};
28use std::{marker::PhantomData, time::Duration};
29
30use crate::from_row::SqlRequest;
31
32pub use sqlx::sqlite::SqlitePool;
33
34pub struct SqliteStorage<T, C = JsonCodec<String>> {
37 pool: Pool<Sqlite>,
38 job_type: PhantomData<T>,
39 controller: Controller,
40 config: Config,
41 codec: PhantomData<C>,
42}
43
44impl<T, C> fmt::Debug for SqliteStorage<T, C> {
45 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
46 f.debug_struct("SqliteStorage")
47 .field("pool", &self.pool)
48 .field("job_type", &"PhantomData<T>")
49 .field("controller", &self.controller)
50 .field("config", &self.config)
51 .field("codec", &std::any::type_name::<C>())
52 .finish()
53 }
54}
55
56impl<T, C> Clone for SqliteStorage<T, C> {
57 fn clone(&self) -> Self {
58 SqliteStorage {
59 pool: self.pool.clone(),
60 job_type: PhantomData,
61 controller: self.controller.clone(),
62 config: self.config.clone(),
63 codec: self.codec,
64 }
65 }
66}
67
68impl SqliteStorage<()> {
69 #[cfg(feature = "migrate")]
71 pub async fn setup(pool: &Pool<Sqlite>) -> Result<(), sqlx::Error> {
72 sqlx::query("PRAGMA journal_mode = 'WAL';")
73 .execute(pool)
74 .await?;
75 sqlx::query("PRAGMA temp_store = 2;").execute(pool).await?;
76 sqlx::query("PRAGMA synchronous = NORMAL;")
77 .execute(pool)
78 .await?;
79 sqlx::query("PRAGMA cache_size = 64000;")
80 .execute(pool)
81 .await?;
82 Self::migrations().run(pool).await?;
83 Ok(())
84 }
85
86 #[cfg(feature = "migrate")]
88 pub fn migrations() -> sqlx::migrate::Migrator {
89 sqlx::migrate!("migrations/sqlite")
90 }
91}
92
93impl<T> SqliteStorage<T> {
94 pub fn new(pool: SqlitePool) -> Self {
96 Self {
97 pool,
98 job_type: PhantomData,
99 controller: Controller::new(),
100 config: Config::new(type_name::<T>()),
101 codec: PhantomData,
102 }
103 }
104
105 pub fn new_with_config(pool: SqlitePool, config: Config) -> Self {
107 Self {
108 pool,
109 job_type: PhantomData,
110 controller: Controller::new(),
111 config,
112 codec: PhantomData,
113 }
114 }
115}
116impl<T, C> SqliteStorage<T, C> {
117 pub async fn keep_alive_at(
119 &mut self,
120 worker: &Worker<Context>,
121 last_seen: i64,
122 ) -> Result<(), sqlx::Error> {
123 let worker_type = self.config.namespace.clone();
124 let storage_name = std::any::type_name::<Self>();
125 let query = "INSERT INTO Workers (id, worker_type, storage_name, layers, last_seen)
126 VALUES ($1, $2, $3, $4, $5)
127 ON CONFLICT (id) DO
128 UPDATE SET last_seen = EXCLUDED.last_seen";
129 sqlx::query(query)
130 .bind(worker.id().to_string())
131 .bind(worker_type)
132 .bind(storage_name)
133 .bind(worker.get_service())
134 .bind(last_seen)
135 .execute(&self.pool)
136 .await?;
137 Ok(())
138 }
139
140 pub fn pool(&self) -> &Pool<Sqlite> {
142 &self.pool
143 }
144
145 pub fn get_config(&self) -> &Config {
147 &self.config
148 }
149}
150
151impl<T, C> SqliteStorage<T, C> {
152 pub fn codec(&self) -> &PhantomData<C> {
154 &self.codec
155 }
156}
157
158async fn fetch_next(
159 pool: &Pool<Sqlite>,
160 worker_id: &WorkerId,
161 id: String,
162 config: &Config,
163) -> Result<Option<SqlRequest<String>>, sqlx::Error> {
164 let now: i64 = Utc::now().timestamp();
165 let update_query = "UPDATE Jobs SET status = 'Running', lock_by = ?2, lock_at = ?3 WHERE id = ?1 AND job_type = ?4 AND status = 'Pending' AND lock_by IS NULL RETURNING *";
166 let job: Option<SqlRequest<String>> = sqlx::query_as(update_query)
167 .bind(id.to_string())
168 .bind(worker_id.to_string())
169 .bind(now)
170 .bind(config.namespace.clone())
171 .fetch_optional(pool)
172 .await?;
173
174 if job.is_none() {
175 let job: Option<SqlRequest<String>> =
177 sqlx::query_as("SELECT * FROM Jobs WHERE id = ?1 AND lock_by = ?2 AND job_type = ?3")
178 .bind(id.to_string())
179 .bind(worker_id.to_string())
180 .bind(config.namespace.clone())
181 .fetch_optional(pool)
182 .await?;
183
184 return Ok(job);
185 }
186
187 Ok(job)
188}
189
190impl<T, C> SqliteStorage<T, C>
191where
192 T: DeserializeOwned + Send + Unpin,
193 C: Codec<Compact = String>,
194{
195 fn stream_jobs(
196 &self,
197 worker: &Worker<Context>,
198 interval: Duration,
199 buffer_size: usize,
200 ) -> impl Stream<Item = Result<Option<Request<T, SqlContext>>, sqlx::Error>> {
201 let pool = self.pool.clone();
202 let worker = worker.clone();
203 let config = self.config.clone();
204 let namespace = Namespace(self.config.namespace.clone());
205 try_stream! {
206 loop {
207 apalis_core::sleep(interval).await;
208 if !worker.is_ready() {
209 continue;
210 }
211 let worker_id = worker.id();
212 let job_type = &config.namespace;
213 let fetch_query = "SELECT id FROM Jobs
214 WHERE (status = 'Pending' OR (status = 'Failed' AND attempts < max_attempts)) AND run_at < ?1 AND job_type = ?2 ORDER BY priority DESC LIMIT ?3";
215 let now: i64 = Utc::now().timestamp();
216 let ids: Vec<(String,)> = sqlx::query_as(fetch_query)
217 .bind(now)
218 .bind(job_type)
219 .bind(i64::try_from(buffer_size).map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?)
220 .fetch_all(&pool)
221 .await?;
222
223 if ids.is_empty() {
224 yield None::<Request<T, SqlContext>>;
225 } else {
226 for id in ids {
227 let res = fetch_next(&pool, worker_id, id.0, &config).await?;
228 yield match res {
229 None => None::<Request<T, SqlContext>>,
230 Some(job) => {
231 let (req, parts) = job.req.take_parts();
232 let args = C::decode(req)
233 .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
234 let mut req = Request::new_with_parts(args, parts);
235 req.parts.namespace = Some(namespace.clone());
236 Some(req)
237 }
238 };
239 }
240 }
241 }
242 }
243 }
244}
245
246impl<T, C> Storage for SqliteStorage<T, C>
247where
248 T: Serialize + DeserializeOwned + Send + 'static + Unpin + Sync,
249 C: Codec<Compact = String> + Send + 'static + Sync,
250 C::Error: std::error::Error + Send + Sync + 'static,
251{
252 type Job = T;
253
254 type Error = sqlx::Error;
255
256 type Context = SqlContext;
257
258 type Compact = String;
259
260 async fn push_request(
261 &mut self,
262 job: Request<Self::Job, SqlContext>,
263 ) -> Result<Parts<SqlContext>, Self::Error> {
264 let query = "INSERT INTO Jobs VALUES (?1, ?2, ?3, 'Pending', 0, ?4, strftime('%s','now'), NULL, NULL, NULL, NULL, ?5)";
265 let (task, parts) = job.take_parts();
266 let raw = C::encode(&task)
267 .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
268 let job_type = self.config.namespace.clone();
269 sqlx::query(query)
270 .bind(raw)
271 .bind(parts.task_id.to_string())
272 .bind(job_type.to_string())
273 .bind(parts.context.max_attempts())
274 .bind(parts.context.priority())
275 .execute(&self.pool)
276 .await?;
277 Ok(parts)
278 }
279
280 async fn push_raw_request(
281 &mut self,
282 job: Request<Self::Compact, SqlContext>,
283 ) -> Result<Parts<SqlContext>, Self::Error> {
284 let query = "INSERT INTO Jobs VALUES (?1, ?2, ?3, 'Pending', 0, ?4, strftime('%s','now'), NULL, NULL, NULL, NULL, ?5)";
285 let (task, parts) = job.take_parts();
286 let raw = C::encode(&task)
287 .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
288 let job_type = self.config.namespace.clone();
289 sqlx::query(query)
290 .bind(raw)
291 .bind(parts.task_id.to_string())
292 .bind(job_type.to_string())
293 .bind(parts.context.max_attempts())
294 .bind(parts.context.priority())
295 .execute(&self.pool)
296 .await?;
297 Ok(parts)
298 }
299
300 async fn schedule_request(
301 &mut self,
302 req: Request<Self::Job, SqlContext>,
303 on: i64,
304 ) -> Result<Parts<SqlContext>, Self::Error> {
305 let query =
306 "INSERT INTO Jobs VALUES (?1, ?2, ?3, 'Pending', 0, ?4, ?5, NULL, NULL, NULL, NULL, ?6)";
307 let id = &req.parts.task_id;
308 let job = C::encode(&req.args)
309 .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
310 let job_type = self.config.namespace.clone();
311 sqlx::query(query)
312 .bind(job)
313 .bind(id.to_string())
314 .bind(job_type)
315 .bind(req.parts.context.max_attempts())
316 .bind(on)
317 .bind(req.parts.context.priority())
318 .execute(&self.pool)
319 .await?;
320 Ok(req.parts)
321 }
322
323 async fn fetch_by_id(
324 &mut self,
325 job_id: &TaskId,
326 ) -> Result<Option<Request<Self::Job, SqlContext>>, Self::Error> {
327 let fetch_query = "SELECT * FROM Jobs WHERE id = ?1";
328 let res: Option<SqlRequest<String>> = sqlx::query_as(fetch_query)
329 .bind(job_id.to_string())
330 .fetch_optional(&self.pool)
331 .await?;
332 match res {
333 None => Ok(None),
334 Some(job) => Ok(Some({
335 let (req, parts) = job.req.take_parts();
336 let args = C::decode(req)
337 .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
338
339 let mut req: Request<T, SqlContext> = Request::new_with_parts(args, parts);
340 req.parts.namespace = Some(Namespace(self.config.namespace.clone()));
341 req
342 })),
343 }
344 }
345
346 async fn len(&mut self) -> Result<i64, Self::Error> {
347 let query = "Select COUNT(*) AS count FROM Jobs WHERE (status = 'Pending' OR (status = 'Failed' AND attempts < max_attempts))";
348 let record = sqlx::query(query).fetch_one(&self.pool).await?;
349 record.try_get("count")
350 }
351
352 async fn reschedule(
353 &mut self,
354 job: Request<T, SqlContext>,
355 wait: Duration,
356 ) -> Result<(), Self::Error> {
357 let task_id = job.parts.task_id;
358
359 let wait: i64 = wait
360 .as_secs()
361 .try_into()
362 .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
363
364 let query =
365 "UPDATE Jobs SET status = 'Failed', done_at = NULL, lock_by = NULL, lock_at = NULL, run_at = ?2 WHERE id = ?1";
366 let now: i64 = Utc::now().timestamp();
367 let wait_until = now + wait;
368
369 sqlx::query(query)
370 .bind(task_id.to_string())
371 .bind(wait_until)
372 .execute(self.pool())
373 .await?;
374 Ok(())
375 }
376
377 async fn update(&mut self, job: Request<Self::Job, SqlContext>) -> Result<(), Self::Error> {
378 let ctx = job.parts.context;
379 let status = ctx.status().to_string();
380 let attempts = job.parts.attempt;
381 let done_at = *ctx.done_at();
382 let lock_by = ctx.lock_by().clone();
383 let lock_at = *ctx.lock_at();
384 let last_error = ctx.last_error().clone();
385 let priority = *ctx.priority();
386 let job_id = job.parts.task_id;
387 let query =
388 "UPDATE Jobs SET status = ?1, attempts = ?2, done_at = ?3, lock_by = ?4, lock_at = ?5, last_error = ?6, priority = ?7 WHERE id = ?8";
389 sqlx::query(query)
390 .bind(status.to_owned())
391 .bind::<i64>(
392 attempts
393 .current()
394 .try_into()
395 .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?,
396 )
397 .bind(done_at)
398 .bind(lock_by.map(|w| w.name().to_string()))
399 .bind(lock_at)
400 .bind(last_error)
401 .bind(priority)
402 .bind(job_id.to_string())
403 .execute(self.pool())
404 .await?;
405 Ok(())
406 }
407
408 async fn is_empty(&mut self) -> Result<bool, Self::Error> {
409 self.len().map_ok(|c| c == 0).await
410 }
411
412 async fn vacuum(&mut self) -> Result<usize, sqlx::Error> {
413 let query = "Delete from Jobs where status='Done'";
414 let record = sqlx::query(query).execute(&self.pool).await?;
415 Ok(record.rows_affected().try_into().unwrap_or_default())
416 }
417}
418
419impl<T, C> SqliteStorage<T, C> {
420 pub async fn retry(
423 &mut self,
424 worker_id: &WorkerId,
425 job_id: &TaskId,
426 ) -> Result<(), sqlx::Error> {
427 let query =
428 "UPDATE Jobs SET status = 'Pending', done_at = NULL, lock_by = NULL WHERE id = ?1 AND lock_by = ?2";
429 sqlx::query(query)
430 .bind(job_id.to_string())
431 .bind(worker_id.to_string())
432 .execute(self.pool())
433 .await?;
434 Ok(())
435 }
436
437 pub async fn kill(&mut self, worker_id: &WorkerId, job_id: &TaskId) -> Result<(), sqlx::Error> {
439 let query =
440 "UPDATE Jobs SET status = 'Killed', done_at = strftime('%s','now') WHERE id = ?1 AND lock_by = ?2";
441 sqlx::query(query)
442 .bind(job_id.to_string())
443 .bind(worker_id.to_string())
444 .execute(self.pool())
445 .await?;
446 Ok(())
447 }
448
449 pub async fn reenqueue_orphaned(
451 &self,
452 count: i32,
453 dead_since: DateTime<Utc>,
454 ) -> Result<(), sqlx::Error> {
455 let job_type = self.config.namespace.clone();
456 let query = r#"Update Jobs
457 SET status = "Pending", done_at = NULL, lock_by = NULL, lock_at = NULL, attempts = attempts + 1, last_error ="Job was abandoned"
458 WHERE id in
459 (SELECT Jobs.id from Jobs INNER join Workers ON lock_by = Workers.id
460 WHERE status= "Running" AND workers.last_seen < ?1
461 AND Workers.worker_type = ?2 ORDER BY lock_at ASC LIMIT ?3);"#;
462
463 sqlx::query(query)
464 .bind(dead_since.timestamp())
465 .bind(job_type)
466 .bind(count)
467 .execute(self.pool())
468 .await?;
469 Ok(())
470 }
471}
472
473#[derive(thiserror::Error, Debug)]
475pub enum SqlitePollError {
476 #[error("Encountered an error during KeepAlive heartbeat: `{0}`")]
478 KeepAliveError(sqlx::Error),
479
480 #[error("Encountered an error during ReenqueueOrphaned heartbeat: `{0}`")]
482 ReenqueueOrphanedError(sqlx::Error),
483}
484
485impl<T, C> Backend<Request<T, SqlContext>> for SqliteStorage<T, C>
486where
487 C: Codec<Compact = String> + Send + 'static + Sync,
488 C::Error: std::error::Error + 'static + Send + Sync,
489 T: Serialize + DeserializeOwned + Sync + Send + Unpin + 'static,
490{
491 type Stream = BackendStream<RequestStream<Request<T, SqlContext>>>;
492 type Layer = AckLayer<SqliteStorage<T, C>, T, SqlContext, C>;
493
494 type Codec = JsonCodec<String>;
495
496 fn poll(mut self, worker: &Worker<Context>) -> Poller<Self::Stream, Self::Layer> {
497 let layer = AckLayer::new(self.clone());
498 let config = self.config.clone();
499 let controller = self.controller.clone();
500 let stream = self
501 .stream_jobs(worker, config.poll_interval, config.buffer_size)
502 .map_err(|e| Error::SourceError(Arc::new(Box::new(e))));
503 let stream = BackendStream::new(stream.boxed(), controller);
504 let requeue_storage = self.clone();
505 let w = worker.clone();
506 let heartbeat = async move {
507 if let Err(e) = self
509 .reenqueue_orphaned((config.buffer_size * 10) as i32, Utc::now())
510 .await
511 {
512 w.emit(Event::Error(Box::new(
513 SqlitePollError::ReenqueueOrphanedError(e),
514 )));
515 }
516 loop {
517 let now: i64 = Utc::now().timestamp();
518 if let Err(e) = self.keep_alive_at(&w, now).await {
519 w.emit(Event::Error(Box::new(SqlitePollError::KeepAliveError(e))));
520 }
521 apalis_core::sleep(Duration::from_secs(30)).await;
522 }
523 }
524 .boxed();
525 let w = worker.clone();
526 let reenqueue_beat = async move {
527 loop {
528 let dead_since = Utc::now()
529 - chrono::Duration::from_std(config.reenqueue_orphaned_after).unwrap();
530 if let Err(e) = requeue_storage
531 .reenqueue_orphaned(
532 config
533 .buffer_size
534 .try_into()
535 .expect("could not convert usize to i32"),
536 dead_since,
537 )
538 .await
539 {
540 w.emit(Event::Error(Box::new(
541 SqlitePollError::ReenqueueOrphanedError(e),
542 )));
543 }
544 apalis_core::sleep(config.poll_interval).await;
545 }
546 };
547 Poller::new_with_layer(
548 stream,
549 async {
550 futures::join!(heartbeat, reenqueue_beat);
551 },
552 layer,
553 )
554 }
555}
556
557impl<T: Sync + Send, C: Send, Res: Serialize + Sync> Ack<T, Res, C> for SqliteStorage<T, C> {
558 type Context = SqlContext;
559 type AckError = sqlx::Error;
560 async fn ack(&mut self, ctx: &Self::Context, res: &Response<Res>) -> Result<(), sqlx::Error> {
561 let pool = self.pool.clone();
562 let query =
563 "UPDATE Jobs SET status = ?4, attempts = ?5, done_at = strftime('%s','now'), last_error = ?3 WHERE id = ?1 AND lock_by = ?2";
564 let result = serde_json::to_string(&res.inner.as_ref().map_err(|r| r.to_string()))
565 .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
566 sqlx::query(query)
567 .bind(res.task_id.to_string())
568 .bind(
569 ctx.lock_by()
570 .as_ref()
571 .expect("Task is not locked")
572 .to_string(),
573 )
574 .bind(result)
575 .bind(calculate_status(ctx, res).to_string())
576 .bind(res.attempt.current() as u32)
577 .execute(&pool)
578 .await?;
579 Ok(())
580 }
581}
582
583impl<J: 'static + Serialize + DeserializeOwned + Unpin + Send + Sync> BackendExpose<J>
584 for SqliteStorage<J, JsonCodec<String>>
585{
586 type Request = Request<J, Parts<SqlContext>>;
587 type Error = SqlError;
588 async fn stats(&self) -> Result<Stat, Self::Error> {
589 let fetch_query = "SELECT
590 COUNT(1) FILTER (WHERE status = 'Pending') AS pending,
591 COUNT(1) FILTER (WHERE status = 'Running') AS running,
592 COUNT(1) FILTER (WHERE status = 'Done') AS done,
593 COUNT(1) FILTER (WHERE status = 'Failed') AS failed,
594 COUNT(1) FILTER (WHERE status = 'Killed') AS killed
595 FROM Jobs WHERE job_type = ?";
596
597 let res: (i64, i64, i64, i64, i64) = sqlx::query_as(fetch_query)
598 .bind(self.get_config().namespace())
599 .fetch_one(self.pool())
600 .await?;
601
602 Ok(Stat {
603 pending: res.0.try_into()?,
604 running: res.1.try_into()?,
605 dead: res.4.try_into()?,
606 failed: res.3.try_into()?,
607 success: res.2.try_into()?,
608 })
609 }
610
611 async fn list_jobs(
612 &self,
613 status: &State,
614 page: i32,
615 ) -> Result<Vec<Self::Request>, Self::Error> {
616 let status = status.to_string();
617 let fetch_query = "SELECT * FROM Jobs WHERE status = ? AND job_type = ? ORDER BY done_at DESC, run_at DESC LIMIT 10 OFFSET ?";
618 let res: Vec<SqlRequest<String>> = sqlx::query_as(fetch_query)
619 .bind(status)
620 .bind(self.get_config().namespace())
621 .bind(((page - 1) * 10).to_string())
622 .fetch_all(self.pool())
623 .await?;
624 Ok(res
625 .into_iter()
626 .map(|j| {
627 let (req, ctx) = j.req.take_parts();
628 let req = JsonCodec::<String>::decode(req).unwrap();
629 Request::new_with_ctx(req, ctx)
630 })
631 .collect())
632 }
633
634 async fn list_workers(&self) -> Result<Vec<Worker<WorkerState>>, Self::Error> {
635 let fetch_query =
636 "SELECT id, layers, last_seen FROM Workers WHERE worker_type = ? ORDER BY last_seen DESC LIMIT 20 OFFSET ?";
637 let res: Vec<(String, String, i64)> = sqlx::query_as(fetch_query)
638 .bind(self.get_config().namespace())
639 .bind(0)
640 .fetch_all(self.pool())
641 .await?;
642 Ok(res
643 .into_iter()
644 .map(|w| Worker::new(WorkerId::new(w.0), WorkerState::new::<Self>(w.1)))
645 .collect())
646 }
647}
648
649#[cfg(test)]
650mod tests {
651
652 use crate::sql_storage_tests;
653
654 use super::*;
655 use apalis_core::request::State;
656 use chrono::Utc;
657 use email_service::example_good_email;
658 use email_service::Email;
659 use futures::StreamExt;
660
661 use apalis_core::generic_storage_test;
662 use apalis_core::test_utils::apalis_test_service_fn;
663 use apalis_core::test_utils::TestWrapper;
664
665 generic_storage_test!(setup);
666 sql_storage_tests!(setup::<Email>, SqliteStorage<Email>, Email);
667
668 async fn setup<T: Serialize + DeserializeOwned>() -> SqliteStorage<T> {
670 let pool = SqlitePool::connect("sqlite::memory:").await.unwrap();
674 SqliteStorage::setup(&pool)
675 .await
676 .expect("failed to migrate DB");
677 let config = Config::new("apalis::test");
678 let storage = SqliteStorage::<T>::new_with_config(pool, config);
679
680 storage
681 }
682
683 #[tokio::test]
684 async fn test_inmemory_sqlite_worker() {
685 let mut sqlite = setup().await;
686 sqlite
687 .push(Email {
688 subject: "Test Subject".to_string(),
689 to: "example@sqlite".to_string(),
690 text: "Some Text".to_string(),
691 })
692 .await
693 .expect("Unable to push job");
694 let len = sqlite.len().await.expect("Could not fetch the jobs count");
695 assert_eq!(len, 1);
696 }
697
698 async fn consume_one(
699 storage: &mut SqliteStorage<Email>,
700 worker: &Worker<Context>,
701 ) -> Request<Email, SqlContext> {
702 let mut stream = storage
703 .stream_jobs(worker, std::time::Duration::from_secs(10), 1)
704 .boxed();
705 stream
706 .next()
707 .await
708 .expect("stream is empty")
709 .expect("failed to poll job")
710 .expect("no job is pending")
711 }
712
713 async fn register_worker_at(
714 storage: &mut SqliteStorage<Email>,
715 last_seen: i64,
716 ) -> Worker<Context> {
717 let worker_id = WorkerId::new("test-worker");
718
719 let worker = Worker::new(worker_id, Default::default());
720 storage
721 .keep_alive_at(&worker, last_seen)
722 .await
723 .expect("failed to register worker");
724 worker.start();
725 worker
726 }
727
728 async fn register_worker(storage: &mut SqliteStorage<Email>) -> Worker<Context> {
729 register_worker_at(storage, Utc::now().timestamp()).await
730 }
731
732 async fn push_email(storage: &mut SqliteStorage<Email>, email: Email) {
733 storage.push(email).await.expect("failed to push a job");
734 }
735
736 async fn get_job(
737 storage: &mut SqliteStorage<Email>,
738 job_id: &TaskId,
739 ) -> Request<Email, SqlContext> {
740 storage
741 .fetch_by_id(job_id)
742 .await
743 .expect("failed to fetch job by id")
744 .expect("no job found by id")
745 }
746
747 #[tokio::test]
748 async fn test_consume_last_pushed_job() {
749 let mut storage = setup().await;
750 let worker = register_worker(&mut storage).await;
751
752 push_email(&mut storage, example_good_email()).await;
753 let len = storage.len().await.expect("Could not fetch the jobs count");
754 assert_eq!(len, 1);
755
756 let job = consume_one(&mut storage, &worker).await;
757 let ctx = job.parts.context;
758 assert_eq!(*ctx.status(), State::Running);
759 assert_eq!(*ctx.lock_by(), Some(worker.id().clone()));
760 assert!(ctx.lock_at().is_some());
761 }
762
763 #[tokio::test]
764 async fn test_acknowledge_job() {
765 let mut storage = setup().await;
766 let worker = register_worker(&mut storage).await;
767
768 push_email(&mut storage, example_good_email()).await;
769 let job = consume_one(&mut storage, &worker).await;
770 let job_id = &job.parts.task_id;
771 let ctx = &job.parts.context;
772 let res = 1usize;
773 storage
774 .ack(
775 ctx,
776 &Response::success(res, job_id.clone(), job.parts.attempt.clone()),
777 )
778 .await
779 .expect("failed to acknowledge the job");
780
781 let job = get_job(&mut storage, job_id).await;
782 let ctx = job.parts.context;
783 assert_eq!(*ctx.status(), State::Done);
784 assert!(ctx.done_at().is_some());
785 }
786
787 #[tokio::test]
788 async fn test_kill_job() {
789 let mut storage = setup().await;
790
791 push_email(&mut storage, example_good_email()).await;
792
793 let worker = register_worker(&mut storage).await;
794
795 let job = consume_one(&mut storage, &worker).await;
796 let job_id = &job.parts.task_id;
797
798 storage
799 .kill(&worker.id(), job_id)
800 .await
801 .expect("failed to kill job");
802
803 let job = get_job(&mut storage, job_id).await;
804 let ctx = job.parts.context;
805 assert_eq!(*ctx.status(), State::Killed);
806 assert!(ctx.done_at().is_some());
807 }
808
809 #[tokio::test]
810 async fn test_heartbeat_renqueueorphaned_pulse_last_seen_6min() {
811 let mut storage = setup().await;
812
813 push_email(&mut storage, example_good_email()).await;
814
815 let six_minutes_ago = Utc::now() - Duration::from_secs(6 * 60);
816
817 let five_minutes_ago = Utc::now() - Duration::from_secs(5 * 60);
818 let worker = register_worker_at(&mut storage, six_minutes_ago.timestamp()).await;
819
820 let job = consume_one(&mut storage, &worker).await;
821 let job_id = &job.parts.task_id;
822 storage
823 .reenqueue_orphaned(1, five_minutes_ago)
824 .await
825 .expect("failed to heartbeat");
826 let job = get_job(&mut storage, job_id).await;
827 let ctx = &job.parts.context;
828 assert_eq!(*ctx.status(), State::Pending);
829 assert!(ctx.done_at().is_none());
830 assert!(ctx.lock_by().is_none());
831 assert!(ctx.lock_at().is_none());
832 assert_eq!(*ctx.last_error(), Some("Job was abandoned".to_owned()));
833 assert_eq!(job.parts.attempt.current(), 1);
834
835 let job = consume_one(&mut storage, &worker).await;
836 let ctx = &job.parts.context;
837 job.parts.attempt.increment();
839 storage
840 .ack(
841 ctx,
842 &Response::new(Ok("success".to_owned()), job_id.clone(), job.parts.attempt),
843 )
844 .await
845 .unwrap();
846 let job = get_job(&mut storage, &job_id).await;
849 let ctx = &job.parts.context;
850 assert_eq!(*ctx.status(), State::Done);
851 assert_eq!(*ctx.lock_by(), Some(worker.id().clone()));
852 assert!(ctx.lock_at().is_some());
853 assert_eq!(*ctx.last_error(), Some("{\"Ok\":\"success\"}".to_owned()));
854 assert_eq!(job.parts.attempt.current(), 2);
855 }
856
857 #[tokio::test]
858 async fn test_heartbeat_renqueueorphaned_pulse_last_seen_4min() {
859 let mut storage = setup().await;
860
861 push_email(&mut storage, example_good_email()).await;
862
863 let six_minutes_ago = Utc::now() - Duration::from_secs(6 * 60);
864 let four_minutes_ago = Utc::now() - Duration::from_secs(4 * 60);
865 let worker = register_worker_at(&mut storage, four_minutes_ago.timestamp()).await;
866
867 let job = consume_one(&mut storage, &worker).await;
868 let job_id = job.parts.task_id;
869 storage
870 .reenqueue_orphaned(1, six_minutes_ago)
871 .await
872 .expect("failed to heartbeat");
873
874 let job = get_job(&mut storage, &job_id).await;
875 let ctx = &job.parts.context;
876
877 job.parts.attempt.increment();
879 storage
880 .ack(
881 ctx,
882 &Response::new(Ok("success".to_owned()), job_id.clone(), job.parts.attempt),
883 )
884 .await
885 .unwrap();
886 let job = get_job(&mut storage, &job_id).await;
889 let ctx = &job.parts.context;
890 assert_eq!(*ctx.status(), State::Done);
891 assert_eq!(*ctx.lock_by(), Some(worker.id().clone()));
892 assert!(ctx.lock_at().is_some());
893 assert_eq!(*ctx.last_error(), Some("{\"Ok\":\"success\"}".to_owned()));
894 assert_eq!(job.parts.attempt.current(), 1);
895 }
896}