apalis_sqlite/
lib.rs

1//! # apalis-sqlite
2//!
3//! Background task processing for rust using apalis and sqlite.
4//!
5//! ## Features
6//!
7//! - **Reliable job queue** using SQLite as the backend.
8//! - **Multiple storage types**: standard polling and event-driven (hooked) storage.
9//! - **Custom codecs** for serializing/deserializing job arguments.
10//! - **Heartbeat and orphaned job re-enqueueing** for robust job processing.
11//! - **Integration with apalis workers and middleware.**
12//!
13//! ## Storage Types
14//!
15//! - [`SqliteStorage`]: Standard polling-based storage.
16//! - [`SqliteStorageWithHook`]: Event-driven storage using SQLite update hooks for low-latency job fetching.
17//! - [`SharedSqliteStorage`]: Shared storage for multiple job types.
18//!
19//! The naming is designed to clearly indicate the storage mechanism and its capabilities, but under the hood its the result is the `SqliteStorage` struct with different configurations.
20//!
21//! ## Examples
22//!
23//! ### Basic Worker Example
24//!
25//! ```rust
26//! # use apalis_sqlite::{SqliteStorage, SqlContext};
27//! # use apalis_core::task::Task;
28//! # use apalis_core::worker::context::WorkerContext;
29//! # use sqlx::SqlitePool;
30//! # use futures::stream;
31//! # use std::time::Duration;
32//! # use apalis_core::error::BoxDynError;
33//! # use futures::StreamExt;
34//! # use futures::SinkExt;
35//! # use apalis_core::worker::builder::WorkerBuilder;
36//! # use apalis_core::backend::TaskSink;
37//! #[tokio::main]
38//! async fn main() {
39//!     let pool = SqlitePool::connect(":memory:").await.unwrap();
40//!     SqliteStorage::setup(&pool).await.unwrap();
41//!     let mut backend = SqliteStorage::new(&pool);
42//!
43//!     let mut start = 0usize;
44//!     let mut items = stream::repeat_with(move || {
45//!         start += 1;
46//!         start
47//!     })
48//!     .take(10);
49//!     backend.push_stream(&mut items).await.unwrap();
50//!
51//!     async fn send_reminder(item: usize, wrk: WorkerContext) -> Result<(), BoxDynError> {
52//!         if item == 10 {
53//!             wrk.stop().unwrap();
54//!         }
55//!         Ok(())
56//!     }
57//!
58//!     let worker = WorkerBuilder::new("worker-1")
59//!         .backend(backend)
60//!         .build(send_reminder);
61//!     worker.run().await.unwrap();
62//! }
63//! ```
64//!
65//! ### Hooked Worker Example (Event-driven)
66//!
67//! ```rust,no_run
68//! # use apalis_sqlite::{SqliteStorage, SqlContext, Config};
69//! # use apalis_core::task::Task;
70//! # use apalis_core::worker::context::WorkerContext;
71//! # use apalis_core::backend::poll_strategy::{IntervalStrategy, StrategyBuilder};
72//! # use sqlx::SqlitePool;
73//! # use futures::stream;
74//! # use std::time::Duration;
75//! # use apalis_core::error::BoxDynError;
76//! # use futures::StreamExt;
77//! # use futures::SinkExt;
78//! # use apalis_core::worker::builder::WorkerBuilder;
79//!
80//! #[tokio::main]
81//! async fn main() {
82//!     let pool = SqlitePool::connect(":memory:").await.unwrap();
83//!     SqliteStorage::setup(&pool).await.unwrap();
84//!
85//!     let lazy_strategy = StrategyBuilder::new()
86//!         .apply(IntervalStrategy::new(Duration::from_secs(5)))
87//!         .build();
88//!     let config = Config::new("queue")
89//!         .with_poll_interval(lazy_strategy)
90//!         .set_buffer_size(5);
91//!     let backend = SqliteStorage::new_with_callback(&pool, &config);
92//!
93//!     tokio::spawn({
94//!         let pool = pool.clone();
95//!         let config = config.clone();
96//!         async move {
97//!             tokio::time::sleep(Duration::from_secs(2)).await;
98//!             let mut start = 0;
99//!             let items = stream::repeat_with(move || {
100//!                 start += 1;
101//!                 Task::builder(serde_json::to_vec(&start).unwrap())
102//!                     .run_after(Duration::from_secs(1))
103//!                     .with_ctx(SqlContext::new().with_priority(start))
104//!                     .build()
105//!             })
106//!             .take(20)
107//!             .collect::<Vec<_>>()
108//!             .await;
109//!             // push encoded tasks
110//!             apalis_sqlite::sink::push_tasks(pool, config, items).await.unwrap();
111//!         }
112//!     });
113//!
114//!     async fn send_reminder(item: usize, wrk: WorkerContext) -> Result<(), BoxDynError> {
115//!         if item == 1 {
116//!             apalis_core::timer::sleep(Duration::from_secs(1)).await;
117//!             wrk.stop().unwrap();
118//!         }
119//!         Ok(())
120//!     }
121//!
122//!     let worker = WorkerBuilder::new("worker-2")
123//!         .backend(backend)
124//!         .build(send_reminder);
125//!     worker.run().await.unwrap();
126//! }
127//! ```
128//! ### Workflow Example
129//!
130//! ```rust,no_run
131//! # use apalis_sqlite::{SqliteStorage, SqlContext, Config};
132//! # use apalis_core::task::Task;
133//! # use apalis_core::worker::context::WorkerContext;
134//! # use sqlx::SqlitePool;
135//! # use futures::stream;
136//! # use std::time::Duration;
137//! # use apalis_core::error::BoxDynError;
138//! # use futures::StreamExt;
139//! # use futures::SinkExt;
140//! # use apalis_core::worker::builder::WorkerBuilder;
141//! # use apalis_workflow::Workflow;
142//! # use apalis_workflow::WorkflowError;
143//! # use apalis_core::worker::event::Event;
144//! # use apalis_core::backend::WeakTaskSink;
145//! # use apalis_core::worker::ext::event_listener::EventListenerExt;
146//! #[tokio::main]
147//! async fn main() {
148//!     let workflow = Workflow::new("odd-numbers-workflow")
149//!         .then(|a: usize| async move {
150//!             Ok::<_, WorkflowError>((0..=a).collect::<Vec<_>>())
151//!         })
152//!         .filter_map(|x| async move {
153//!             if x % 2 != 0 { Some(x) } else { None }
154//!         })
155//!         .filter_map(|x| async move {
156//!             if x % 3 != 0 { Some(x) } else { None }
157//!         })
158//!         .filter_map(|x| async move {
159//!             if x % 5 != 0 { Some(x) } else { None }
160//!         })
161//!         .delay_for(Duration::from_millis(1000))
162//!         .then(|a: Vec<usize>| async move {
163//!             println!("Sum: {}", a.iter().sum::<usize>());
164//!             Ok::<(), WorkflowError>(())
165//!         });
166//!
167//!     let pool = SqlitePool::connect(":memory:").await.unwrap();
168//!     SqliteStorage::setup(&pool).await.unwrap();
169//!     let mut sqlite = SqliteStorage::new_in_queue(&pool, "test-workflow");
170//!
171//!     sqlite.push(100usize).await.unwrap();
172//!
173//!     let worker = WorkerBuilder::new("rango-tango")
174//!         .backend(sqlite)
175//!         .on_event(|ctx, ev| {
176//!             println!("On Event = {:?}", ev);
177//!             if matches!(ev, Event::Error(_)) {
178//!                 ctx.stop().unwrap();
179//!             }
180//!         })
181//!         .build(workflow);
182//!
183//!     worker.run().await.unwrap();
184//! }
185//! ```
186//!
187//! ## Observability
188//!
189//! You can track your jobs using [apalis-board](https://github.com/apalis-dev/apalis-board).
190//! ![Task](https://github.com/apalis-dev/apalis-board/raw/master/screenshots/task.png)
191//!
192//! ## License
193//!
194//! Licensed under either of Apache License, Version 2.0 or MIT license at your option.
195//!
196//! [`SqliteStorageWithHook`]: crate::SqliteStorage
197use std::{fmt, marker::PhantomData};
198
199use apalis_core::{
200    backend::{
201        Backend, TaskStream,
202        codec::{Codec, json::JsonCodec},
203    },
204    features_table,
205    layers::Stack,
206    task::Task,
207    worker::{context::WorkerContext, ext::ack::AcknowledgeLayer},
208};
209pub use apalis_sql::context::SqlContext;
210use futures::{
211    FutureExt, StreamExt, TryFutureExt, TryStreamExt,
212    channel::mpsc,
213    future::ready,
214    stream::{self, BoxStream, select},
215};
216use libsqlite3_sys::{sqlite3, sqlite3_update_hook};
217use sqlx::{Pool, Sqlite};
218use std::ffi::c_void;
219use ulid::Ulid;
220
221use crate::{
222    ack::{LockTaskLayer, SqliteAck},
223    callback::{HookCallbackListener, update_hook_callback},
224    fetcher::{SqliteFetcher, SqlitePollFetcher, fetch_next},
225    queries::{
226        keep_alive::{initial_heartbeat, keep_alive, keep_alive_stream},
227        reenqueue_orphaned::reenqueue_orphaned_stream,
228    },
229    sink::SqliteSink,
230};
231
232mod ack;
233mod callback;
234mod config;
235pub mod fetcher;
236pub mod queries;
237mod shared;
238pub mod sink;
239
240mod from_row {
241    use chrono::{TimeZone, Utc};
242
243    #[derive(Debug)]
244    pub(crate) struct SqliteTaskRow {
245        pub(crate) job: Vec<u8>,
246        pub(crate) id: Option<String>,
247        pub(crate) job_type: Option<String>,
248        pub(crate) status: Option<String>,
249        pub(crate) attempts: Option<i64>,
250        pub(crate) max_attempts: Option<i64>,
251        pub(crate) run_at: Option<i64>,
252        pub(crate) last_result: Option<String>,
253        pub(crate) lock_at: Option<i64>,
254        pub(crate) lock_by: Option<String>,
255        pub(crate) done_at: Option<i64>,
256        pub(crate) priority: Option<i64>,
257        pub(crate) metadata: Option<String>,
258    }
259
260    impl TryInto<apalis_sql::from_row::TaskRow> for SqliteTaskRow {
261        type Error = sqlx::Error;
262
263        fn try_into(self) -> Result<apalis_sql::from_row::TaskRow, Self::Error> {
264            Ok(apalis_sql::from_row::TaskRow {
265                job: self.job,
266                id: self
267                    .id
268                    .ok_or_else(|| sqlx::Error::Protocol("Missing id".into()))?,
269                job_type: self
270                    .job_type
271                    .ok_or_else(|| sqlx::Error::Protocol("Missing job_type".into()))?,
272                status: self
273                    .status
274                    .ok_or_else(|| sqlx::Error::Protocol("Missing status".into()))?,
275                attempts: self
276                    .attempts
277                    .ok_or_else(|| sqlx::Error::Protocol("Missing attempts".into()))?
278                    as usize,
279                max_attempts: self.max_attempts.map(|v| v as usize),
280                run_at: self.run_at.map(|ts| {
281                    Utc.timestamp_opt(ts, 0)
282                        .single()
283                        .ok_or_else(|| sqlx::Error::Protocol("Invalid run_at timestamp".into()))
284                        .unwrap()
285                }),
286                last_result: self
287                    .last_result
288                    .map(|res| serde_json::from_str(&res).unwrap_or(serde_json::Value::Null)),
289                lock_at: self.lock_at.map(|ts| {
290                    Utc.timestamp_opt(ts, 0)
291                        .single()
292                        .ok_or_else(|| sqlx::Error::Protocol("Invalid run_at timestamp".into()))
293                        .unwrap()
294                }),
295                lock_by: self.lock_by,
296                done_at: self.done_at.map(|ts| {
297                    Utc.timestamp_opt(ts, 0)
298                        .single()
299                        .ok_or_else(|| sqlx::Error::Protocol("Invalid run_at timestamp".into()))
300                        .unwrap()
301                }),
302                priority: self.priority.map(|v| v as usize),
303                metadata: self
304                    .metadata
305                    .map(|meta| serde_json::from_str(&meta).unwrap_or(serde_json::Value::Null)),
306            })
307        }
308    }
309}
310
311pub type SqliteTask<Args> = Task<Args, SqlContext, Ulid>;
312pub use callback::{CallbackListener, DbEvent};
313pub use config::Config;
314pub use shared::{SharedPostgresError, SharedSqliteStorage};
315pub use sqlx::SqlitePool;
316
317pub type CompactType = Vec<u8>;
318
319const INSERT_OPERATION: &str = "INSERT";
320const JOBS_TABLE: &str = "Jobs";
321
322/// SqliteStorage is a storage backend for apalis using sqlite as the database.
323///
324/// It supports both standard polling and event-driven (hooked) storage mechanisms.
325///
326#[doc = features_table! {
327    setup = r#"
328        # {
329        #   use apalis_sqlite::SqliteStorage;
330        #   use sqlx::SqlitePool;
331        #   let pool = SqlitePool::connect(":memory:").await.unwrap();
332        #   SqliteStorage::setup(&pool).await.unwrap();
333        #   SqliteStorage::new(&pool)
334        # };
335    "#,
336
337    Backend => supported("Supports storage and retrieval of tasks", true),
338    TaskSink => supported("Ability to push new tasks", true),
339    Serialization => supported("Serialization support for arguments", true),
340    WebUI => supported("Expose a web interface for monitoring tasks", true),
341    FetchById => supported("Allow fetching a task by its ID", false),
342    RegisterWorker => supported("Allow registering a worker with the backend", false),
343    MakeShared => supported("Share one connection across multiple workers via [`SharedSqliteStorage`]", false),
344    Workflow => supported("Flexible enough to support workflows", true),
345    WaitForCompletion => supported("Wait for tasks to complete without blocking", true),
346    ResumeById => supported("Resume a task by its ID", false),
347    ResumeAbandoned => supported("Resume abandoned tasks", false),
348    ListWorkers => supported("List all workers registered with the backend", false),
349    ListTasks => supported("List all tasks in the backend", false),
350}]
351#[pin_project::pin_project]
352pub struct SqliteStorage<T, C, Fetcher> {
353    pool: Pool<Sqlite>,
354    job_type: PhantomData<T>,
355    config: Config,
356    codec: PhantomData<C>,
357    #[pin]
358    sink: SqliteSink<T, CompactType, C>,
359    #[pin]
360    fetcher: Fetcher,
361}
362
363impl<T, C, F> fmt::Debug for SqliteStorage<T, C, F> {
364    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
365        f.debug_struct("SqliteStorage")
366            .field("pool", &self.pool)
367            .field("job_type", &"PhantomData<T>")
368            .field("config", &self.config)
369            .field("codec", &std::any::type_name::<C>())
370            .finish()
371    }
372}
373
374impl<T, C, F: Clone> Clone for SqliteStorage<T, C, F> {
375    fn clone(&self) -> Self {
376        SqliteStorage {
377            sink: self.sink.clone(),
378            pool: self.pool.clone(),
379            job_type: PhantomData,
380            config: self.config.clone(),
381            codec: self.codec,
382            fetcher: self.fetcher.clone(),
383        }
384    }
385}
386
387impl SqliteStorage<(), (), ()> {
388    /// Perform migrations for storage
389    #[cfg(feature = "migrate")]
390    pub async fn setup(pool: &Pool<Sqlite>) -> Result<(), sqlx::Error> {
391        sqlx::query("PRAGMA journal_mode = 'WAL';")
392            .execute(pool)
393            .await?;
394        sqlx::query("PRAGMA temp_store = 2;").execute(pool).await?;
395        sqlx::query("PRAGMA synchronous = NORMAL;")
396            .execute(pool)
397            .await?;
398        sqlx::query("PRAGMA cache_size = 64000;")
399            .execute(pool)
400            .await?;
401        Self::migrations().run(pool).await?;
402        Ok(())
403    }
404
405    /// Get sqlite migrations without running them
406    #[cfg(feature = "migrate")]
407    pub fn migrations() -> sqlx::migrate::Migrator {
408        sqlx::migrate!("./migrations")
409    }
410}
411
412impl<T> SqliteStorage<T, (), ()> {
413    /// Create a new SqliteStorage
414    pub fn new(
415        pool: &Pool<Sqlite>,
416    ) -> SqliteStorage<
417        T,
418        JsonCodec<CompactType>,
419        fetcher::SqliteFetcher<T, CompactType, JsonCodec<CompactType>>,
420    > {
421        let config = Config::new(std::any::type_name::<T>());
422        SqliteStorage {
423            pool: pool.clone(),
424            job_type: PhantomData,
425            sink: SqliteSink::new(pool, &config),
426            config,
427            codec: PhantomData,
428            fetcher: fetcher::SqliteFetcher {
429                _marker: PhantomData,
430            },
431        }
432    }
433
434    pub fn new_in_queue(
435        pool: &Pool<Sqlite>,
436        queue: &str,
437    ) -> SqliteStorage<
438        T,
439        JsonCodec<CompactType>,
440        fetcher::SqliteFetcher<T, CompactType, JsonCodec<CompactType>>,
441    > {
442        let config = Config::new(queue);
443        SqliteStorage {
444            pool: pool.clone(),
445            job_type: PhantomData,
446            sink: SqliteSink::new(pool, &config),
447            config,
448            codec: PhantomData,
449            fetcher: fetcher::SqliteFetcher {
450                _marker: PhantomData,
451            },
452        }
453    }
454
455    pub fn new_with_codec<Codec>(
456        pool: &Pool<Sqlite>,
457        config: &Config,
458    ) -> SqliteStorage<T, Codec, fetcher::SqliteFetcher<T, CompactType, Codec>> {
459        SqliteStorage {
460            pool: pool.clone(),
461            job_type: PhantomData,
462            config: config.clone(),
463            codec: PhantomData,
464            sink: SqliteSink::new(pool, config),
465            fetcher: fetcher::SqliteFetcher {
466                _marker: PhantomData,
467            },
468        }
469    }
470
471    pub fn new_with_config(
472        pool: &Pool<Sqlite>,
473        config: &Config,
474    ) -> SqliteStorage<
475        T,
476        JsonCodec<CompactType>,
477        fetcher::SqliteFetcher<T, CompactType, JsonCodec<CompactType>>,
478    > {
479        SqliteStorage {
480            pool: pool.clone(),
481            job_type: PhantomData,
482            config: config.clone(),
483            codec: PhantomData,
484            sink: SqliteSink::new(pool, config),
485            fetcher: fetcher::SqliteFetcher {
486                _marker: PhantomData,
487            },
488        }
489    }
490
491    pub fn new_with_callback(
492        pool: &Pool<Sqlite>,
493        config: &Config,
494    ) -> SqliteStorage<T, JsonCodec<CompactType>, HookCallbackListener> {
495        SqliteStorage {
496            pool: pool.clone(),
497            job_type: PhantomData,
498            config: config.clone(),
499            codec: PhantomData,
500            sink: SqliteSink::new(pool, config),
501            fetcher: HookCallbackListener,
502        }
503    }
504
505    pub fn new_with_codec_callback<Codec>(
506        pool: &Pool<Sqlite>,
507        config: &Config,
508    ) -> SqliteStorage<T, Codec, HookCallbackListener> {
509        SqliteStorage {
510            pool: pool.clone(),
511            job_type: PhantomData,
512            config: config.clone(),
513            codec: PhantomData,
514            sink: SqliteSink::new(pool, config),
515            fetcher: HookCallbackListener,
516        }
517    }
518}
519
520impl<T, C, F> SqliteStorage<T, C, F> {
521    pub fn config(&self) -> &Config {
522        &self.config
523    }
524}
525
526impl<Args, Decode> Backend for SqliteStorage<Args, Decode, SqliteFetcher<Args, CompactType, Decode>>
527where
528    Args: Send + 'static + Unpin,
529    Decode: Codec<Args, Compact = CompactType> + 'static + Send,
530    Decode::Error: std::error::Error + Send + Sync + 'static,
531{
532    type Args = Args;
533    type IdType = Ulid;
534
535    type Context = SqlContext;
536
537    type Codec = Decode;
538
539    type Compact = CompactType;
540
541    type Error = sqlx::Error;
542
543    type Stream = TaskStream<SqliteTask<Args>, sqlx::Error>;
544
545    type Beat = BoxStream<'static, Result<(), sqlx::Error>>;
546
547    type Layer = Stack<LockTaskLayer, AcknowledgeLayer<SqliteAck>>;
548
549    fn heartbeat(&self, worker: &WorkerContext) -> Self::Beat {
550        let pool = self.pool.clone();
551        let config = self.config.clone();
552        let worker = worker.clone();
553        let keep_alive = keep_alive_stream(pool, config, worker);
554        let reenqueue = reenqueue_orphaned_stream(
555            self.pool.clone(),
556            self.config.clone(),
557            *self.config.keep_alive(),
558        )
559        .map_ok(|_| ());
560        futures::stream::select(keep_alive, reenqueue).boxed()
561    }
562
563    fn middleware(&self) -> Self::Layer {
564        let lock = LockTaskLayer::new(self.pool.clone());
565        let ack = AcknowledgeLayer::new(SqliteAck::new(self.pool.clone()));
566        Stack::new(lock, ack)
567    }
568
569    fn poll(self, worker: &WorkerContext) -> Self::Stream {
570        let fut = initial_heartbeat(
571            self.pool.clone(),
572            self.config().clone(),
573            worker.clone(),
574            "SqliteStorage",
575        );
576        let register = stream::once(fut.map(|_| Ok(None)));
577        register
578            .chain(SqlitePollFetcher::<Args, CompactType, Decode>::new(
579                &self.pool,
580                &self.config,
581                worker,
582            ))
583            .boxed()
584    }
585}
586
587impl<Args, Decode> Backend for SqliteStorage<Args, Decode, HookCallbackListener>
588where
589    Args: Send + 'static + Unpin,
590    Decode: Codec<Args, Compact = CompactType> + Send + 'static,
591    Decode::Error: std::error::Error + Send + Sync + 'static,
592{
593    type Args = Args;
594    type IdType = Ulid;
595
596    type Context = SqlContext;
597
598    type Codec = Decode;
599
600    type Compact = CompactType;
601
602    type Error = sqlx::Error;
603
604    type Stream = TaskStream<SqliteTask<Args>, sqlx::Error>;
605
606    type Beat = BoxStream<'static, Result<(), sqlx::Error>>;
607
608    type Layer = Stack<LockTaskLayer, AcknowledgeLayer<SqliteAck>>;
609
610    fn heartbeat(&self, worker: &WorkerContext) -> Self::Beat {
611        let pool = self.pool.clone();
612        let config = self.config.clone();
613        let worker = worker.clone();
614        let keep_alive = keep_alive_stream(pool, config, worker);
615        let reenqueue = reenqueue_orphaned_stream(
616            self.pool.clone(),
617            self.config.clone(),
618            *self.config.keep_alive(),
619        )
620        .map_ok(|_| ());
621        futures::stream::select(keep_alive, reenqueue).boxed()
622    }
623
624    fn middleware(&self) -> Self::Layer {
625        let lock = LockTaskLayer::new(self.pool.clone());
626        let ack = AcknowledgeLayer::new(SqliteAck::new(self.pool.clone()));
627        Stack::new(lock, ack)
628    }
629
630    fn poll(self, worker: &WorkerContext) -> Self::Stream {
631        let (tx, rx) = mpsc::unbounded::<DbEvent>();
632
633        let listener = CallbackListener::new(rx);
634
635        let pool = self.pool.clone();
636        let config = self.config.clone();
637        let worker = worker.clone();
638        let register_worker = initial_heartbeat(
639            self.pool.clone(),
640            self.config.clone(),
641            worker.clone(),
642            "SqliteStorageWithHook",
643        );
644        let p = pool.clone();
645        let register_worker = stream::once(register_worker.map_ok(|_| None));
646        let register_listener = stream::once(async move {
647            // This is still a little tbd, but the idea is to test the update hook
648            let mut conn = p.acquire().await?;
649            // Get raw sqlite3* handle
650            let handle: *mut sqlite3 = conn.lock_handle().await.unwrap().as_raw_handle().as_ptr();
651
652            // Put sender in a Box so it has a stable memory address
653            let tx_box = Box::new(tx);
654            let tx_ptr = Box::into_raw(tx_box) as *mut c_void;
655            unsafe {
656                sqlite3_update_hook(handle, Some(update_hook_callback), tx_ptr);
657            }
658            Ok(None)
659        });
660        let eager_fetcher: SqlitePollFetcher<Args, CompactType, Decode> =
661            SqlitePollFetcher::new(&self.pool, &self.config, &worker);
662        let lazy_fetcher = listener
663            .filter(|a| ready(a.operation() == INSERT_OPERATION && a.table_name() == JOBS_TABLE))
664            .inspect(|db_event| {
665                log::debug!("Received DB event: {db_event:?}");
666            })
667            .ready_chunks(self.config.buffer_size())
668            .then(move |_| fetch_next::<Args, Decode>(pool.clone(), config.clone(), worker.clone()))
669            .flat_map(|res| match res {
670                Ok(tasks) => stream::iter(tasks).map(Ok).boxed(),
671                Err(e) => stream::iter(vec![Err(e)]).boxed(),
672            })
673            .map(|res| match res {
674                Ok(task) => Ok(Some(task)),
675                Err(e) => Err(e),
676            });
677
678        register_worker
679            .chain(register_listener)
680            .chain(select(lazy_fetcher, eager_fetcher))
681            .boxed()
682    }
683}
684
685#[cfg(test)]
686mod tests {
687    use std::time::Duration;
688
689    use apalis_workflow::{Workflow, WorkflowError};
690    use chrono::Local;
691
692    use apalis_core::{
693        backend::{
694            WeakTaskSink,
695            poll_strategy::{IntervalStrategy, StrategyBuilder},
696        },
697        error::BoxDynError,
698        task::data::Data,
699        worker::{builder::WorkerBuilder, event::Event, ext::event_listener::EventListenerExt},
700    };
701    use serde::{Deserialize, Serialize};
702
703    use super::*;
704
705    #[tokio::test]
706    async fn basic_worker() {
707        const ITEMS: usize = 10;
708        let pool = SqlitePool::connect(":memory:").await.unwrap();
709        SqliteStorage::setup(&pool).await.unwrap();
710
711        let mut backend = SqliteStorage::new(&pool);
712
713        let mut start = 0;
714
715        let mut items = stream::repeat_with(move || {
716            start += 1;
717            start
718        })
719        .take(ITEMS);
720        backend.push_stream(&mut items).await.unwrap();
721
722        println!("Starting worker at {}", Local::now());
723
724        async fn send_reminder(item: usize, wrk: WorkerContext) -> Result<(), BoxDynError> {
725            if ITEMS == item {
726                wrk.stop().unwrap();
727            }
728            Ok(())
729        }
730
731        let worker = WorkerBuilder::new("rango-tango-1")
732            .backend(backend)
733            .build(send_reminder);
734        worker.run().await.unwrap();
735    }
736
737    #[tokio::test]
738    async fn hooked_worker() {
739        const ITEMS: usize = 20;
740        let pool = SqlitePool::connect(":memory:").await.unwrap();
741        SqliteStorage::setup(&pool).await.unwrap();
742
743        let lazy_strategy = StrategyBuilder::new()
744            .apply(IntervalStrategy::new(Duration::from_secs(5)))
745            .build();
746        let config = Config::new("rango-tango-queue")
747            .with_poll_interval(lazy_strategy)
748            .set_buffer_size(5);
749        let backend = SqliteStorage::new_with_callback(&pool, &config);
750
751        tokio::spawn(async move {
752            tokio::time::sleep(Duration::from_secs(2)).await;
753            let mut start = 0;
754
755            let items = stream::repeat_with(move || {
756                start += 1;
757
758                Task::builder(serde_json::to_vec(&start).unwrap())
759                    // .run_after(Duration::from_secs(1))
760                    .with_ctx(SqlContext::new().with_priority(start))
761                    .build()
762            })
763            .take(ITEMS)
764            .collect::<Vec<_>>()
765            .await;
766            sink::push_tasks(pool, config, items).await.unwrap();
767        });
768
769        async fn send_reminder(item: usize, wrk: WorkerContext) -> Result<(), BoxDynError> {
770            // Priority is in reverse order
771            if item == 1 {
772                apalis_core::timer::sleep(Duration::from_secs(1)).await;
773                wrk.stop().unwrap();
774            }
775            Ok(())
776        }
777
778        let worker = WorkerBuilder::new("rango-tango-1")
779            .backend(backend)
780            .build(send_reminder);
781        worker.run().await.unwrap();
782    }
783
784    #[tokio::test]
785    async fn test_workflow() {
786        let workflow = Workflow::new("odd-numbers-workflow")
787            .then(|a: usize| async move { Ok::<_, WorkflowError>((0..=a).collect::<Vec<_>>()) })
788            .filter_map(|x| async move { if x % 2 != 0 { Some(x) } else { None } })
789            .filter_map(|x| async move { if x % 3 != 0 { Some(x) } else { None } })
790            .filter_map(|x| async move { if x % 5 != 0 { Some(x) } else { None } })
791            .delay_for(Duration::from_millis(1000))
792            .then(|a: Vec<usize>| async move {
793                println!("Sum: {}", a.iter().sum::<usize>());
794                Err::<(), WorkflowError>(WorkflowError::MissingContextError)
795            });
796
797        let pool = SqlitePool::connect(":memory:").await.unwrap();
798        SqliteStorage::setup(&pool).await.unwrap();
799        let mut sqlite = SqliteStorage::new_with_callback(
800            &pool,
801            &Config::new("workflow-queue").with_poll_interval(
802                StrategyBuilder::new()
803                    .apply(IntervalStrategy::new(Duration::from_millis(100)))
804                    .build(),
805            ),
806        );
807
808        sqlite.push(100usize).await.unwrap();
809
810        let worker = WorkerBuilder::new("rango-tango")
811            .backend(sqlite)
812            .on_event(|ctx, ev| {
813                println!("On Event = {:?}", ev);
814                if matches!(ev, Event::Error(_)) {
815                    ctx.stop().unwrap();
816                }
817            })
818            .build(workflow);
819        worker.run().await.unwrap();
820    }
821
822    #[tokio::test]
823    async fn test_workflow_complete() {
824        #[derive(Debug, Serialize, Deserialize, Clone)]
825        struct PipelineConfig {
826            min_confidence: f32,
827            enable_sentiment: bool,
828        }
829
830        #[derive(Debug, Serialize, Deserialize)]
831        struct UserInput {
832            text: String,
833        }
834
835        #[derive(Debug, Serialize, Deserialize)]
836        struct Classified {
837            text: String,
838            label: String,
839            confidence: f32,
840        }
841
842        #[derive(Debug, Serialize, Deserialize)]
843        struct Summary {
844            text: String,
845            sentiment: Option<String>,
846        }
847
848        let workflow = Workflow::new("text-pipeline")
849            // Step 1: Preprocess input (e.g., tokenize, lowercase)
850            .then(|input: UserInput, mut worker: WorkerContext| async move {
851                worker.emit(&Event::Custom(Box::new(format!(
852                    "Preprocessing input: {}",
853                    input.text
854                ))));
855                let processed = input.text.to_lowercase();
856                Ok::<_, WorkflowError>(processed)
857            })
858            // Step 2: Classify text
859            .then(|text: String| async move {
860                let confidence = 0.85; // pretend model confidence
861                let items = text.split_whitespace().collect::<Vec<_>>();
862                let results = items
863                    .into_iter()
864                    .map(|x| Classified {
865                        text: x.to_string(),
866                        label: if x.contains("rust") {
867                            "Tech"
868                        } else {
869                            "General"
870                        }
871                        .to_string(),
872                        confidence,
873                    })
874                    .collect::<Vec<_>>();
875                Ok::<_, WorkflowError>(results)
876            })
877            // Step 3: Filter out low-confidence predictions
878            .filter_map(
879                |c: Classified| async move { if c.confidence >= 0.6 { Some(c) } else { None } },
880            )
881            .filter_map(move |c: Classified, config: Data<PipelineConfig>| {
882                let cfg = config.enable_sentiment;
883                async move {
884                    if !cfg {
885                        return Some(Summary {
886                            text: c.text,
887                            sentiment: None,
888                        });
889                    }
890
891                    // pretend we run a sentiment model
892                    let sentiment = if c.text.contains("delightful") {
893                        "positive"
894                    } else {
895                        "neutral"
896                    };
897                    Some(Summary {
898                        text: c.text,
899                        sentiment: Some(sentiment.to_string()),
900                    })
901                }
902            })
903            .then(|a: Vec<Summary>, mut worker: WorkerContext| async move {
904                dbg!(&a);
905                worker.emit(&Event::Custom(Box::new(format!(
906                    "Generated {} summaries",
907                    a.len()
908                ))));
909                worker.stop()
910            });
911
912        let pool = SqlitePool::connect(":memory:").await.unwrap();
913        SqliteStorage::setup(&pool).await.unwrap();
914        let mut sqlite = SqliteStorage::new_with_callback(&pool, &Config::new("text-pipeline"));
915
916        let input = UserInput {
917            text: "Rust makes systems programming delightful!".to_string(),
918        };
919        sqlite.push(input).await.unwrap();
920
921        let worker = WorkerBuilder::new("rango-tango")
922            .backend(sqlite)
923            .data(PipelineConfig {
924                min_confidence: 0.8,
925                enable_sentiment: true,
926            })
927            .on_event(|ctx, ev| match ev {
928                Event::Custom(msg) => {
929                    if let Some(m) = msg.downcast_ref::<String>() {
930                        println!("Custom Message: {}", m);
931                    }
932                }
933                Event::Error(_) => {
934                    println!("On Error = {:?}", ev);
935                    ctx.stop().unwrap();
936                }
937                _ => {
938                    println!("On Event = {:?}", ev);
939                }
940            })
941            .build(workflow);
942        worker.run().await.unwrap();
943    }
944}