1#![doc = include_str!("../README.md")]
2use std::{fmt, marker::PhantomData};
5
6use apalis_core::{
7 backend::{
8 Backend, BackendExt, TaskStream,
9 codec::{Codec, json::JsonCodec},
10 },
11 features_table,
12 layers::Stack,
13 task::Task,
14 worker::{context::WorkerContext, ext::ack::AcknowledgeLayer},
15};
16pub use apalis_sql::context::SqlContext;
17use futures::{
18 FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt,
19 channel::mpsc::{self},
20 future::ready,
21 stream::{self, BoxStream, select},
22};
23pub use sqlx::{
24 Connection, Pool, Sqlite, SqliteConnection, SqlitePool,
25 pool::{PoolConnection, PoolOptions},
26 sqlite::{SqliteConnectOptions, SqliteOperation},
27};
28use ulid::Ulid;
29
30use crate::{
31 ack::{LockTaskLayer, SqliteAck},
32 callback::update_hook_callback,
33 fetcher::{SqliteFetcher, SqlitePollFetcher, fetch_next},
34 queries::{
35 keep_alive::{initial_heartbeat, keep_alive, keep_alive_stream},
36 reenqueue_orphaned::reenqueue_orphaned_stream,
37 },
38 sink::SqliteSink,
39};
40
41mod ack;
42mod callback;
43mod config;
44pub mod fetcher;
46mod from_row;
47pub mod queries;
49mod shared;
50pub mod sink;
52
53pub type SqliteTask<Args> = Task<Args, SqlContext, Ulid>;
55pub use callback::{DbEvent, HookCallbackListener};
56pub use config::Config;
57pub use shared::{SharedSqliteError, SharedSqliteStorage};
58
59pub type CompactType = Vec<u8>;
61
62const JOBS_TABLE: &str = "Jobs";
63
64#[doc = features_table! {
69 setup = r#"
70 # {
71 # use apalis_sqlite::SqliteStorage;
72 # use sqlx::SqlitePool;
73 # let pool = SqlitePool::connect(":memory:").await.unwrap();
74 # SqliteStorage::setup(&pool).await.unwrap();
75 # SqliteStorage::new(&pool)
76 # };
77 "#,
78
79 Backend => supported("Supports storage and retrieval of tasks", true),
80 TaskSink => supported("Ability to push new tasks", true),
81 Serialization => supported("Serialization support for arguments", true),
82 Workflow => supported("Flexible enough to support workflows", true),
83 WebUI => supported("Expose a web interface for monitoring tasks", true),
84 FetchById => supported("Allow fetching a task by its ID", false),
85 RegisterWorker => supported("Allow registering a worker with the backend", false),
86 MakeShared => supported("Share one connection across multiple workers via [`SharedSqliteStorage`]", false),
87 WaitForCompletion => supported("Wait for tasks to complete without blocking", true),
88 ResumeById => supported("Resume a task by its ID", false),
89 ResumeAbandoned => supported("Resume abandoned tasks", false),
90 ListWorkers => supported("List all workers registered with the backend", false),
91 ListTasks => supported("List all tasks in the backend", false),
92}]
93#[pin_project::pin_project]
94pub struct SqliteStorage<T, C, Fetcher> {
95 pool: Pool<Sqlite>,
96 job_type: PhantomData<T>,
97 config: Config,
98 codec: PhantomData<C>,
99 #[pin]
100 sink: SqliteSink<T, CompactType, C>,
101 #[pin]
102 fetcher: Fetcher,
103}
104
105impl<T, C, F> fmt::Debug for SqliteStorage<T, C, F> {
106 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
107 f.debug_struct("SqliteStorage")
108 .field("pool", &self.pool)
109 .field("job_type", &"PhantomData<T>")
110 .field("config", &self.config)
111 .field("codec", &std::any::type_name::<C>())
112 .finish()
113 }
114}
115
116impl<T, C, F: Clone> Clone for SqliteStorage<T, C, F> {
117 fn clone(&self) -> Self {
118 Self {
119 sink: self.sink.clone(),
120 pool: self.pool.clone(),
121 job_type: PhantomData,
122 config: self.config.clone(),
123 codec: self.codec,
124 fetcher: self.fetcher.clone(),
125 }
126 }
127}
128
129impl SqliteStorage<(), (), ()> {
130 #[cfg(feature = "migrate")]
132 pub async fn setup(pool: &Pool<Sqlite>) -> Result<(), sqlx::Error> {
133 sqlx::query("PRAGMA journal_mode = 'WAL';")
134 .execute(pool)
135 .await?;
136 sqlx::query("PRAGMA temp_store = 2;").execute(pool).await?;
137 sqlx::query("PRAGMA synchronous = NORMAL;")
138 .execute(pool)
139 .await?;
140 sqlx::query("PRAGMA cache_size = 64000;")
141 .execute(pool)
142 .await?;
143 Self::migrations().run(pool).await?;
144 Ok(())
145 }
146
147 #[cfg(feature = "migrate")]
149 #[must_use]
150 pub fn migrations() -> sqlx::migrate::Migrator {
151 sqlx::migrate!("./migrations")
152 }
153}
154
155impl<T> SqliteStorage<T, (), ()> {
156 #[must_use]
158 pub fn new(
159 pool: &Pool<Sqlite>,
160 ) -> SqliteStorage<T, JsonCodec<CompactType>, fetcher::SqliteFetcher> {
161 let config = Config::new(std::any::type_name::<T>());
162 SqliteStorage {
163 pool: pool.clone(),
164 job_type: PhantomData,
165 sink: SqliteSink::new(pool, &config),
166 config,
167 codec: PhantomData,
168 fetcher: fetcher::SqliteFetcher,
169 }
170 }
171
172 #[must_use]
174 pub fn new_in_queue(
175 pool: &Pool<Sqlite>,
176 queue: &str,
177 ) -> SqliteStorage<T, JsonCodec<CompactType>, fetcher::SqliteFetcher> {
178 let config = Config::new(queue);
179 SqliteStorage {
180 pool: pool.clone(),
181 job_type: PhantomData,
182 sink: SqliteSink::new(pool, &config),
183 config,
184 codec: PhantomData,
185 fetcher: fetcher::SqliteFetcher,
186 }
187 }
188
189 #[must_use]
191 pub fn new_with_config(
192 pool: &Pool<Sqlite>,
193 config: &Config,
194 ) -> SqliteStorage<T, JsonCodec<CompactType>, fetcher::SqliteFetcher> {
195 SqliteStorage {
196 pool: pool.clone(),
197 job_type: PhantomData,
198 config: config.clone(),
199 codec: PhantomData,
200 sink: SqliteSink::new(pool, config),
201 fetcher: fetcher::SqliteFetcher,
202 }
203 }
204
205 #[must_use]
207 pub fn new_with_callback(
208 url: &str,
209 config: &Config,
210 ) -> SqliteStorage<T, JsonCodec<CompactType>, HookCallbackListener> {
211 let (tx, rx) = mpsc::unbounded::<DbEvent>();
212
213 let listener = HookCallbackListener::new(rx);
214 let pool = PoolOptions::<Sqlite>::new()
215 .after_connect(move |conn, _meta| {
216 let mut tx = tx.clone();
217 Box::pin(async move {
218 let mut lock_handle = conn.lock_handle().await?;
219 lock_handle.set_update_hook(move |ev| update_hook_callback(ev, &mut tx));
220 Ok(())
221 })
222 })
223 .connect_lazy(url)
224 .expect("Failed to create Sqlite pool");
225 SqliteStorage {
226 pool: pool.clone(),
227 job_type: PhantomData,
228 config: config.clone(),
229 codec: PhantomData,
230 sink: SqliteSink::new(&pool, config),
231 fetcher: listener,
232 }
233 }
234}
235
236impl<T, C, F> SqliteStorage<T, C, F> {
237 pub fn with_codec<D>(self) -> SqliteStorage<T, D, F> {
239 SqliteStorage {
240 sink: SqliteSink::new(&self.pool, &self.config),
241 pool: self.pool,
242 job_type: PhantomData,
243 config: self.config,
244 codec: PhantomData,
245 fetcher: self.fetcher,
246 }
247 }
248
249 pub fn config(&self) -> &Config {
251 &self.config
252 }
253
254 pub fn pool(&self) -> &Pool<Sqlite> {
256 &self.pool
257 }
258}
259
260impl<Args, Decode> Backend for SqliteStorage<Args, Decode, SqliteFetcher>
261where
262 Args: Send + 'static + Unpin,
263 Decode: Codec<Args, Compact = CompactType> + 'static + Send,
264 Decode::Error: std::error::Error + Send + Sync + 'static,
265{
266 type Args = Args;
267 type IdType = Ulid;
268
269 type Context = SqlContext;
270
271 type Error = sqlx::Error;
272
273 type Stream = TaskStream<SqliteTask<Args>, sqlx::Error>;
274
275 type Beat = BoxStream<'static, Result<(), sqlx::Error>>;
276
277 type Layer = Stack<LockTaskLayer, AcknowledgeLayer<SqliteAck>>;
278
279 fn heartbeat(&self, worker: &WorkerContext) -> Self::Beat {
280 let pool = self.pool.clone();
281 let config = self.config.clone();
282 let worker = worker.clone();
283 let keep_alive = keep_alive_stream(pool, config, worker);
284 let reenqueue = reenqueue_orphaned_stream(
285 self.pool.clone(),
286 self.config.clone(),
287 *self.config.keep_alive(),
288 )
289 .map_ok(|_| ());
290 futures::stream::select(keep_alive, reenqueue).boxed()
291 }
292
293 fn middleware(&self) -> Self::Layer {
294 let lock = LockTaskLayer::new(self.pool.clone());
295 let ack = AcknowledgeLayer::new(SqliteAck::new(self.pool.clone()));
296 Stack::new(lock, ack)
297 }
298
299 fn poll(self, worker: &WorkerContext) -> Self::Stream {
300 self.poll_default(worker)
301 .map(|a| match a {
302 Ok(Some(task)) => Ok(Some(
303 task.try_map(|t| Decode::decode(&t))
304 .map_err(|e| sqlx::Error::Decode(e.into()))?,
305 )),
306 Ok(None) => Ok(None),
307 Err(e) => Err(e),
308 })
309 .boxed()
310 }
311}
312
313impl<Args, Decode: Send + 'static> BackendExt for SqliteStorage<Args, Decode, SqliteFetcher>
314where
315 Self: Backend<Args = Args, IdType = Ulid, Context = SqlContext, Error = sqlx::Error>,
316 Decode: Codec<Args, Compact = CompactType> + Send + 'static,
317 Decode::Error: std::error::Error + Send + Sync + 'static,
318 Args: Send + 'static + Unpin,
319{
320 type Codec = Decode;
321 type Compact = CompactType;
322 type CompactStream = TaskStream<SqliteTask<Self::Compact>, sqlx::Error>;
323
324 fn poll_compact(self, worker: &WorkerContext) -> Self::CompactStream {
325 self.poll_default(worker).boxed()
326 }
327}
328
329impl<Args, Decode> Backend for SqliteStorage<Args, Decode, HookCallbackListener>
330where
331 Args: Send + 'static + Unpin,
332 Decode: Codec<Args, Compact = CompactType> + Send + 'static,
333 Decode::Error: std::error::Error + Send + Sync + 'static,
334{
335 type Args = Args;
336 type IdType = Ulid;
337
338 type Context = SqlContext;
339
340 type Error = sqlx::Error;
341
342 type Stream = TaskStream<SqliteTask<Args>, sqlx::Error>;
343
344 type Beat = BoxStream<'static, Result<(), sqlx::Error>>;
345
346 type Layer = Stack<LockTaskLayer, AcknowledgeLayer<SqliteAck>>;
347
348 fn heartbeat(&self, worker: &WorkerContext) -> Self::Beat {
349 let pool = self.pool.clone();
350 let config = self.config.clone();
351 let worker = worker.clone();
352 let keep_alive = keep_alive_stream(pool, config, worker);
353 let reenqueue = reenqueue_orphaned_stream(
354 self.pool.clone(),
355 self.config.clone(),
356 *self.config.keep_alive(),
357 )
358 .map_ok(|_| ());
359 futures::stream::select(keep_alive, reenqueue).boxed()
360 }
361
362 fn middleware(&self) -> Self::Layer {
363 let lock = LockTaskLayer::new(self.pool.clone());
364 let ack = AcknowledgeLayer::new(SqliteAck::new(self.pool.clone()));
365 Stack::new(lock, ack)
366 }
367
368 fn poll(self, worker: &WorkerContext) -> Self::Stream {
369 self.poll_with_listener(worker)
370 .map(|a| match a {
371 Ok(Some(task)) => Ok(Some(
372 task.try_map(|t| Decode::decode(&t))
373 .map_err(|e| sqlx::Error::Decode(e.into()))?,
374 )),
375 Ok(None) => Ok(None),
376 Err(e) => Err(e),
377 })
378 .boxed()
379 }
380}
381
382impl<Args, Decode: Send + 'static> BackendExt for SqliteStorage<Args, Decode, HookCallbackListener>
383where
384 Self: Backend<Args = Args, IdType = Ulid, Context = SqlContext, Error = sqlx::Error>,
385 Decode: Codec<Args, Compact = CompactType> + Send + 'static,
386 Decode::Error: std::error::Error + Send + Sync + 'static,
387 Args: Send + 'static + Unpin,
388{
389 type Codec = Decode;
390 type Compact = CompactType;
391 type CompactStream = TaskStream<SqliteTask<Self::Compact>, sqlx::Error>;
392
393 fn poll_compact(self, worker: &WorkerContext) -> Self::CompactStream {
394 self.poll_with_listener(worker).boxed()
395 }
396}
397
398impl<Args, Decode: Send + 'static> SqliteStorage<Args, Decode, HookCallbackListener> {
399 fn poll_with_listener(
400 self,
401 worker: &WorkerContext,
402 ) -> impl Stream<Item = Result<Option<SqliteTask<CompactType>>, sqlx::Error>> + Send + 'static
403 {
404 let pool = self.pool.clone();
405 let config = self.config.clone();
406 let worker = worker.clone();
407 let register_worker = initial_heartbeat(
408 self.pool.clone(),
409 self.config.clone(),
410 worker.clone(),
411 "SqliteStorageWithHook",
412 );
413 let register_worker = stream::once(register_worker.map_ok(|_| None));
414 let eager_fetcher: SqlitePollFetcher<CompactType, Decode> =
415 SqlitePollFetcher::new(&self.pool, &self.config, &worker);
416 let lazy_fetcher = self
417 .fetcher
418 .filter(|a| {
419 ready(a.operation() == &SqliteOperation::Insert && a.table_name() == JOBS_TABLE)
420 })
421 .inspect(|db_event| {
422 log::debug!("Received new job event: {db_event:?}");
423 })
424 .ready_chunks(self.config.buffer_size())
425 .then(move |_| fetch_next(pool.clone(), config.clone(), worker.clone()))
426 .flat_map(|res| match res {
427 Ok(tasks) => stream::iter(tasks).map(Ok).boxed(),
428 Err(e) => stream::iter(vec![Err(e)]).boxed(),
429 })
430 .map(|res| match res {
431 Ok(task) => Ok(Some(task)),
432 Err(e) => Err(e),
433 });
434
435 register_worker.chain(select(lazy_fetcher, eager_fetcher))
436 }
437}
438
439impl<Args, Decode: Send + 'static, F> SqliteStorage<Args, Decode, F> {
440 fn poll_default(
441 self,
442 worker: &WorkerContext,
443 ) -> impl Stream<Item = Result<Option<SqliteTask<CompactType>>, sqlx::Error>> + Send + 'static
444 {
445 let fut = initial_heartbeat(
446 self.pool.clone(),
447 self.config().clone(),
448 worker.clone(),
449 "SqliteStorage",
450 );
451 let register = stream::once(fut.map(|_| Ok(None)));
452 register.chain(SqlitePollFetcher::<CompactType, Decode>::new(
453 &self.pool,
454 &self.config,
455 worker,
456 ))
457 }
458}
459
460#[cfg(test)]
461mod tests {
462 use std::time::Duration;
463
464 use apalis::prelude::*;
465 use apalis_workflow::*;
466 use chrono::Local;
467 use serde::{Deserialize, Serialize};
468 use sqlx::SqlitePool;
469
470 use super::*;
471
472 #[tokio::test]
473 async fn basic_worker() {
474 const ITEMS: usize = 10;
475 let pool = SqlitePool::connect(":memory:").await.unwrap();
476 SqliteStorage::setup(&pool).await.unwrap();
477
478 let mut backend = SqliteStorage::new(&pool);
479
480 let mut start = 0;
481
482 let mut items = stream::repeat_with(move || {
483 start += 1;
484 start
485 })
486 .take(ITEMS);
487 backend.push_stream(&mut items).await.unwrap();
488
489 println!("Starting worker at {}", Local::now());
490
491 async fn send_reminder(item: usize, wrk: WorkerContext) -> Result<(), BoxDynError> {
492 if ITEMS == item {
493 wrk.stop().unwrap();
494 }
495 Ok(())
496 }
497
498 let worker = WorkerBuilder::new("rango-tango-1")
499 .backend(backend)
500 .build(send_reminder);
501 worker.run().await.unwrap();
502 }
503
504 #[tokio::test]
505 async fn hooked_worker() {
506 const ITEMS: usize = 20;
507
508 let lazy_strategy = StrategyBuilder::new()
509 .apply(IntervalStrategy::new(Duration::from_secs(5)))
510 .build();
511 let config = Config::new("rango-tango-queue")
512 .with_poll_interval(lazy_strategy)
513 .set_buffer_size(5);
514 let backend = SqliteStorage::new_with_callback(":memory:", &config);
515 let pool = backend.pool().clone();
516 SqliteStorage::setup(&pool).await.unwrap();
517
518 tokio::spawn(async move {
519 tokio::time::sleep(Duration::from_secs(2)).await;
520 let mut start = 0;
521
522 let items = stream::repeat_with(move || {
523 start += 1;
524
525 Task::builder(serde_json::to_vec(&start).unwrap())
526 .with_ctx(SqlContext::new().with_priority(start))
527 .build()
528 })
529 .take(ITEMS)
530 .collect::<Vec<_>>()
531 .await;
532 crate::sink::push_tasks(pool.clone(), config, items)
533 .await
534 .unwrap();
535 });
536
537 async fn send_reminder(item: usize, wrk: WorkerContext) -> Result<(), BoxDynError> {
538 if item == 1 {
540 apalis_core::timer::sleep(Duration::from_secs(1)).await;
541 wrk.stop().unwrap();
542 }
543 Ok(())
544 }
545
546 let worker = WorkerBuilder::new("rango-tango-1")
547 .backend(backend)
548 .build(send_reminder);
549 worker.run().await.unwrap();
550 }
551
552 #[tokio::test]
553 async fn test_workflow() {
554 let workflow = Workflow::new("odd-numbers-workflow")
555 .and_then(|a: usize| async move { Ok::<_, BoxDynError>((0..=a).collect::<Vec<_>>()) })
556 .filter_map(|x| async move { if x % 2 != 0 { Some(x) } else { None } })
557 .filter_map(|x| async move { if x % 3 != 0 { Some(x) } else { None } })
558 .filter_map(|x| async move { if x % 5 != 0 { Some(x) } else { None } })
559 .delay_for(Duration::from_millis(1000))
560 .and_then(|a: Vec<usize>| async move {
561 println!("Sum: {}", a.iter().sum::<usize>());
562 Err::<(), BoxDynError>("Intentional Error".into())
563 });
564
565 let mut sqlite = SqliteStorage::new_with_callback(
566 ":memory:",
567 &Config::new("workflow-queue").with_poll_interval(
568 StrategyBuilder::new()
569 .apply(IntervalStrategy::new(Duration::from_millis(100)))
570 .build(),
571 ),
572 );
573
574 SqliteStorage::setup(sqlite.pool()).await.unwrap();
575
576 sqlite.push_start(100usize).await.unwrap();
577
578 let worker = WorkerBuilder::new("rango-tango")
579 .backend(sqlite)
580 .on_event(|ctx, ev| {
581 println!("On Event = {:?}", ev);
582 if matches!(ev, Event::Error(_)) {
583 ctx.stop().unwrap();
584 }
585 })
586 .build(workflow);
587 worker.run().await.unwrap();
588 }
589
590 #[tokio::test]
591 async fn test_workflow_complete() {
592 #[derive(Debug, Serialize, Deserialize, Clone)]
593 struct PipelineConfig {
594 min_confidence: f32,
595 enable_sentiment: bool,
596 }
597
598 #[derive(Debug, Serialize, Deserialize)]
599 struct UserInput {
600 text: String,
601 }
602
603 #[derive(Debug, Serialize, Deserialize)]
604 struct Classified {
605 text: String,
606 label: String,
607 confidence: f32,
608 }
609
610 #[derive(Debug, Serialize, Deserialize)]
611 struct Summary {
612 text: String,
613 sentiment: Option<String>,
614 }
615
616 let workflow = Workflow::new("text-pipeline")
617 .and_then(|input: UserInput, mut worker: WorkerContext| async move {
619 worker.emit(&Event::custom(format!(
620 "Preprocessing input: {}",
621 input.text
622 )));
623 let processed = input.text.to_lowercase();
624 Ok::<_, BoxDynError>(processed)
625 })
626 .and_then(|text: String| async move {
628 let confidence = 0.85; let items = text.split_whitespace().collect::<Vec<_>>();
630 let results = items
631 .into_iter()
632 .map(|x| Classified {
633 text: x.to_string(),
634 label: if x.contains("rust") {
635 "Tech"
636 } else {
637 "General"
638 }
639 .to_string(),
640 confidence,
641 })
642 .collect::<Vec<_>>();
643 Ok::<_, BoxDynError>(results)
644 })
645 .filter_map(
647 |c: Classified| async move { if c.confidence >= 0.6 { Some(c) } else { None } },
648 )
649 .filter_map(move |c: Classified, config: Data<PipelineConfig>| {
650 let cfg = config.enable_sentiment;
651 async move {
652 if !cfg {
653 return Some(Summary {
654 text: c.text,
655 sentiment: None,
656 });
657 }
658
659 let sentiment = if c.text.contains("delightful") {
661 "positive"
662 } else {
663 "neutral"
664 };
665 Some(Summary {
666 text: c.text,
667 sentiment: Some(sentiment.to_string()),
668 })
669 }
670 })
671 .and_then(|a: Vec<Summary>, mut worker: WorkerContext| async move {
672 worker.emit(&Event::Custom(Box::new(format!(
673 "Generated {} summaries",
674 a.len()
675 ))));
676 worker.stop()
677 });
678
679 let mut sqlite =
680 SqliteStorage::new_with_callback(":memory:", &Config::new("text-pipeline"));
681
682 SqliteStorage::setup(sqlite.pool()).await.unwrap();
683
684 let input = UserInput {
685 text: "Rust makes systems programming delightful!".to_string(),
686 };
687 sqlite.push_start(input).await.unwrap();
688
689 let worker = WorkerBuilder::new("rango-tango")
690 .backend(sqlite)
691 .data(PipelineConfig {
692 min_confidence: 0.8,
693 enable_sentiment: true,
694 })
695 .on_event(|ctx, ev| match ev {
696 Event::Custom(msg) => {
697 if let Some(m) = msg.downcast_ref::<String>() {
698 println!("Custom Message: {}", m);
699 }
700 }
701 Event::Error(_) => {
702 println!("On Error = {:?}", ev);
703 ctx.stop().unwrap();
704 }
705 _ => {
706 println!("On Event = {:?}", ev);
707 }
708 })
709 .build(workflow);
710 worker.run().await.unwrap();
711 }
712}