1#![doc = include_str!("../README.md")]
2use std::{fmt, marker::PhantomData};
5
6use apalis_codec::json::JsonCodec;
7use apalis_core::{
8 backend::{Backend, BackendExt, TaskStream, codec::Codec, queue::Queue},
9 features_table,
10 layers::Stack,
11 task::Task,
12 worker::{context::WorkerContext, ext::ack::AcknowledgeLayer},
13};
14use apalis_sql::context::SqlContext;
15use futures::{
16 FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt,
17 channel::mpsc::{self},
18 future::ready,
19 stream::{self, BoxStream, select},
20};
21pub use sqlx::{
22 Connection, Pool, Sqlite, SqliteConnection, SqlitePool,
23 error::Error as SqlxError,
24 pool::{PoolConnection, PoolOptions},
25 sqlite::{SqliteConnectOptions, SqliteOperation},
26};
27use ulid::Ulid;
28
29use crate::{
30 ack::{LockTaskLayer, SqliteAck},
31 callback::update_hook_callback,
32 fetcher::{SqliteFetcher, SqlitePollFetcher, fetch_next},
33 queries::{
34 keep_alive::{initial_heartbeat, keep_alive, keep_alive_stream},
35 reenqueue_orphaned::reenqueue_orphaned_stream,
36 },
37 sink::SqliteSink,
38};
39
40mod ack;
41mod callback;
42pub mod fetcher;
44mod from_row;
45pub mod queries;
47mod shared;
48pub mod sink;
50
51pub type SqliteContext = SqlContext<SqlitePool>;
53
54pub type SqliteTask<Args> = Task<Args, SqliteContext, Ulid>;
56pub use apalis_sql::config::Config;
57pub use apalis_sql::ext::TaskBuilderExt;
58pub use callback::{DbEvent, HookCallbackListener};
59pub use shared::{SharedSqliteError, SharedSqliteStorage};
60
61pub type CompactType = Vec<u8>;
63
64const JOBS_TABLE: &str = "Jobs";
65
66#[doc = features_table! {
71 setup = r#"
72 # {
73 # use apalis_sqlite::SqliteStorage;
74 # use sqlx::SqlitePool;
75 # let pool = SqlitePool::connect(":memory:").await.unwrap();
76 # SqliteStorage::setup(&pool).await.unwrap();
77 # SqliteStorage::new(&pool)
78 # };
79 "#,
80
81 Backend => supported("Supports storage and retrieval of tasks", true),
82 TaskSink => supported("Ability to push new tasks", true),
83 Serialization => supported("Serialization support for arguments", true),
84 Workflow => supported("Flexible enough to support workflows", true),
85 WebUI => supported("Expose a web interface for monitoring tasks", true),
86 FetchById => supported("Allow fetching a task by its ID", false),
87 RegisterWorker => supported("Allow registering a worker with the backend", false),
88 MakeShared => supported("Share one connection across multiple workers via [`SharedSqliteStorage`]", false),
89 WaitForCompletion => supported("Wait for tasks to complete without blocking", true),
90 ResumeById => supported("Resume a task by its ID", false),
91 ResumeAbandoned => supported("Resume abandoned tasks", false),
92 ListWorkers => supported("List all workers registered with the backend", false),
93 ListTasks => supported("List all tasks in the backend", false),
94}]
95#[pin_project::pin_project]
96pub struct SqliteStorage<T, C, Fetcher> {
97 pool: Pool<Sqlite>,
98 job_type: PhantomData<T>,
99 config: Config,
100 codec: PhantomData<C>,
101 #[pin]
102 sink: SqliteSink<T, CompactType, C>,
103 #[pin]
104 fetcher: Fetcher,
105}
106
107impl<T, C, F> fmt::Debug for SqliteStorage<T, C, F> {
108 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
109 f.debug_struct("SqliteStorage")
110 .field("pool", &self.pool)
111 .field("job_type", &"PhantomData<T>")
112 .field("config", &self.config)
113 .field("codec", &std::any::type_name::<C>())
114 .finish()
115 }
116}
117
118impl<T, C, F: Clone> Clone for SqliteStorage<T, C, F> {
119 fn clone(&self) -> Self {
120 Self {
121 sink: self.sink.clone(),
122 pool: self.pool.clone(),
123 job_type: PhantomData,
124 config: self.config.clone(),
125 codec: self.codec,
126 fetcher: self.fetcher.clone(),
127 }
128 }
129}
130
131impl SqliteStorage<(), (), ()> {
132 #[cfg(feature = "migrate")]
134 pub async fn setup(pool: &Pool<Sqlite>) -> Result<(), sqlx::Error> {
135 sqlx::query("PRAGMA journal_mode = 'WAL';")
136 .execute(pool)
137 .await?;
138 sqlx::query("PRAGMA temp_store = 2;").execute(pool).await?;
139 sqlx::query("PRAGMA synchronous = NORMAL;")
140 .execute(pool)
141 .await?;
142 sqlx::query("PRAGMA cache_size = 64000;")
143 .execute(pool)
144 .await?;
145 Self::migrations().run(pool).await?;
146 Ok(())
147 }
148
149 #[cfg(feature = "migrate")]
151 #[must_use]
152 pub fn migrations() -> sqlx::migrate::Migrator {
153 sqlx::migrate!("./migrations")
154 }
155}
156
157impl<T> SqliteStorage<T, (), ()> {
158 #[must_use]
160 pub fn new(
161 pool: &Pool<Sqlite>,
162 ) -> SqliteStorage<T, JsonCodec<CompactType>, fetcher::SqliteFetcher> {
163 let config = Config::new(std::any::type_name::<T>());
164 SqliteStorage {
165 pool: pool.clone(),
166 job_type: PhantomData,
167 sink: SqliteSink::new(pool, &config),
168 config,
169 codec: PhantomData,
170 fetcher: fetcher::SqliteFetcher,
171 }
172 }
173
174 #[must_use]
176 pub fn new_in_queue(
177 pool: &Pool<Sqlite>,
178 queue: &str,
179 ) -> SqliteStorage<T, JsonCodec<CompactType>, fetcher::SqliteFetcher> {
180 let config = Config::new(queue);
181 SqliteStorage {
182 pool: pool.clone(),
183 job_type: PhantomData,
184 sink: SqliteSink::new(pool, &config),
185 config,
186 codec: PhantomData,
187 fetcher: fetcher::SqliteFetcher,
188 }
189 }
190
191 #[must_use]
193 pub fn new_with_config(
194 pool: &Pool<Sqlite>,
195 config: &Config,
196 ) -> SqliteStorage<T, JsonCodec<CompactType>, fetcher::SqliteFetcher> {
197 SqliteStorage {
198 pool: pool.clone(),
199 job_type: PhantomData,
200 config: config.clone(),
201 codec: PhantomData,
202 sink: SqliteSink::new(pool, config),
203 fetcher: fetcher::SqliteFetcher,
204 }
205 }
206
207 #[must_use]
209 pub fn new_with_callback(
210 url: &str,
211 config: &Config,
212 ) -> SqliteStorage<T, JsonCodec<CompactType>, HookCallbackListener> {
213 let (tx, rx) = mpsc::unbounded::<DbEvent>();
214
215 let listener = HookCallbackListener::new(rx);
216 let pool = PoolOptions::<Sqlite>::new()
217 .after_connect(move |conn, _meta| {
218 let mut tx = tx.clone();
219 Box::pin(async move {
220 let mut lock_handle = conn.lock_handle().await?;
221 lock_handle.set_update_hook(move |ev| update_hook_callback(ev, &mut tx));
222 Ok(())
223 })
224 })
225 .connect_lazy(url)
226 .expect("Failed to create Sqlite pool");
227 SqliteStorage {
228 pool: pool.clone(),
229 job_type: PhantomData,
230 config: config.clone(),
231 codec: PhantomData,
232 sink: SqliteSink::new(&pool, config),
233 fetcher: listener,
234 }
235 }
236}
237
238impl<T, C, F> SqliteStorage<T, C, F> {
239 pub fn with_codec<D>(self) -> SqliteStorage<T, D, F> {
241 SqliteStorage {
242 sink: SqliteSink::new(&self.pool, &self.config),
243 pool: self.pool,
244 job_type: PhantomData,
245 config: self.config,
246 codec: PhantomData,
247 fetcher: self.fetcher,
248 }
249 }
250
251 pub fn config(&self) -> &Config {
253 &self.config
254 }
255
256 pub fn pool(&self) -> &Pool<Sqlite> {
258 &self.pool
259 }
260}
261
262impl<Args, Decode> Backend for SqliteStorage<Args, Decode, SqliteFetcher>
263where
264 Args: Send + 'static + Unpin,
265 Decode: Codec<Args, Compact = CompactType> + 'static + Send,
266 Decode::Error: std::error::Error + Send + Sync + 'static,
267{
268 type Args = Args;
269 type IdType = Ulid;
270
271 type Context = SqliteContext;
272
273 type Error = sqlx::Error;
274
275 type Stream = TaskStream<SqliteTask<Args>, sqlx::Error>;
276
277 type Beat = BoxStream<'static, Result<(), sqlx::Error>>;
278
279 type Layer = Stack<LockTaskLayer, AcknowledgeLayer<SqliteAck>>;
280
281 fn heartbeat(&self, worker: &WorkerContext) -> Self::Beat {
282 let pool = self.pool.clone();
283 let config = self.config.clone();
284 let worker = worker.clone();
285 let keep_alive = keep_alive_stream(pool, config, worker);
286 let reenqueue = reenqueue_orphaned_stream(
287 self.pool.clone(),
288 self.config.clone(),
289 *self.config.keep_alive(),
290 )
291 .map_ok(|_| ());
292 futures::stream::select(keep_alive, reenqueue).boxed()
293 }
294
295 fn middleware(&self) -> Self::Layer {
296 let lock = LockTaskLayer::new(self.pool.clone());
297 let ack = AcknowledgeLayer::new(SqliteAck::new(self.pool.clone()));
298 Stack::new(lock, ack)
299 }
300
301 fn poll(self, worker: &WorkerContext) -> Self::Stream {
302 self.poll_default(worker)
303 .map(|a| match a {
304 Ok(Some(task)) => Ok(Some(
305 task.try_map(|t| Decode::decode(&t))
306 .map_err(|e| sqlx::Error::Decode(e.into()))?,
307 )),
308 Ok(None) => Ok(None),
309 Err(e) => Err(e),
310 })
311 .boxed()
312 }
313}
314
315impl<Args, Decode: Send + 'static> BackendExt for SqliteStorage<Args, Decode, SqliteFetcher>
316where
317 Self: Backend<Args = Args, IdType = Ulid, Context = SqliteContext, Error = sqlx::Error>,
318 Decode: Codec<Args, Compact = CompactType> + Send + 'static,
319 Decode::Error: std::error::Error + Send + Sync + 'static,
320 Args: Send + 'static + Unpin,
321{
322 type Codec = Decode;
323 type Compact = CompactType;
324 type CompactStream = TaskStream<SqliteTask<Self::Compact>, sqlx::Error>;
325
326 fn get_queue(&self) -> Queue {
327 self.config.queue().to_owned()
328 }
329
330 fn poll_compact(self, worker: &WorkerContext) -> Self::CompactStream {
331 self.poll_default(worker).boxed()
332 }
333}
334
335impl<Args, Decode> Backend for SqliteStorage<Args, Decode, HookCallbackListener>
336where
337 Args: Send + 'static + Unpin,
338 Decode: Codec<Args, Compact = CompactType> + Send + 'static,
339 Decode::Error: std::error::Error + Send + Sync + 'static,
340{
341 type Args = Args;
342 type IdType = Ulid;
343
344 type Context = SqliteContext;
345
346 type Error = sqlx::Error;
347
348 type Stream = TaskStream<SqliteTask<Args>, sqlx::Error>;
349
350 type Beat = BoxStream<'static, Result<(), sqlx::Error>>;
351
352 type Layer = Stack<LockTaskLayer, AcknowledgeLayer<SqliteAck>>;
353
354 fn heartbeat(&self, worker: &WorkerContext) -> Self::Beat {
355 let pool = self.pool.clone();
356 let config = self.config.clone();
357 let worker = worker.clone();
358 let keep_alive = keep_alive_stream(pool, config, worker);
359 let reenqueue = reenqueue_orphaned_stream(
360 self.pool.clone(),
361 self.config.clone(),
362 *self.config.keep_alive(),
363 )
364 .map_ok(|_| ());
365 futures::stream::select(keep_alive, reenqueue).boxed()
366 }
367
368 fn middleware(&self) -> Self::Layer {
369 let lock = LockTaskLayer::new(self.pool.clone());
370 let ack = AcknowledgeLayer::new(SqliteAck::new(self.pool.clone()));
371 Stack::new(lock, ack)
372 }
373
374 fn poll(self, worker: &WorkerContext) -> Self::Stream {
375 self.poll_with_listener(worker)
376 .map(|a| match a {
377 Ok(Some(task)) => Ok(Some(
378 task.try_map(|t| Decode::decode(&t))
379 .map_err(|e| sqlx::Error::Decode(e.into()))?,
380 )),
381 Ok(None) => Ok(None),
382 Err(e) => Err(e),
383 })
384 .boxed()
385 }
386}
387
388impl<Args, Decode: Send + 'static> BackendExt for SqliteStorage<Args, Decode, HookCallbackListener>
389where
390 Self: Backend<Args = Args, IdType = Ulid, Context = SqliteContext, Error = sqlx::Error>,
391 Decode: Codec<Args, Compact = CompactType> + Send + 'static,
392 Decode::Error: std::error::Error + Send + Sync + 'static,
393 Args: Send + 'static + Unpin,
394{
395 type Codec = Decode;
396 type Compact = CompactType;
397 type CompactStream = TaskStream<SqliteTask<Self::Compact>, sqlx::Error>;
398
399 fn get_queue(&self) -> Queue {
400 self.config.queue().to_owned()
401 }
402
403 fn poll_compact(self, worker: &WorkerContext) -> Self::CompactStream {
404 self.poll_with_listener(worker).boxed()
405 }
406}
407
408impl<Args, Decode: Send + 'static> SqliteStorage<Args, Decode, HookCallbackListener> {
409 fn poll_with_listener(
410 self,
411 worker: &WorkerContext,
412 ) -> impl Stream<Item = Result<Option<SqliteTask<CompactType>>, sqlx::Error>> + Send + 'static
413 {
414 let pool = self.pool.clone();
415 let config = self.config.clone();
416 let worker = worker.clone();
417 let register_worker = initial_heartbeat(
418 self.pool.clone(),
419 self.config.clone(),
420 worker.clone(),
421 "SqliteStorageWithHook",
422 );
423 let register_worker = stream::once(register_worker.map_ok(|_| None));
424 let eager_fetcher: SqlitePollFetcher<CompactType, Decode> =
425 SqlitePollFetcher::new(&self.pool, &self.config, &worker);
426 let lazy_fetcher = self
427 .fetcher
428 .filter(|a| {
429 ready(a.operation() == &SqliteOperation::Insert && a.table_name() == JOBS_TABLE)
430 })
431 .inspect(|db_event| {
432 log::debug!("Received new job event: {db_event:?}");
433 })
434 .ready_chunks(self.config.buffer_size())
435 .then(move |_| fetch_next(pool.clone(), config.clone(), worker.clone()))
436 .flat_map(|res| match res {
437 Ok(tasks) => stream::iter(tasks).map(Ok).boxed(),
438 Err(e) => stream::iter(vec![Err(e)]).boxed(),
439 })
440 .map(|res| match res {
441 Ok(task) => Ok(Some(task)),
442 Err(e) => Err(e),
443 });
444
445 register_worker.chain(select(lazy_fetcher, eager_fetcher))
446 }
447}
448
449impl<Args, Decode: Send + 'static, F> SqliteStorage<Args, Decode, F> {
450 fn poll_default(
451 self,
452 worker: &WorkerContext,
453 ) -> impl Stream<Item = Result<Option<SqliteTask<CompactType>>, sqlx::Error>> + Send + 'static
454 {
455 let fut = initial_heartbeat(
456 self.pool.clone(),
457 self.config().clone(),
458 worker.clone(),
459 "SqliteStorage",
460 );
461 let register = stream::once(fut.map(|_| Ok(None)));
462 register.chain(SqlitePollFetcher::<CompactType, Decode>::new(
463 &self.pool,
464 &self.config,
465 worker,
466 ))
467 }
468}
469
470#[cfg(test)]
471mod tests {
472 use std::time::Duration;
473
474 use apalis::prelude::*;
475 use apalis_workflow::*;
476 use chrono::Local;
477 use serde::{Deserialize, Serialize};
478 use sqlx::SqlitePool;
479
480 use super::*;
481
482 #[tokio::test]
483 async fn basic_worker() {
484 const ITEMS: usize = 10;
485 let pool = SqlitePool::connect(std::env::var("DATABASE_URL").unwrap().as_str())
486 .await
487 .unwrap();
488 SqliteStorage::setup(&pool).await.unwrap();
489
490 let mut backend = SqliteStorage::new(&pool);
491
492 let mut start = 0;
493
494 let mut items = stream::repeat_with(move || {
495 start += 1;
496 start
497 })
498 .take(ITEMS);
499 backend.push_stream(&mut items).await.unwrap();
500
501 println!("Starting worker at {}", Local::now());
502
503 async fn send_reminder(item: usize, wrk: WorkerContext) -> Result<(), BoxDynError> {
504 if ITEMS == item {
505 wrk.stop().unwrap();
506 }
507 Ok(())
508 }
509
510 let worker = WorkerBuilder::new("rango-tango-basic")
511 .backend(backend)
512 .build(send_reminder);
513 worker.run().await.unwrap();
514 }
515
516 #[tokio::test]
517 async fn hooked_worker() {
518 const ITEMS: usize = 20;
519
520 let lazy_strategy = StrategyBuilder::new()
521 .apply(IntervalStrategy::new(Duration::from_secs(5)))
522 .build();
523 let config = Config::new("rango-tango-queue")
524 .with_poll_interval(lazy_strategy)
525 .set_buffer_size(5);
526 let backend = SqliteStorage::new_with_callback(
527 std::env::var("DATABASE_URL").unwrap().as_str(),
528 &config,
529 );
530 let pool = backend.pool().clone();
531 SqliteStorage::setup(&pool).await.unwrap();
532
533 tokio::spawn(async move {
534 tokio::time::sleep(Duration::from_secs(2)).await;
535 let mut start = 0;
536
537 let items = stream::repeat_with(move || {
538 start += 1;
539
540 Task::builder(serde_json::to_vec(&start).unwrap())
541 .with_ctx(SqliteContext::new().with_priority(start))
542 .build()
543 })
544 .take(ITEMS)
545 .collect::<Vec<_>>()
546 .await;
547 crate::sink::push_tasks(pool.clone(), config, items)
548 .await
549 .unwrap();
550 });
551
552 async fn send_reminder(item: usize, wrk: WorkerContext) -> Result<(), BoxDynError> {
553 if item == 1 {
555 apalis_core::timer::sleep(Duration::from_secs(1)).await;
556 wrk.stop().unwrap();
557 }
558 Ok(())
559 }
560
561 let worker = WorkerBuilder::new("rango-tango-hooked")
562 .backend(backend)
563 .build(send_reminder);
564 worker.run().await.unwrap();
565 }
566
567 #[tokio::test]
568 async fn test_workflow() {
569 let workflow = Workflow::new("odd-numbers-workflow")
570 .and_then(|a: usize| async move { Ok::<_, BoxDynError>((0..=a).collect::<Vec<_>>()) })
571 .filter_map(|x| async move { if x % 2 != 0 { Some(x) } else { None } })
572 .filter_map(|x| async move { if x % 3 != 0 { Some(x) } else { None } })
573 .filter_map(|x| async move { if x % 5 != 0 { Some(x) } else { None } })
574 .delay_for(Duration::from_millis(1000))
575 .and_then(|a: Vec<usize>| async move {
576 println!("Sum: {}", a.iter().sum::<usize>());
577 Err::<(), BoxDynError>("Intentional Error".into())
578 });
579
580 let mut sqlite = SqliteStorage::new_with_callback(
581 std::env::var("DATABASE_URL").unwrap().as_str(),
582 &Config::new("workflow-queue").with_poll_interval(
583 StrategyBuilder::new()
584 .apply(IntervalStrategy::new(Duration::from_millis(100)))
585 .build(),
586 ),
587 );
588
589 SqliteStorage::setup(sqlite.pool()).await.unwrap();
590
591 sqlite.push_start(100usize).await.unwrap();
592
593 let worker = WorkerBuilder::new("rango-tango-workflow")
594 .backend(sqlite)
595 .on_event(|ctx, ev| {
596 println!("On Event = {:?}", ev);
597 if matches!(ev, Event::Error(_)) {
598 ctx.stop().unwrap();
599 }
600 })
601 .build(workflow);
602 worker.run().await.unwrap();
603 }
604
605 #[tokio::test]
606 async fn test_workflow_complete() {
607 #[derive(Debug, Serialize, Deserialize, Clone)]
608 struct PipelineConfig {
609 min_confidence: f32,
610 enable_sentiment: bool,
611 }
612
613 #[derive(Debug, Serialize, Deserialize)]
614 struct UserInput {
615 text: String,
616 }
617
618 #[derive(Debug, Serialize, Deserialize)]
619 struct Classified {
620 text: String,
621 label: String,
622 confidence: f32,
623 }
624
625 #[derive(Debug, Serialize, Deserialize)]
626 struct Summary {
627 text: String,
628 sentiment: Option<String>,
629 }
630
631 let workflow = Workflow::new("text-pipeline")
632 .and_then(|input: UserInput, mut worker: WorkerContext| async move {
634 worker.emit(&Event::custom(format!(
635 "Preprocessing input: {}",
636 input.text
637 )));
638 let processed = input.text.to_lowercase();
639 Ok::<_, BoxDynError>(processed)
640 })
641 .and_then(|text: String| async move {
643 let confidence = 0.85; let items = text.split_whitespace().collect::<Vec<_>>();
645 let results = items
646 .into_iter()
647 .map(|x| Classified {
648 text: x.to_string(),
649 label: if x.contains("rust") {
650 "Tech"
651 } else {
652 "General"
653 }
654 .to_string(),
655 confidence,
656 })
657 .collect::<Vec<_>>();
658 Ok::<_, BoxDynError>(results)
659 })
660 .filter_map(
662 |c: Classified| async move { if c.confidence >= 0.6 { Some(c) } else { None } },
663 )
664 .filter_map(move |c: Classified, config: Data<PipelineConfig>| {
665 let cfg = config.enable_sentiment;
666 async move {
667 if !cfg {
668 return Some(Summary {
669 text: c.text,
670 sentiment: None,
671 });
672 }
673
674 let sentiment = if c.text.contains("delightful") {
676 "positive"
677 } else {
678 "neutral"
679 };
680 Some(Summary {
681 text: c.text,
682 sentiment: Some(sentiment.to_string()),
683 })
684 }
685 })
686 .and_then(|a: Vec<Summary>, mut worker: WorkerContext| async move {
687 worker.emit(&Event::Custom(Box::new(format!(
688 "Generated {} summaries",
689 a.len()
690 ))));
691 worker.stop()
692 });
693
694 let mut sqlite =
695 SqliteStorage::new_with_callback(":memory:", &Config::new("text-pipeline"));
696
697 SqliteStorage::setup(sqlite.pool()).await.unwrap();
698
699 let input = UserInput {
700 text: "Rust makes systems programming delightful!".to_string(),
701 };
702 sqlite.push_start(input).await.unwrap();
703
704 let worker = WorkerBuilder::new("rango-tango")
705 .backend(sqlite)
706 .data(PipelineConfig {
707 min_confidence: 0.8,
708 enable_sentiment: true,
709 })
710 .on_event(|ctx, ev| match ev {
711 Event::Custom(msg) => {
712 if let Some(m) = msg.downcast_ref::<String>() {
713 println!("Custom Message: {}", m);
714 }
715 }
716 Event::Error(_) => {
717 println!("On Error = {:?}", ev);
718 ctx.stop().unwrap();
719 }
720 _ => {
721 println!("On Event = {:?}", ev);
722 }
723 })
724 .build(workflow);
725 worker.run().await.unwrap();
726 }
727}