1#![doc = include_str!("../README.md")]
2use std::{fmt, marker::PhantomData};
5
6use apalis_core::{
7 backend::{
8 Backend, BackendExt, TaskStream,
9 codec::{Codec, json::JsonCodec},
10 },
11 features_table,
12 layers::Stack,
13 task::Task,
14 worker::{context::WorkerContext, ext::ack::AcknowledgeLayer},
15};
16pub use apalis_sql::context::SqlContext;
17use futures::{
18 FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt,
19 channel::mpsc::{self},
20 future::ready,
21 stream::{self, BoxStream, select},
22};
23pub use sqlx::{
24 Connection, Pool, Sqlite, SqliteConnection, SqlitePool,
25 error::Error as SqlxError,
26 pool::{PoolConnection, PoolOptions},
27 sqlite::{SqliteConnectOptions, SqliteOperation},
28};
29use ulid::Ulid;
30
31use crate::{
32 ack::{LockTaskLayer, SqliteAck},
33 callback::update_hook_callback,
34 fetcher::{SqliteFetcher, SqlitePollFetcher, fetch_next},
35 queries::{
36 keep_alive::{initial_heartbeat, keep_alive, keep_alive_stream},
37 reenqueue_orphaned::reenqueue_orphaned_stream,
38 },
39 sink::SqliteSink,
40};
41
42mod ack;
43mod callback;
44mod config;
45pub mod fetcher;
47mod from_row;
48pub mod queries;
50mod shared;
51pub mod sink;
53
54pub type SqliteTask<Args> = Task<Args, SqlContext, Ulid>;
56pub use callback::{DbEvent, HookCallbackListener};
57pub use config::Config;
58pub use shared::{SharedSqliteError, SharedSqliteStorage};
59
60pub type CompactType = Vec<u8>;
62
63const JOBS_TABLE: &str = "Jobs";
64
65#[doc = features_table! {
70 setup = r#"
71 # {
72 # use apalis_sqlite::SqliteStorage;
73 # use sqlx::SqlitePool;
74 # let pool = SqlitePool::connect(":memory:").await.unwrap();
75 # SqliteStorage::setup(&pool).await.unwrap();
76 # SqliteStorage::new(&pool)
77 # };
78 "#,
79
80 Backend => supported("Supports storage and retrieval of tasks", true),
81 TaskSink => supported("Ability to push new tasks", true),
82 Serialization => supported("Serialization support for arguments", true),
83 Workflow => supported("Flexible enough to support workflows", true),
84 WebUI => supported("Expose a web interface for monitoring tasks", true),
85 FetchById => supported("Allow fetching a task by its ID", false),
86 RegisterWorker => supported("Allow registering a worker with the backend", false),
87 MakeShared => supported("Share one connection across multiple workers via [`SharedSqliteStorage`]", false),
88 WaitForCompletion => supported("Wait for tasks to complete without blocking", true),
89 ResumeById => supported("Resume a task by its ID", false),
90 ResumeAbandoned => supported("Resume abandoned tasks", false),
91 ListWorkers => supported("List all workers registered with the backend", false),
92 ListTasks => supported("List all tasks in the backend", false),
93}]
94#[pin_project::pin_project]
95pub struct SqliteStorage<T, C, Fetcher> {
96 pool: Pool<Sqlite>,
97 job_type: PhantomData<T>,
98 config: Config,
99 codec: PhantomData<C>,
100 #[pin]
101 sink: SqliteSink<T, CompactType, C>,
102 #[pin]
103 fetcher: Fetcher,
104}
105
106impl<T, C, F> fmt::Debug for SqliteStorage<T, C, F> {
107 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
108 f.debug_struct("SqliteStorage")
109 .field("pool", &self.pool)
110 .field("job_type", &"PhantomData<T>")
111 .field("config", &self.config)
112 .field("codec", &std::any::type_name::<C>())
113 .finish()
114 }
115}
116
117impl<T, C, F: Clone> Clone for SqliteStorage<T, C, F> {
118 fn clone(&self) -> Self {
119 Self {
120 sink: self.sink.clone(),
121 pool: self.pool.clone(),
122 job_type: PhantomData,
123 config: self.config.clone(),
124 codec: self.codec,
125 fetcher: self.fetcher.clone(),
126 }
127 }
128}
129
130impl SqliteStorage<(), (), ()> {
131 #[cfg(feature = "migrate")]
133 pub async fn setup(pool: &Pool<Sqlite>) -> Result<(), sqlx::Error> {
134 sqlx::query("PRAGMA journal_mode = 'WAL';")
135 .execute(pool)
136 .await?;
137 sqlx::query("PRAGMA temp_store = 2;").execute(pool).await?;
138 sqlx::query("PRAGMA synchronous = NORMAL;")
139 .execute(pool)
140 .await?;
141 sqlx::query("PRAGMA cache_size = 64000;")
142 .execute(pool)
143 .await?;
144 Self::migrations().run(pool).await?;
145 Ok(())
146 }
147
148 #[cfg(feature = "migrate")]
150 #[must_use]
151 pub fn migrations() -> sqlx::migrate::Migrator {
152 sqlx::migrate!("./migrations")
153 }
154}
155
156impl<T> SqliteStorage<T, (), ()> {
157 #[must_use]
159 pub fn new(
160 pool: &Pool<Sqlite>,
161 ) -> SqliteStorage<T, JsonCodec<CompactType>, fetcher::SqliteFetcher> {
162 let config = Config::new(std::any::type_name::<T>());
163 SqliteStorage {
164 pool: pool.clone(),
165 job_type: PhantomData,
166 sink: SqliteSink::new(pool, &config),
167 config,
168 codec: PhantomData,
169 fetcher: fetcher::SqliteFetcher,
170 }
171 }
172
173 #[must_use]
175 pub fn new_in_queue(
176 pool: &Pool<Sqlite>,
177 queue: &str,
178 ) -> SqliteStorage<T, JsonCodec<CompactType>, fetcher::SqliteFetcher> {
179 let config = Config::new(queue);
180 SqliteStorage {
181 pool: pool.clone(),
182 job_type: PhantomData,
183 sink: SqliteSink::new(pool, &config),
184 config,
185 codec: PhantomData,
186 fetcher: fetcher::SqliteFetcher,
187 }
188 }
189
190 #[must_use]
192 pub fn new_with_config(
193 pool: &Pool<Sqlite>,
194 config: &Config,
195 ) -> SqliteStorage<T, JsonCodec<CompactType>, fetcher::SqliteFetcher> {
196 SqliteStorage {
197 pool: pool.clone(),
198 job_type: PhantomData,
199 config: config.clone(),
200 codec: PhantomData,
201 sink: SqliteSink::new(pool, config),
202 fetcher: fetcher::SqliteFetcher,
203 }
204 }
205
206 #[must_use]
208 pub fn new_with_callback(
209 url: &str,
210 config: &Config,
211 ) -> SqliteStorage<T, JsonCodec<CompactType>, HookCallbackListener> {
212 let (tx, rx) = mpsc::unbounded::<DbEvent>();
213
214 let listener = HookCallbackListener::new(rx);
215 let pool = PoolOptions::<Sqlite>::new()
216 .after_connect(move |conn, _meta| {
217 let mut tx = tx.clone();
218 Box::pin(async move {
219 let mut lock_handle = conn.lock_handle().await?;
220 lock_handle.set_update_hook(move |ev| update_hook_callback(ev, &mut tx));
221 Ok(())
222 })
223 })
224 .connect_lazy(url)
225 .expect("Failed to create Sqlite pool");
226 SqliteStorage {
227 pool: pool.clone(),
228 job_type: PhantomData,
229 config: config.clone(),
230 codec: PhantomData,
231 sink: SqliteSink::new(&pool, config),
232 fetcher: listener,
233 }
234 }
235}
236
237impl<T, C, F> SqliteStorage<T, C, F> {
238 pub fn with_codec<D>(self) -> SqliteStorage<T, D, F> {
240 SqliteStorage {
241 sink: SqliteSink::new(&self.pool, &self.config),
242 pool: self.pool,
243 job_type: PhantomData,
244 config: self.config,
245 codec: PhantomData,
246 fetcher: self.fetcher,
247 }
248 }
249
250 pub fn config(&self) -> &Config {
252 &self.config
253 }
254
255 pub fn pool(&self) -> &Pool<Sqlite> {
257 &self.pool
258 }
259}
260
261impl<Args, Decode> Backend for SqliteStorage<Args, Decode, SqliteFetcher>
262where
263 Args: Send + 'static + Unpin,
264 Decode: Codec<Args, Compact = CompactType> + 'static + Send,
265 Decode::Error: std::error::Error + Send + Sync + 'static,
266{
267 type Args = Args;
268 type IdType = Ulid;
269
270 type Context = SqlContext;
271
272 type Error = sqlx::Error;
273
274 type Stream = TaskStream<SqliteTask<Args>, sqlx::Error>;
275
276 type Beat = BoxStream<'static, Result<(), sqlx::Error>>;
277
278 type Layer = Stack<LockTaskLayer, AcknowledgeLayer<SqliteAck>>;
279
280 fn heartbeat(&self, worker: &WorkerContext) -> Self::Beat {
281 let pool = self.pool.clone();
282 let config = self.config.clone();
283 let worker = worker.clone();
284 let keep_alive = keep_alive_stream(pool, config, worker);
285 let reenqueue = reenqueue_orphaned_stream(
286 self.pool.clone(),
287 self.config.clone(),
288 *self.config.keep_alive(),
289 )
290 .map_ok(|_| ());
291 futures::stream::select(keep_alive, reenqueue).boxed()
292 }
293
294 fn middleware(&self) -> Self::Layer {
295 let lock = LockTaskLayer::new(self.pool.clone());
296 let ack = AcknowledgeLayer::new(SqliteAck::new(self.pool.clone()));
297 Stack::new(lock, ack)
298 }
299
300 fn poll(self, worker: &WorkerContext) -> Self::Stream {
301 self.poll_default(worker)
302 .map(|a| match a {
303 Ok(Some(task)) => Ok(Some(
304 task.try_map(|t| Decode::decode(&t))
305 .map_err(|e| sqlx::Error::Decode(e.into()))?,
306 )),
307 Ok(None) => Ok(None),
308 Err(e) => Err(e),
309 })
310 .boxed()
311 }
312}
313
314impl<Args, Decode: Send + 'static> BackendExt for SqliteStorage<Args, Decode, SqliteFetcher>
315where
316 Self: Backend<Args = Args, IdType = Ulid, Context = SqlContext, Error = sqlx::Error>,
317 Decode: Codec<Args, Compact = CompactType> + Send + 'static,
318 Decode::Error: std::error::Error + Send + Sync + 'static,
319 Args: Send + 'static + Unpin,
320{
321 type Codec = Decode;
322 type Compact = CompactType;
323 type CompactStream = TaskStream<SqliteTask<Self::Compact>, sqlx::Error>;
324
325 fn poll_compact(self, worker: &WorkerContext) -> Self::CompactStream {
326 self.poll_default(worker).boxed()
327 }
328}
329
330impl<Args, Decode> Backend for SqliteStorage<Args, Decode, HookCallbackListener>
331where
332 Args: Send + 'static + Unpin,
333 Decode: Codec<Args, Compact = CompactType> + Send + 'static,
334 Decode::Error: std::error::Error + Send + Sync + 'static,
335{
336 type Args = Args;
337 type IdType = Ulid;
338
339 type Context = SqlContext;
340
341 type Error = sqlx::Error;
342
343 type Stream = TaskStream<SqliteTask<Args>, sqlx::Error>;
344
345 type Beat = BoxStream<'static, Result<(), sqlx::Error>>;
346
347 type Layer = Stack<LockTaskLayer, AcknowledgeLayer<SqliteAck>>;
348
349 fn heartbeat(&self, worker: &WorkerContext) -> Self::Beat {
350 let pool = self.pool.clone();
351 let config = self.config.clone();
352 let worker = worker.clone();
353 let keep_alive = keep_alive_stream(pool, config, worker);
354 let reenqueue = reenqueue_orphaned_stream(
355 self.pool.clone(),
356 self.config.clone(),
357 *self.config.keep_alive(),
358 )
359 .map_ok(|_| ());
360 futures::stream::select(keep_alive, reenqueue).boxed()
361 }
362
363 fn middleware(&self) -> Self::Layer {
364 let lock = LockTaskLayer::new(self.pool.clone());
365 let ack = AcknowledgeLayer::new(SqliteAck::new(self.pool.clone()));
366 Stack::new(lock, ack)
367 }
368
369 fn poll(self, worker: &WorkerContext) -> Self::Stream {
370 self.poll_with_listener(worker)
371 .map(|a| match a {
372 Ok(Some(task)) => Ok(Some(
373 task.try_map(|t| Decode::decode(&t))
374 .map_err(|e| sqlx::Error::Decode(e.into()))?,
375 )),
376 Ok(None) => Ok(None),
377 Err(e) => Err(e),
378 })
379 .boxed()
380 }
381}
382
383impl<Args, Decode: Send + 'static> BackendExt for SqliteStorage<Args, Decode, HookCallbackListener>
384where
385 Self: Backend<Args = Args, IdType = Ulid, Context = SqlContext, Error = sqlx::Error>,
386 Decode: Codec<Args, Compact = CompactType> + Send + 'static,
387 Decode::Error: std::error::Error + Send + Sync + 'static,
388 Args: Send + 'static + Unpin,
389{
390 type Codec = Decode;
391 type Compact = CompactType;
392 type CompactStream = TaskStream<SqliteTask<Self::Compact>, sqlx::Error>;
393
394 fn poll_compact(self, worker: &WorkerContext) -> Self::CompactStream {
395 self.poll_with_listener(worker).boxed()
396 }
397}
398
399impl<Args, Decode: Send + 'static> SqliteStorage<Args, Decode, HookCallbackListener> {
400 fn poll_with_listener(
401 self,
402 worker: &WorkerContext,
403 ) -> impl Stream<Item = Result<Option<SqliteTask<CompactType>>, sqlx::Error>> + Send + 'static
404 {
405 let pool = self.pool.clone();
406 let config = self.config.clone();
407 let worker = worker.clone();
408 let register_worker = initial_heartbeat(
409 self.pool.clone(),
410 self.config.clone(),
411 worker.clone(),
412 "SqliteStorageWithHook",
413 );
414 let register_worker = stream::once(register_worker.map_ok(|_| None));
415 let eager_fetcher: SqlitePollFetcher<CompactType, Decode> =
416 SqlitePollFetcher::new(&self.pool, &self.config, &worker);
417 let lazy_fetcher = self
418 .fetcher
419 .filter(|a| {
420 ready(a.operation() == &SqliteOperation::Insert && a.table_name() == JOBS_TABLE)
421 })
422 .inspect(|db_event| {
423 log::debug!("Received new job event: {db_event:?}");
424 })
425 .ready_chunks(self.config.buffer_size())
426 .then(move |_| fetch_next(pool.clone(), config.clone(), worker.clone()))
427 .flat_map(|res| match res {
428 Ok(tasks) => stream::iter(tasks).map(Ok).boxed(),
429 Err(e) => stream::iter(vec![Err(e)]).boxed(),
430 })
431 .map(|res| match res {
432 Ok(task) => Ok(Some(task)),
433 Err(e) => Err(e),
434 });
435
436 register_worker.chain(select(lazy_fetcher, eager_fetcher))
437 }
438}
439
440impl<Args, Decode: Send + 'static, F> SqliteStorage<Args, Decode, F> {
441 fn poll_default(
442 self,
443 worker: &WorkerContext,
444 ) -> impl Stream<Item = Result<Option<SqliteTask<CompactType>>, sqlx::Error>> + Send + 'static
445 {
446 let fut = initial_heartbeat(
447 self.pool.clone(),
448 self.config().clone(),
449 worker.clone(),
450 "SqliteStorage",
451 );
452 let register = stream::once(fut.map(|_| Ok(None)));
453 register.chain(SqlitePollFetcher::<CompactType, Decode>::new(
454 &self.pool,
455 &self.config,
456 worker,
457 ))
458 }
459}
460
461#[cfg(test)]
462mod tests {
463 use std::time::Duration;
464
465 use apalis::prelude::*;
466 use apalis_workflow::*;
467 use chrono::Local;
468 use serde::{Deserialize, Serialize};
469 use sqlx::SqlitePool;
470
471 use super::*;
472
473 #[tokio::test]
474 async fn basic_worker() {
475 const ITEMS: usize = 10;
476 let pool = SqlitePool::connect(":memory:").await.unwrap();
477 SqliteStorage::setup(&pool).await.unwrap();
478
479 let mut backend = SqliteStorage::new(&pool);
480
481 let mut start = 0;
482
483 let mut items = stream::repeat_with(move || {
484 start += 1;
485 start
486 })
487 .take(ITEMS);
488 backend.push_stream(&mut items).await.unwrap();
489
490 println!("Starting worker at {}", Local::now());
491
492 async fn send_reminder(item: usize, wrk: WorkerContext) -> Result<(), BoxDynError> {
493 if ITEMS == item {
494 wrk.stop().unwrap();
495 }
496 Ok(())
497 }
498
499 let worker = WorkerBuilder::new("rango-tango-1")
500 .backend(backend)
501 .build(send_reminder);
502 worker.run().await.unwrap();
503 }
504
505 #[tokio::test]
506 async fn hooked_worker() {
507 const ITEMS: usize = 20;
508
509 let lazy_strategy = StrategyBuilder::new()
510 .apply(IntervalStrategy::new(Duration::from_secs(5)))
511 .build();
512 let config = Config::new("rango-tango-queue")
513 .with_poll_interval(lazy_strategy)
514 .set_buffer_size(5);
515 let backend = SqliteStorage::new_with_callback(":memory:", &config);
516 let pool = backend.pool().clone();
517 SqliteStorage::setup(&pool).await.unwrap();
518
519 tokio::spawn(async move {
520 tokio::time::sleep(Duration::from_secs(2)).await;
521 let mut start = 0;
522
523 let items = stream::repeat_with(move || {
524 start += 1;
525
526 Task::builder(serde_json::to_vec(&start).unwrap())
527 .with_ctx(SqlContext::new().with_priority(start))
528 .build()
529 })
530 .take(ITEMS)
531 .collect::<Vec<_>>()
532 .await;
533 crate::sink::push_tasks(pool.clone(), config, items)
534 .await
535 .unwrap();
536 });
537
538 async fn send_reminder(item: usize, wrk: WorkerContext) -> Result<(), BoxDynError> {
539 if item == 1 {
541 apalis_core::timer::sleep(Duration::from_secs(1)).await;
542 wrk.stop().unwrap();
543 }
544 Ok(())
545 }
546
547 let worker = WorkerBuilder::new("rango-tango-1")
548 .backend(backend)
549 .build(send_reminder);
550 worker.run().await.unwrap();
551 }
552
553 #[tokio::test]
554 async fn test_workflow() {
555 let workflow = Workflow::new("odd-numbers-workflow")
556 .and_then(|a: usize| async move { Ok::<_, BoxDynError>((0..=a).collect::<Vec<_>>()) })
557 .filter_map(|x| async move { if x % 2 != 0 { Some(x) } else { None } })
558 .filter_map(|x| async move { if x % 3 != 0 { Some(x) } else { None } })
559 .filter_map(|x| async move { if x % 5 != 0 { Some(x) } else { None } })
560 .delay_for(Duration::from_millis(1000))
561 .and_then(|a: Vec<usize>| async move {
562 println!("Sum: {}", a.iter().sum::<usize>());
563 Err::<(), BoxDynError>("Intentional Error".into())
564 });
565
566 let mut sqlite = SqliteStorage::new_with_callback(
567 ":memory:",
568 &Config::new("workflow-queue").with_poll_interval(
569 StrategyBuilder::new()
570 .apply(IntervalStrategy::new(Duration::from_millis(100)))
571 .build(),
572 ),
573 );
574
575 SqliteStorage::setup(sqlite.pool()).await.unwrap();
576
577 sqlite.push_start(100usize).await.unwrap();
578
579 let worker = WorkerBuilder::new("rango-tango")
580 .backend(sqlite)
581 .on_event(|ctx, ev| {
582 println!("On Event = {:?}", ev);
583 if matches!(ev, Event::Error(_)) {
584 ctx.stop().unwrap();
585 }
586 })
587 .build(workflow);
588 worker.run().await.unwrap();
589 }
590
591 #[tokio::test]
592 async fn test_workflow_complete() {
593 #[derive(Debug, Serialize, Deserialize, Clone)]
594 struct PipelineConfig {
595 min_confidence: f32,
596 enable_sentiment: bool,
597 }
598
599 #[derive(Debug, Serialize, Deserialize)]
600 struct UserInput {
601 text: String,
602 }
603
604 #[derive(Debug, Serialize, Deserialize)]
605 struct Classified {
606 text: String,
607 label: String,
608 confidence: f32,
609 }
610
611 #[derive(Debug, Serialize, Deserialize)]
612 struct Summary {
613 text: String,
614 sentiment: Option<String>,
615 }
616
617 let workflow = Workflow::new("text-pipeline")
618 .and_then(|input: UserInput, mut worker: WorkerContext| async move {
620 worker.emit(&Event::custom(format!(
621 "Preprocessing input: {}",
622 input.text
623 )));
624 let processed = input.text.to_lowercase();
625 Ok::<_, BoxDynError>(processed)
626 })
627 .and_then(|text: String| async move {
629 let confidence = 0.85; let items = text.split_whitespace().collect::<Vec<_>>();
631 let results = items
632 .into_iter()
633 .map(|x| Classified {
634 text: x.to_string(),
635 label: if x.contains("rust") {
636 "Tech"
637 } else {
638 "General"
639 }
640 .to_string(),
641 confidence,
642 })
643 .collect::<Vec<_>>();
644 Ok::<_, BoxDynError>(results)
645 })
646 .filter_map(
648 |c: Classified| async move { if c.confidence >= 0.6 { Some(c) } else { None } },
649 )
650 .filter_map(move |c: Classified, config: Data<PipelineConfig>| {
651 let cfg = config.enable_sentiment;
652 async move {
653 if !cfg {
654 return Some(Summary {
655 text: c.text,
656 sentiment: None,
657 });
658 }
659
660 let sentiment = if c.text.contains("delightful") {
662 "positive"
663 } else {
664 "neutral"
665 };
666 Some(Summary {
667 text: c.text,
668 sentiment: Some(sentiment.to_string()),
669 })
670 }
671 })
672 .and_then(|a: Vec<Summary>, mut worker: WorkerContext| async move {
673 worker.emit(&Event::Custom(Box::new(format!(
674 "Generated {} summaries",
675 a.len()
676 ))));
677 worker.stop()
678 });
679
680 let mut sqlite =
681 SqliteStorage::new_with_callback(":memory:", &Config::new("text-pipeline"));
682
683 SqliteStorage::setup(sqlite.pool()).await.unwrap();
684
685 let input = UserInput {
686 text: "Rust makes systems programming delightful!".to_string(),
687 };
688 sqlite.push_start(input).await.unwrap();
689
690 let worker = WorkerBuilder::new("rango-tango")
691 .backend(sqlite)
692 .data(PipelineConfig {
693 min_confidence: 0.8,
694 enable_sentiment: true,
695 })
696 .on_event(|ctx, ev| match ev {
697 Event::Custom(msg) => {
698 if let Some(m) = msg.downcast_ref::<String>() {
699 println!("Custom Message: {}", m);
700 }
701 }
702 Event::Error(_) => {
703 println!("On Error = {:?}", ev);
704 ctx.stop().unwrap();
705 }
706 _ => {
707 println!("On Event = {:?}", ev);
708 }
709 })
710 .build(workflow);
711 worker.run().await.unwrap();
712 }
713}