apalis_postgres/
lib.rs

1#![doc = include_str!("../README.md")]
2//!
3//! [`PostgresStorageWithListener`]: crate::PostgresStorage
4//! [`SharedPostgresStorage`]: crate::shared::SharedPostgresStorage
5use std::{fmt::Debug, marker::PhantomData};
6
7use apalis_codec::json::JsonCodec;
8use apalis_core::{
9    backend::{Backend, BackendExt, TaskStream, codec::Codec, queue::Queue},
10    features_table,
11    layers::Stack,
12    task::{Task, task_id::TaskId},
13    worker::{context::WorkerContext, ext::ack::AcknowledgeLayer},
14};
15pub use apalis_sql::{config::Config, from_row::TaskRow};
16use futures::{
17    StreamExt, TryFutureExt, TryStreamExt,
18    future::ready,
19    stream::{self, BoxStream, select},
20};
21use serde::Deserialize;
22pub use sqlx::{PgPool, postgres::PgConnectOptions, postgres::PgListener, postgres::Postgres};
23use ulid::Ulid;
24
25pub use crate::{
26    ack::{LockTaskLayer, PgAck},
27    fetcher::{PgFetcher, PgPollFetcher},
28    queries::{
29        keep_alive::{initial_heartbeat, keep_alive_stream},
30        reenqueue_orphaned::reenqueue_orphaned_stream,
31    },
32    sink::PgSink,
33};
34
35mod ack;
36mod fetcher;
37mod from_row;
38
39pub type PgContext = apalis_sql::context::SqlContext<PgPool>;
40mod queries;
41pub mod shared;
42pub mod sink;
43
44pub type PgTask<Args> = Task<Args, PgContext, Ulid>;
45
46pub type PgTaskId = TaskId<Ulid>;
47
48pub type CompactType = Vec<u8>;
49
50#[doc = features_table! {
51    setup = r#"
52        # {
53        #   use apalis_postgres::PostgresStorage;
54        #   use sqlx::PgPool;
55        #   let pool = PgPool::connect(std::env::var("DATABASE_URL").unwrap().as_str()).await.unwrap();
56        #   PostgresStorage::setup(&pool).await.unwrap();
57        #   PostgresStorage::new(&pool)
58        # };
59    "#,
60
61    Backend => supported("Supports storage and retrieval of tasks", true),
62    TaskSink => supported("Ability to push new tasks", true),
63    Serialization => supported("Serialization support for arguments", true),
64    Workflow => supported("Flexible enough to support workflows", true),
65    WebUI => supported("Expose a web interface for monitoring tasks", true),
66    FetchById => supported("Allow fetching a task by its ID", false),
67    RegisterWorker => supported("Allow registering a worker with the backend", false),
68    MakeShared => supported("Share one connection across multiple workers via [`SharedPostgresStorage`]", false),
69    WaitForCompletion => supported("Wait for tasks to complete without blocking", true),
70    ResumeById => supported("Resume a task by its ID", false),
71    ResumeAbandoned => supported("Resume abandoned tasks", false),
72    ListWorkers => supported("List all workers registered with the backend", false),
73    ListTasks => supported("List all tasks in the backend", false),
74}]
75///
76/// [`SharedPostgresStorage`]: crate::shared::SharedPostgresStorage
77#[pin_project::pin_project]
78pub struct PostgresStorage<
79    Args,
80    Compact = CompactType,
81    Codec = JsonCodec<CompactType>,
82    Fetcher = PgFetcher<Args, Compact, Codec>,
83> {
84    _marker: PhantomData<(Args, Compact, Codec)>,
85    pool: PgPool,
86    config: Config,
87    #[pin]
88    fetcher: Fetcher,
89    #[pin]
90    sink: PgSink<Args, Compact, Codec>,
91}
92
93/// A fetcher that does nothing, used for notify-based storage
94#[derive(Debug, Clone, Default)]
95pub struct PgNotify {
96    _private: PhantomData<()>,
97}
98
99impl<Args, Compact, Codec, Fetcher: Clone> Clone
100    for PostgresStorage<Args, Compact, Codec, Fetcher>
101{
102    fn clone(&self) -> Self {
103        Self {
104            _marker: PhantomData,
105            pool: self.pool.clone(),
106            config: self.config.clone(),
107            fetcher: self.fetcher.clone(),
108            sink: self.sink.clone(),
109        }
110    }
111}
112
113impl PostgresStorage<(), (), ()> {
114    /// Perform migrations for storage
115    #[cfg(feature = "migrate")]
116    pub async fn setup(pool: &PgPool) -> Result<(), sqlx::Error> {
117        Self::migrations().run(pool).await?;
118        Ok(())
119    }
120
121    /// Get postgres migrations without running them
122    #[cfg(feature = "migrate")]
123    pub fn migrations() -> sqlx::migrate::Migrator {
124        sqlx::migrate!("./migrations")
125    }
126}
127
128impl<Args> PostgresStorage<Args> {
129    pub fn new(pool: &PgPool) -> Self {
130        let config = Config::new(std::any::type_name::<Args>());
131        Self::new_with_config(pool, &config)
132    }
133
134    /// Creates a new PostgresStorage instance.
135    pub fn new_with_config(pool: &PgPool, config: &Config) -> Self {
136        let sink = PgSink::new(pool, config);
137        Self {
138            _marker: PhantomData,
139            pool: pool.clone(),
140            config: config.clone(),
141            fetcher: PgFetcher {
142                _marker: PhantomData,
143            },
144            sink,
145        }
146    }
147
148    pub fn new_with_notify(
149        pool: &PgPool,
150        config: &Config,
151    ) -> PostgresStorage<Args, CompactType, JsonCodec<CompactType>, PgNotify> {
152        let sink = PgSink::new(pool, config);
153
154        PostgresStorage {
155            _marker: PhantomData,
156            pool: pool.clone(),
157            config: config.clone(),
158            fetcher: PgNotify::default(),
159            sink,
160        }
161    }
162
163    /// Returns a reference to the pool.
164    pub fn pool(&self) -> &PgPool {
165        &self.pool
166    }
167
168    /// Returns a reference to the config.
169    pub fn config(&self) -> &Config {
170        &self.config
171    }
172}
173
174impl<Args, Compact, Codec, Fetcher> PostgresStorage<Args, Compact, Codec, Fetcher> {
175    pub fn with_codec<NewCodec>(self) -> PostgresStorage<Args, Compact, NewCodec, Fetcher> {
176        PostgresStorage {
177            _marker: PhantomData,
178            sink: PgSink::new(&self.pool, &self.config),
179            pool: self.pool,
180            config: self.config,
181            fetcher: self.fetcher,
182        }
183    }
184}
185
186impl<Args, Decode> Backend
187    for PostgresStorage<Args, CompactType, Decode, PgFetcher<Args, CompactType, Decode>>
188where
189    Args: Send + 'static + Unpin,
190    Decode: Codec<Args, Compact = CompactType> + Send + 'static,
191    Decode::Error: std::error::Error + Send + Sync + 'static,
192{
193    type Args = Args;
194
195    type IdType = Ulid;
196
197    type Context = PgContext;
198
199    type Error = sqlx::Error;
200
201    type Stream = TaskStream<PgTask<Args>, sqlx::Error>;
202
203    type Beat = BoxStream<'static, Result<(), sqlx::Error>>;
204
205    type Layer = Stack<LockTaskLayer, AcknowledgeLayer<PgAck>>;
206
207    fn heartbeat(&self, worker: &WorkerContext) -> Self::Beat {
208        let pool = self.pool.clone();
209        let config = self.config.clone();
210        let worker = worker.clone();
211        let keep_alive = keep_alive_stream(pool, config, worker);
212        let reenqueue = reenqueue_orphaned_stream(
213            self.pool.clone(),
214            self.config.clone(),
215            *self.config.keep_alive(),
216        )
217        .map_ok(|_| ());
218        futures::stream::select(keep_alive, reenqueue).boxed()
219    }
220
221    fn middleware(&self) -> Self::Layer {
222        Stack::new(
223            LockTaskLayer::new(self.pool.clone()),
224            AcknowledgeLayer::new(PgAck::new(self.pool.clone())),
225        )
226    }
227
228    fn poll(self, worker: &WorkerContext) -> Self::Stream {
229        self.poll_basic(worker)
230            .map(|a| match a {
231                Ok(Some(task)) => Ok(Some(
232                    task.try_map(|t| Decode::decode(&t))
233                        .map_err(|e| sqlx::Error::Decode(e.into()))?,
234                )),
235                Ok(None) => Ok(None),
236                Err(e) => Err(e),
237            })
238            .boxed()
239    }
240}
241
242impl<Args, Decode> BackendExt
243    for PostgresStorage<Args, CompactType, Decode, PgFetcher<Args, CompactType, Decode>>
244where
245    Args: Send + 'static + Unpin,
246    Decode: Codec<Args, Compact = CompactType> + Send + 'static,
247    Decode::Error: std::error::Error + Send + Sync + 'static,
248{
249    type Compact = CompactType;
250
251    type Codec = Decode;
252    type CompactStream = TaskStream<PgTask<CompactType>, Self::Error>;
253
254    fn get_queue(&self) -> Queue {
255        self.config.queue().clone()
256    }
257
258    fn poll_compact(self, worker: &WorkerContext) -> Self::CompactStream {
259        self.poll_basic(worker).boxed()
260    }
261}
262
263impl<Args, Decode> PostgresStorage<Args, CompactType, Decode, PgFetcher<Args, CompactType, Decode>>
264where
265    Args: Send + 'static + Unpin,
266{
267    fn poll_basic(&self, worker: &WorkerContext) -> TaskStream<PgTask<CompactType>, sqlx::Error> {
268        let register_worker = initial_heartbeat(
269            self.pool.clone(),
270            self.config.clone(),
271            worker.clone(),
272            "PostgresStorage",
273        )
274        .map_ok(|_| None);
275        let register = stream::once(register_worker);
276        register
277            .chain(PgPollFetcher::<CompactType>::new(
278                &self.pool,
279                &self.config,
280                worker,
281            ))
282            .boxed()
283    }
284}
285
286impl<Args, Decode> Backend for PostgresStorage<Args, CompactType, Decode, PgNotify>
287where
288    Args: Send + 'static + Unpin,
289    Decode: Codec<Args, Compact = CompactType> + 'static + Send,
290    Decode::Error: std::error::Error + Send + Sync + 'static,
291{
292    type Args = Args;
293
294    type IdType = Ulid;
295
296    type Context = PgContext;
297
298    type Error = sqlx::Error;
299
300    type Stream = TaskStream<PgTask<Args>, sqlx::Error>;
301
302    type Beat = BoxStream<'static, Result<(), sqlx::Error>>;
303
304    type Layer = Stack<LockTaskLayer, AcknowledgeLayer<PgAck>>;
305
306    fn heartbeat(&self, worker: &WorkerContext) -> Self::Beat {
307        let pool = self.pool.clone();
308        let config = self.config.clone();
309        let worker = worker.clone();
310        let keep_alive = keep_alive_stream(pool, config, worker);
311        let reenqueue = reenqueue_orphaned_stream(
312            self.pool.clone(),
313            self.config.clone(),
314            *self.config.keep_alive(),
315        )
316        .map_ok(|_| ());
317        futures::stream::select(keep_alive, reenqueue).boxed()
318    }
319
320    fn middleware(&self) -> Self::Layer {
321        Stack::new(
322            LockTaskLayer::new(self.pool.clone()),
323            AcknowledgeLayer::new(PgAck::new(self.pool.clone())),
324        )
325    }
326
327    fn poll(self, worker: &WorkerContext) -> Self::Stream {
328        self.poll_with_notify(worker)
329            .map(|a| match a {
330                Ok(Some(task)) => Ok(Some(
331                    task.try_map(|t| Decode::decode(&t))
332                        .map_err(|e| sqlx::Error::Decode(e.into()))?,
333                )),
334                Ok(None) => Ok(None),
335                Err(e) => Err(e),
336            })
337            .boxed()
338    }
339}
340
341impl<Args, Decode> BackendExt for PostgresStorage<Args, CompactType, Decode, PgNotify>
342where
343    Args: Send + 'static + Unpin,
344    Decode: Codec<Args, Compact = CompactType> + 'static + Unpin + Send,
345    Decode::Error: std::error::Error + Send + Sync + 'static,
346{
347    type Compact = CompactType;
348
349    type Codec = Decode;
350    type CompactStream = TaskStream<PgTask<CompactType>, Self::Error>;
351
352    fn get_queue(&self) -> Queue {
353        self.config.queue().clone()
354    }
355
356    fn poll_compact(self, worker: &WorkerContext) -> Self::CompactStream {
357        self.poll_with_notify(worker).boxed()
358    }
359}
360
361impl<Args, Decode> PostgresStorage<Args, CompactType, Decode, PgNotify> {
362    pub fn poll_with_notify(
363        &self,
364        worker: &WorkerContext,
365    ) -> TaskStream<PgTask<CompactType>, sqlx::Error> {
366        let pool = self.pool.clone();
367        let worker_id = worker.name().to_owned();
368        let namespace = self.config.queue().to_string();
369        let listener = async move {
370            let mut fetcher = PgListener::connect_with(&pool)
371                .await
372                .expect("Failed to create listener");
373            fetcher.listen("apalis::job::insert").await.unwrap();
374            fetcher
375        };
376        let fetcher = stream::once(listener).flat_map(|f| f.into_stream());
377        let pool = self.pool.clone();
378        let register_worker = initial_heartbeat(
379            self.pool.clone(),
380            self.config.clone(),
381            worker.clone(),
382            "PostgresStorageWithNotify",
383        )
384        .map_ok(|_| None);
385        let register = stream::once(register_worker);
386        let lazy_fetcher = fetcher
387            .into_stream()
388            .filter_map(move |notification| {
389                let namespace = namespace.clone();
390                async move {
391                    let pg_notification = notification.ok()?;
392                    let payload = pg_notification.payload();
393                    let ev: InsertEvent = serde_json::from_str(payload).ok()?;
394
395                    if ev.job_type == namespace {
396                        return Some(ev.id);
397                    }
398                    None
399                }
400            })
401            .map(|t| t.to_string())
402            .ready_chunks(self.config.buffer_size())
403            .then(move |ids| {
404                let pool = pool.clone();
405                let worker_id = worker_id.clone();
406                async move {
407                    let mut tx = pool.begin().await?;
408                    use crate::from_row::PgTaskRow;
409                    let res: Vec<_> = sqlx::query_file_as!(
410                        PgTaskRow,
411                        "queries/task/queue_by_id.sql",
412                        &ids,
413                        &worker_id
414                    )
415                    .fetch(&mut *tx)
416                    .map(|r| {
417                        let row: TaskRow = r?.try_into()?;
418                        Ok(Some(
419                            row.try_into_task_compact()
420                                .map_err(|e| sqlx::Error::Protocol(e.to_string()))?,
421                        ))
422                    })
423                    .collect()
424                    .await;
425                    tx.commit().await?;
426                    Ok::<_, sqlx::Error>(res)
427                }
428            })
429            .flat_map(|vec| match vec {
430                Ok(vec) => stream::iter(vec.into_iter().map(|res| match res {
431                    Ok(t) => Ok(t),
432                    Err(e) => Err(e),
433                }))
434                .boxed(),
435                Err(e) => stream::once(ready(Err(e))).boxed(),
436            })
437            .boxed();
438
439        let eager_fetcher = StreamExt::boxed(PgPollFetcher::<CompactType>::new(
440            &self.pool,
441            &self.config,
442            worker,
443        ));
444        register.chain(select(lazy_fetcher, eager_fetcher)).boxed()
445    }
446}
447
448#[derive(Debug, Deserialize)]
449pub struct InsertEvent {
450    job_type: String,
451    id: PgTaskId,
452}
453
454#[cfg(test)]
455mod tests {
456    use std::{
457        collections::HashMap,
458        env,
459        time::{Duration, Instant},
460    };
461
462    use apalis_workflow::Workflow;
463    use apalis_workflow::WorkflowSink;
464
465    use apalis_core::{
466        backend::poll_strategy::{IntervalStrategy, StrategyBuilder},
467        error::BoxDynError,
468        task::data::Data,
469        worker::{builder::WorkerBuilder, event::Event, ext::event_listener::EventListenerExt},
470    };
471    use serde::{Deserialize, Serialize};
472
473    use super::*;
474
475    #[tokio::test]
476    async fn basic_worker() {
477        use apalis_core::backend::TaskSink;
478        let pool = PgPool::connect(env::var("DATABASE_URL").unwrap().as_str())
479            .await
480            .unwrap();
481        let mut backend = PostgresStorage::new(&pool);
482
483        let mut items = stream::repeat_with(HashMap::default).take(1);
484        backend.push_stream(&mut items).await.unwrap();
485
486        async fn send_reminder(
487            _: HashMap<String, String>,
488            wrk: WorkerContext,
489        ) -> Result<(), BoxDynError> {
490            tokio::time::sleep(Duration::from_secs(2)).await;
491            wrk.stop().unwrap();
492            Ok(())
493        }
494
495        let worker = WorkerBuilder::new("rango-tango-1")
496            .backend(backend)
497            .build(send_reminder);
498        worker.run().await.unwrap();
499    }
500
501    #[tokio::test]
502    async fn notify_worker() {
503        use apalis_core::backend::TaskSink;
504        let pool = PgPool::connect(env::var("DATABASE_URL").unwrap().as_str())
505            .await
506            .unwrap();
507        let config = Config::new("test").with_poll_interval(
508            StrategyBuilder::new()
509                .apply(IntervalStrategy::new(Duration::from_secs(6)))
510                .build(),
511        );
512        let backend = PostgresStorage::new_with_notify(&pool, &config);
513
514        let mut b = backend.clone();
515
516        tokio::spawn(async move {
517            tokio::time::sleep(Duration::from_secs(2)).await;
518            let mut items = stream::repeat_with(|| {
519                Task::builder(42u32)
520                    .with_ctx(PgContext::new().with_priority(1))
521                    .build()
522            })
523            .take(1);
524            b.push_all(&mut items).await.unwrap();
525        });
526
527        async fn send_reminder(_: u32, wrk: WorkerContext) -> Result<(), BoxDynError> {
528            wrk.stop().unwrap();
529            Ok(())
530        }
531
532        let instant = Instant::now();
533        let worker = WorkerBuilder::new("rango-tango-2")
534            .backend(backend)
535            .build(send_reminder);
536        worker.run().await.unwrap();
537        let run_for = instant.elapsed();
538        assert!(
539            run_for < Duration::from_secs(4),
540            "Worker did not use notify mechanism"
541        );
542    }
543
544    #[tokio::test]
545    async fn test_workflow_complete() {
546        #[derive(Debug, Serialize, Deserialize, Clone)]
547        struct PipelineConfig {
548            min_confidence: f32,
549            enable_sentiment: bool,
550        }
551
552        #[derive(Debug, Serialize, Deserialize)]
553        struct UserInput {
554            text: String,
555        }
556
557        #[derive(Debug, Serialize, Deserialize)]
558        struct Classified {
559            text: String,
560            label: String,
561            confidence: f32,
562        }
563
564        #[derive(Debug, Serialize, Deserialize)]
565        struct Summary {
566            text: String,
567            sentiment: Option<String>,
568        }
569
570        let workflow = Workflow::new("text-pipeline")
571            // Step 1: Preprocess input (e.g., tokenize, lowercase)
572            .and_then(|input: UserInput, mut worker: WorkerContext| async move {
573                worker.emit(&Event::Custom(Box::new(format!(
574                    "Preprocessing input: {}",
575                    input.text
576                ))));
577                let processed = input.text.to_lowercase();
578                Ok::<_, BoxDynError>(processed)
579            })
580            // Step 2: Classify text
581            .and_then(|text: String| async move {
582                let confidence = 0.85; // pretend model confidence
583                let items = text.split_whitespace().collect::<Vec<_>>();
584                let results = items
585                    .into_iter()
586                    .map(|x| Classified {
587                        text: x.to_string(),
588                        label: if x.contains("rust") {
589                            "Tech"
590                        } else {
591                            "General"
592                        }
593                        .to_string(),
594                        confidence,
595                    })
596                    .collect::<Vec<_>>();
597                Ok::<_, BoxDynError>(results)
598            })
599            // Step 3: Filter out low-confidence predictions
600            .filter_map(
601                |c: Classified| async move { if c.confidence >= 0.6 { Some(c) } else { None } },
602            )
603            .filter_map(move |c: Classified, config: Data<PipelineConfig>| {
604                let cfg = config.enable_sentiment;
605                async move {
606                    if !cfg {
607                        return Some(Summary {
608                            text: c.text,
609                            sentiment: None,
610                        });
611                    }
612
613                    // pretend we run a sentiment model
614                    let sentiment = if c.text.contains("delightful") {
615                        "positive"
616                    } else {
617                        "neutral"
618                    };
619                    Some(Summary {
620                        text: c.text,
621                        sentiment: Some(sentiment.to_string()),
622                    })
623                }
624            })
625            .and_then(|a: Vec<Summary>, mut worker: WorkerContext| async move {
626                dbg!(&a);
627                worker.emit(&Event::Custom(Box::new(format!(
628                    "Generated {} summaries",
629                    a.len()
630                ))));
631                worker.stop()
632            });
633
634        let pool = PgPool::connect(env::var("DATABASE_URL").unwrap().as_str())
635            .await
636            .unwrap();
637        let config = Config::new("test").with_poll_interval(
638            StrategyBuilder::new()
639                .apply(IntervalStrategy::new(Duration::from_secs(1)))
640                .build(),
641        );
642        let mut backend = PostgresStorage::new_with_notify(&pool, &config);
643
644        let input = UserInput {
645            text: "Rust makes systems programming delightful!".to_string(),
646        };
647        backend.push_start(input).await.unwrap();
648
649        let worker = WorkerBuilder::new("rango-tango")
650            .backend(backend)
651            .data(PipelineConfig {
652                min_confidence: 0.8,
653                enable_sentiment: true,
654            })
655            .on_event(|ctx, ev| match ev {
656                Event::Custom(msg) => {
657                    if let Some(m) = msg.downcast_ref::<String>() {
658                        println!("Custom Message: {m}");
659                    }
660                }
661                Event::Error(_) => {
662                    println!("On Error = {ev:?}");
663                    ctx.stop().unwrap();
664                }
665                _ => {
666                    println!("On Event = {ev:?}");
667                }
668            })
669            .build(workflow);
670        worker.run().await.unwrap();
671    }
672}