1use std::{fmt, marker::PhantomData};
206
207use apalis_core::{
208 backend::{
209 Backend, TaskStream,
210 codec::{Codec, json::JsonCodec},
211 },
212 task::Task,
213 worker::{context::WorkerContext, ext::ack::AcknowledgeLayer},
214 layers::Stack,
215};
216use apalis_sql::context::SqlContext;
217use futures::{
218 FutureExt, StreamExt, TryFutureExt, TryStreamExt,
219 channel::mpsc,
220 future::ready,
221 stream::{self, BoxStream, select},
222};
223use libsqlite3_sys::{sqlite3, sqlite3_update_hook};
224use sqlx::{Pool, Sqlite};
225use std::ffi::c_void;
226use ulid::Ulid;
227
228use crate::{
229 ack::{LockTaskLayer, SqliteAck},
230 callback::{HookCallbackListener, update_hook_callback},
231 fetcher::{SqliteFetcher, SqlitePollFetcher, fetch_next},
232 queries::{
233 keep_alive::{initial_heartbeat, keep_alive, keep_alive_stream},
234 reenqueue_orphaned::reenqueue_orphaned_stream,
235 },
236 sink::SqliteSink,
237};
238
239mod ack;
240mod callback;
241mod config;
242pub mod fetcher;
243pub mod queries;
244mod shared;
245pub mod sink;
246
247mod from_row {
248 use chrono::{TimeZone, Utc};
249
250 #[derive(Debug)]
251 pub(crate) struct SqliteTaskRow {
252 pub(crate) job: Vec<u8>,
253 pub(crate) id: Option<String>,
254 pub(crate) job_type: Option<String>,
255 pub(crate) status: Option<String>,
256 pub(crate) attempts: Option<i64>,
257 pub(crate) max_attempts: Option<i64>,
258 pub(crate) run_at: Option<i64>,
259 pub(crate) last_result: Option<String>,
260 pub(crate) lock_at: Option<i64>,
261 pub(crate) lock_by: Option<String>,
262 pub(crate) done_at: Option<i64>,
263 pub(crate) priority: Option<i64>,
264 pub(crate) metadata: Option<String>,
265 }
266
267 impl TryInto<apalis_sql::from_row::TaskRow> for SqliteTaskRow {
268 type Error = sqlx::Error;
269
270 fn try_into(self) -> Result<apalis_sql::from_row::TaskRow, Self::Error> {
271 Ok(apalis_sql::from_row::TaskRow {
272 job: self.job,
273 id: self
274 .id
275 .ok_or_else(|| sqlx::Error::Protocol("Missing id".into()))?,
276 job_type: self
277 .job_type
278 .ok_or_else(|| sqlx::Error::Protocol("Missing job_type".into()))?,
279 status: self
280 .status
281 .ok_or_else(|| sqlx::Error::Protocol("Missing status".into()))?,
282 attempts: self
283 .attempts
284 .ok_or_else(|| sqlx::Error::Protocol("Missing attempts".into()))?
285 as usize,
286 max_attempts: self.max_attempts.map(|v| v as usize),
287 run_at: self.run_at.map(|ts| {
288 Utc.timestamp_opt(ts, 0)
289 .single()
290 .ok_or_else(|| sqlx::Error::Protocol("Invalid run_at timestamp".into()))
291 .unwrap()
292 }),
293 last_result: self
294 .last_result
295 .map(|res| serde_json::from_str(&res).unwrap_or(serde_json::Value::Null)),
296 lock_at: self.lock_at.map(|ts| {
297 Utc.timestamp_opt(ts, 0)
298 .single()
299 .ok_or_else(|| sqlx::Error::Protocol("Invalid run_at timestamp".into()))
300 .unwrap()
301 }),
302 lock_by: self.lock_by,
303 done_at: self.done_at.map(|ts| {
304 Utc.timestamp_opt(ts, 0)
305 .single()
306 .ok_or_else(|| sqlx::Error::Protocol("Invalid run_at timestamp".into()))
307 .unwrap()
308 }),
309 priority: self.priority.map(|v| v as usize),
310 metadata: self
311 .metadata
312 .map(|meta| serde_json::from_str(&meta).unwrap_or(serde_json::Value::Null)),
313 })
314 }
315 }
316}
317
318pub type SqliteTask<Args> = Task<Args, SqlContext, Ulid>;
319pub use callback::{CallbackListener, DbEvent};
320pub use config::Config;
321pub use shared::{SharedPostgresError, SharedSqliteStorage};
322pub use sqlx::SqlitePool;
323
324pub type CompactType = Vec<u8>;
325
326const INSERT_OPERATION: &str = "INSERT";
327const JOBS_TABLE: &str = "Jobs";
328
329#[pin_project::pin_project]
330pub struct SqliteStorage<T, C, Fetcher> {
331 pool: Pool<Sqlite>,
332 job_type: PhantomData<T>,
333 config: Config,
334 codec: PhantomData<C>,
335 #[pin]
336 sink: SqliteSink<T, CompactType, C>,
337 #[pin]
338 fetcher: Fetcher,
339}
340
341impl<T, C, F> fmt::Debug for SqliteStorage<T, C, F> {
342 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
343 f.debug_struct("SqliteStorage")
344 .field("pool", &self.pool)
345 .field("job_type", &"PhantomData<T>")
346 .field("config", &self.config)
347 .field("codec", &std::any::type_name::<C>())
348 .finish()
349 }
350}
351
352impl<T, C, F: Clone> Clone for SqliteStorage<T, C, F> {
353 fn clone(&self) -> Self {
354 SqliteStorage {
355 sink: self.sink.clone(),
356 pool: self.pool.clone(),
357 job_type: PhantomData,
358 config: self.config.clone(),
359 codec: self.codec,
360 fetcher: self.fetcher.clone(),
361 }
362 }
363}
364
365impl SqliteStorage<(), (), ()> {
366 #[cfg(feature = "migrate")]
368 pub async fn setup(pool: &Pool<Sqlite>) -> Result<(), sqlx::Error> {
369 sqlx::query("PRAGMA journal_mode = 'WAL';")
370 .execute(pool)
371 .await?;
372 sqlx::query("PRAGMA temp_store = 2;").execute(pool).await?;
373 sqlx::query("PRAGMA synchronous = NORMAL;")
374 .execute(pool)
375 .await?;
376 sqlx::query("PRAGMA cache_size = 64000;")
377 .execute(pool)
378 .await?;
379 Self::migrations().run(pool).await?;
380 Ok(())
381 }
382
383 #[cfg(feature = "migrate")]
385 pub fn migrations() -> sqlx::migrate::Migrator {
386 sqlx::migrate!("./migrations")
387 }
388}
389
390impl<T> SqliteStorage<T, (), ()> {
391 pub fn new(
393 pool: &Pool<Sqlite>,
394 ) -> SqliteStorage<
395 T,
396 JsonCodec<CompactType>,
397 fetcher::SqliteFetcher<T, CompactType, JsonCodec<CompactType>>,
398 > {
399 let config = Config::new(std::any::type_name::<T>());
400 SqliteStorage {
401 pool: pool.clone(),
402 job_type: PhantomData,
403 sink: SqliteSink::new(pool, &config),
404 config,
405 codec: PhantomData,
406 fetcher: fetcher::SqliteFetcher {
407 _marker: PhantomData,
408 },
409 }
410 }
411
412 pub fn new_in_queue(
413 pool: &Pool<Sqlite>,
414 queue: &str,
415 ) -> SqliteStorage<
416 T,
417 JsonCodec<CompactType>,
418 fetcher::SqliteFetcher<T, CompactType, JsonCodec<CompactType>>,
419 > {
420 let config = Config::new(queue);
421 SqliteStorage {
422 pool: pool.clone(),
423 job_type: PhantomData,
424 sink: SqliteSink::new(pool, &config),
425 config,
426 codec: PhantomData,
427 fetcher: fetcher::SqliteFetcher {
428 _marker: PhantomData,
429 },
430 }
431 }
432
433 pub fn new_with_codec<Codec>(
434 pool: &Pool<Sqlite>,
435 config: &Config,
436 ) -> SqliteStorage<T, Codec, fetcher::SqliteFetcher<T, CompactType, Codec>> {
437 SqliteStorage {
438 pool: pool.clone(),
439 job_type: PhantomData,
440 config: config.clone(),
441 codec: PhantomData,
442 sink: SqliteSink::new(pool, config),
443 fetcher: fetcher::SqliteFetcher {
444 _marker: PhantomData,
445 },
446 }
447 }
448
449 pub fn new_with_config(
450 pool: &Pool<Sqlite>,
451 config: &Config,
452 ) -> SqliteStorage<
453 T,
454 JsonCodec<CompactType>,
455 fetcher::SqliteFetcher<T, CompactType, JsonCodec<CompactType>>,
456 > {
457 SqliteStorage {
458 pool: pool.clone(),
459 job_type: PhantomData,
460 config: config.clone(),
461 codec: PhantomData,
462 sink: SqliteSink::new(pool, config),
463 fetcher: fetcher::SqliteFetcher {
464 _marker: PhantomData,
465 },
466 }
467 }
468
469 pub fn new_with_callback(
470 pool: &Pool<Sqlite>,
471 config: &Config,
472 ) -> SqliteStorage<T, JsonCodec<CompactType>, HookCallbackListener> {
473 SqliteStorage {
474 pool: pool.clone(),
475 job_type: PhantomData,
476 config: config.clone(),
477 codec: PhantomData,
478 sink: SqliteSink::new(pool, config),
479 fetcher: HookCallbackListener,
480 }
481 }
482
483 pub fn new_with_codec_callback<Codec>(
484 pool: &Pool<Sqlite>,
485 config: &Config,
486 ) -> SqliteStorage<T, Codec, HookCallbackListener> {
487 SqliteStorage {
488 pool: pool.clone(),
489 job_type: PhantomData,
490 config: config.clone(),
491 codec: PhantomData,
492 sink: SqliteSink::new(pool, config),
493 fetcher: HookCallbackListener,
494 }
495 }
496}
497
498impl<T, C, F> SqliteStorage<T, C, F> {
499 pub fn config(&self) -> &Config {
500 &self.config
501 }
502}
503
504impl<Args, Decode> Backend for SqliteStorage<Args, Decode, SqliteFetcher<Args, CompactType, Decode>>
505where
506 Args: Send + 'static + Unpin,
507 Decode: Codec<Args, Compact = CompactType> + 'static + Send,
508 Decode::Error: std::error::Error + Send + Sync + 'static,
509{
510 type Args = Args;
511 type IdType = Ulid;
512
513 type Context = SqlContext;
514
515 type Codec = Decode;
516
517 type Compact = CompactType;
518
519 type Error = sqlx::Error;
520
521 type Stream = TaskStream<SqliteTask<Args>, sqlx::Error>;
522
523 type Beat = BoxStream<'static, Result<(), sqlx::Error>>;
524
525 type Layer = Stack<LockTaskLayer, AcknowledgeLayer<SqliteAck>>;
526
527 fn heartbeat(&self, worker: &WorkerContext) -> Self::Beat {
528 let pool = self.pool.clone();
529 let config = self.config.clone();
530 let worker = worker.clone();
531 let keep_alive = keep_alive_stream(pool, config, worker);
532 let reenqueue = reenqueue_orphaned_stream(
533 self.pool.clone(),
534 self.config.clone(),
535 *self.config.keep_alive(),
536 )
537 .map_ok(|_| ());
538 futures::stream::select(keep_alive, reenqueue).boxed()
539 }
540
541 fn middleware(&self) -> Self::Layer {
542 let lock = LockTaskLayer::new(self.pool.clone());
543 let ack = AcknowledgeLayer::new(SqliteAck::new(self.pool.clone()));
544 Stack::new(lock, ack)
545 }
546
547 fn poll(self, worker: &WorkerContext) -> Self::Stream {
548 let fut = initial_heartbeat(
549 self.pool.clone(),
550 self.config().clone(),
551 worker.clone(),
552 "SqliteStorage",
553 );
554 let register = stream::once(fut.map(|_| Ok(None)));
555 register
556 .chain(SqlitePollFetcher::<Args, CompactType, Decode>::new(
557 &self.pool,
558 &self.config,
559 worker,
560 ))
561 .boxed()
562 }
563}
564
565impl<Args, Decode> Backend for SqliteStorage<Args, Decode, HookCallbackListener>
566where
567 Args: Send + 'static + Unpin,
568 Decode: Codec<Args, Compact = CompactType> + Send + 'static,
569 Decode::Error: std::error::Error + Send + Sync + 'static,
570{
571 type Args = Args;
572 type IdType = Ulid;
573
574 type Context = SqlContext;
575
576 type Codec = Decode;
577
578 type Compact = CompactType;
579
580 type Error = sqlx::Error;
581
582 type Stream = TaskStream<SqliteTask<Args>, sqlx::Error>;
583
584 type Beat = BoxStream<'static, Result<(), sqlx::Error>>;
585
586 type Layer = Stack<LockTaskLayer, AcknowledgeLayer<SqliteAck>>;
587
588 fn heartbeat(&self, worker: &WorkerContext) -> Self::Beat {
589 let pool = self.pool.clone();
590 let config = self.config.clone();
591 let worker = worker.clone();
592 let keep_alive = keep_alive_stream(pool, config, worker);
593 let reenqueue = reenqueue_orphaned_stream(
594 self.pool.clone(),
595 self.config.clone(),
596 *self.config.keep_alive(),
597 )
598 .map_ok(|_| ());
599 futures::stream::select(keep_alive, reenqueue).boxed()
600 }
601
602 fn middleware(&self) -> Self::Layer {
603 let lock = LockTaskLayer::new(self.pool.clone());
604 let ack = AcknowledgeLayer::new(SqliteAck::new(self.pool.clone()));
605 Stack::new(lock, ack)
606 }
607
608 fn poll(self, worker: &WorkerContext) -> Self::Stream {
609 let (tx, rx) = mpsc::unbounded::<DbEvent>();
610
611 let listener = CallbackListener::new(rx);
612
613 let pool = self.pool.clone();
614 let config = self.config.clone();
615 let worker = worker.clone();
616 let register_worker = initial_heartbeat(
617 self.pool.clone(),
618 self.config.clone(),
619 worker.clone(),
620 "SqliteStorageWithHook",
621 );
622 let p = pool.clone();
623 let register_worker = stream::once(
624 register_worker
625 .and_then(|_| async move {
626 let mut conn = p.acquire().await?;
628 let handle: *mut sqlite3 =
630 conn.lock_handle().await.unwrap().as_raw_handle().as_ptr();
631
632 let tx_box = Box::new(tx);
634 let tx_ptr = Box::into_raw(tx_box) as *mut c_void;
635
636 unsafe {
637 sqlite3_update_hook(handle, Some(update_hook_callback), tx_ptr);
638 }
639 Ok(())
640 })
641 .map(|_| Ok(None)),
642 );
643 let eager_fetcher: SqlitePollFetcher<Args, CompactType, Decode> =
644 SqlitePollFetcher::new(&self.pool, &self.config, &worker);
645 let lazy_fetcher = listener
646 .filter(|a| ready(a.operation() == INSERT_OPERATION && a.table_name() == JOBS_TABLE))
647 .ready_chunks(self.config.buffer_size())
648 .then(move |_| fetch_next::<Args, Decode>(pool.clone(), config.clone(), worker.clone()))
649 .flat_map(|res| match res {
650 Ok(tasks) => stream::iter(tasks).map(Ok).boxed(),
651 Err(e) => stream::iter(vec![Err(e)]).boxed(),
652 })
653 .map(|res| match res {
654 Ok(task) => Ok(Some(task)),
655 Err(e) => Err(e),
656 });
657
658 register_worker
659 .chain(select(lazy_fetcher, eager_fetcher))
660 .boxed()
661 }
662}
663
664#[cfg(test)]
665mod tests {
666 use std::time::Duration;
667
668 use apalis_workflow::{WorkFlow, WorkflowError};
669 use chrono::Local;
670
671 use apalis_core::{
672 backend::{
673 WeakTaskSink,
674 poll_strategy::{IntervalStrategy, StrategyBuilder},
675 },
676 error::BoxDynError,
677 task::data::Data,
678 worker::{builder::WorkerBuilder, event::Event, ext::event_listener::EventListenerExt},
679 };
680 use serde::{Deserialize, Serialize};
681
682 use super::*;
683
684 #[tokio::test]
685 async fn basic_worker() {
686 const ITEMS: usize = 10;
687 let pool = SqlitePool::connect(":memory:").await.unwrap();
688 SqliteStorage::setup(&pool).await.unwrap();
689
690 let mut backend = SqliteStorage::new(&pool);
691
692 let mut start = 0;
693
694 let mut items = stream::repeat_with(move || {
695 start += 1;
696 start
697 })
698 .take(ITEMS);
699 backend.push_stream(&mut items).await.unwrap();
700
701 println!("Starting worker at {}", Local::now());
702
703 async fn send_reminder(item: usize, wrk: WorkerContext) -> Result<(), BoxDynError> {
704 if ITEMS == item {
705 wrk.stop().unwrap();
706 }
707 Ok(())
708 }
709
710 let worker = WorkerBuilder::new("rango-tango-1")
711 .backend(backend)
712 .build(send_reminder);
713 worker.run().await.unwrap();
714 }
715
716 #[tokio::test]
717 async fn hooked_worker() {
718 const ITEMS: usize = 20;
719 let pool = SqlitePool::connect(":memory:").await.unwrap();
720 SqliteStorage::setup(&pool).await.unwrap();
721
722 let lazy_strategy = StrategyBuilder::new()
723 .apply(IntervalStrategy::new(Duration::from_secs(5)))
724 .build();
725 let config = Config::new("rango-tango-queue")
726 .with_poll_interval(lazy_strategy)
727 .set_buffer_size(5);
728 let backend = SqliteStorage::new_with_callback(&pool, &config);
729
730 tokio::spawn(async move {
731 tokio::time::sleep(Duration::from_secs(2)).await;
732 let mut start = 0;
733
734 let items = stream::repeat_with(move || {
735 start += 1;
736
737 Task::builder(serde_json::to_vec(&start).unwrap())
738 .run_after(Duration::from_secs(1))
739 .with_ctx(SqlContext::new().with_priority(start))
740 .build()
741 })
742 .take(ITEMS)
743 .collect::<Vec<_>>()
744 .await;
745 sink::push_tasks(pool, config, items).await.unwrap();
746 });
747
748 async fn send_reminder(item: usize, wrk: WorkerContext) -> Result<(), BoxDynError> {
749 if item == 1 {
751 apalis_core::timer::sleep(Duration::from_secs(1)).await;
752 wrk.stop().unwrap();
753 }
754 Ok(())
755 }
756
757 let worker = WorkerBuilder::new("rango-tango-1")
758 .backend(backend)
759 .build(send_reminder);
760 worker.run().await.unwrap();
761 }
762
763 #[tokio::test]
764 async fn test_workflow() {
765 let workflow = WorkFlow::new("odd-numbers-workflow")
766 .then(|a: usize| async move { Ok::<_, WorkflowError>((0..=a).collect::<Vec<_>>()) })
767 .filter_map(|x| async move { if x % 2 != 0 { Some(x) } else { None } })
768 .filter_map(|x| async move { if x % 3 != 0 { Some(x) } else { None } })
769 .filter_map(|x| async move { if x % 5 != 0 { Some(x) } else { None } })
770 .delay_for(Duration::from_millis(1000))
771 .then(|a: Vec<usize>| async move {
772 println!("Sum: {}", a.iter().sum::<usize>());
773 Err::<(), WorkflowError>(WorkflowError::MissingContextError)
774 });
775
776 let pool = SqlitePool::connect(":memory:").await.unwrap();
777 SqliteStorage::setup(&pool).await.unwrap();
778 let mut sqlite = SqliteStorage::new_with_callback(
779 &pool,
780 &Config::new("workflow-queue").with_poll_interval(
781 StrategyBuilder::new()
782 .apply(IntervalStrategy::new(Duration::from_millis(100)))
783 .build(),
784 ),
785 );
786
787 sqlite.push(100usize).await.unwrap();
788
789 let worker = WorkerBuilder::new("rango-tango")
790 .backend(sqlite)
791 .on_event(|ctx, ev| {
792 println!("On Event = {:?}", ev);
793 if matches!(ev, Event::Error(_)) {
794 ctx.stop().unwrap();
795 }
796 })
797 .build(workflow);
798 worker.run().await.unwrap();
799 }
800
801 #[tokio::test]
802 async fn test_workflow_complete() {
803 #[derive(Debug, Serialize, Deserialize, Clone)]
804 struct PipelineConfig {
805 min_confidence: f32,
806 enable_sentiment: bool,
807 }
808
809 #[derive(Debug, Serialize, Deserialize)]
810 struct UserInput {
811 text: String,
812 }
813
814 #[derive(Debug, Serialize, Deserialize)]
815 struct Classified {
816 text: String,
817 label: String,
818 confidence: f32,
819 }
820
821 #[derive(Debug, Serialize, Deserialize)]
822 struct Summary {
823 text: String,
824 sentiment: Option<String>,
825 }
826
827 let workflow = WorkFlow::new("text-pipeline")
828 .then(|input: UserInput, mut worker: WorkerContext| async move {
830 worker.emit(&Event::Custom(Box::new(format!(
831 "Preprocessing input: {}",
832 input.text
833 ))));
834 let processed = input.text.to_lowercase();
835 Ok::<_, WorkflowError>(processed)
836 })
837 .then(|text: String| async move {
839 let confidence = 0.85; let items = text.split_whitespace().collect::<Vec<_>>();
841 let results = items
842 .into_iter()
843 .map(|x| Classified {
844 text: x.to_string(),
845 label: if x.contains("rust") {
846 "Tech"
847 } else {
848 "General"
849 }
850 .to_string(),
851 confidence,
852 })
853 .collect::<Vec<_>>();
854 Ok::<_, WorkflowError>(results)
855 })
856 .filter_map(
858 |c: Classified| async move { if c.confidence >= 0.6 { Some(c) } else { None } },
859 )
860 .filter_map(move |c: Classified, config: Data<PipelineConfig>| {
861 let cfg = config.enable_sentiment;
862 async move {
863 if !cfg {
864 return Some(Summary {
865 text: c.text,
866 sentiment: None,
867 });
868 }
869
870 let sentiment = if c.text.contains("delightful") {
872 "positive"
873 } else {
874 "neutral"
875 };
876 Some(Summary {
877 text: c.text,
878 sentiment: Some(sentiment.to_string()),
879 })
880 }
881 })
882 .then(|a: Vec<Summary>, mut worker: WorkerContext| async move {
883 dbg!(&a);
884 worker.emit(&Event::Custom(Box::new(format!(
885 "Generated {} summaries",
886 a.len()
887 ))));
888 worker.stop()
889 });
890
891 let pool = SqlitePool::connect(":memory:").await.unwrap();
892 SqliteStorage::setup(&pool).await.unwrap();
893 let mut sqlite = SqliteStorage::new_with_callback(&pool, &Config::new("text-pipeline"));
894
895 let input = UserInput {
896 text: "Rust makes systems programming delightful!".to_string(),
897 };
898 sqlite.push(input).await.unwrap();
899
900 let worker = WorkerBuilder::new("rango-tango")
901 .backend(sqlite)
902 .data(PipelineConfig {
903 min_confidence: 0.8,
904 enable_sentiment: true,
905 })
906 .on_event(|ctx, ev| match ev {
907 Event::Custom(msg) => {
908 if let Some(m) = msg.downcast_ref::<String>() {
909 println!("Custom Message: {}", m);
910 }
911 }
912 Event::Error(_) => {
913 println!("On Error = {:?}", ev);
914 ctx.stop().unwrap();
915 }
916 _ => {
917 println!("On Event = {:?}", ev);
918 }
919 })
920 .build(workflow);
921 worker.run().await.unwrap();
922 }
923}