apalis_postgres/
lib.rs

1#![doc = include_str!("../README.md")]
2//!
3//! [`PostgresStorageWithListener`]: crate::PostgresStorage
4//! [`SharedPostgresStorage`]: crate::shared::SharedPostgresStorage
5use std::{fmt::Debug, marker::PhantomData};
6
7use apalis_codec::json::JsonCodec;
8use apalis_core::{
9    backend::{Backend, BackendExt, TaskStream, codec::Codec, queue::Queue},
10    features_table,
11    layers::Stack,
12    task::{Task, task_id::TaskId},
13    worker::{context::WorkerContext, ext::ack::AcknowledgeLayer},
14};
15pub use apalis_sql::{config::Config, from_row::TaskRow};
16use futures::{
17    StreamExt, TryFutureExt, TryStreamExt,
18    future::ready,
19    stream::{self, BoxStream, select},
20};
21use serde::Deserialize;
22pub use sqlx::{PgPool, postgres::PgConnectOptions, postgres::PgListener, postgres::Postgres};
23use ulid::Ulid;
24
25pub use crate::{
26    ack::{LockTaskLayer, PgAck},
27    fetcher::{PgFetcher, PgPollFetcher},
28    queries::{
29        keep_alive::{initial_heartbeat, keep_alive_stream},
30        reenqueue_orphaned::reenqueue_orphaned_stream,
31    },
32    sink::PgSink,
33};
34
35mod ack;
36mod fetcher;
37mod from_row;
38
39pub type PgContext = apalis_sql::context::SqlContext<PgPool>;
40mod queries;
41pub mod shared;
42pub mod sink;
43
44pub type PgTask<Args> = Task<Args, PgContext, Ulid>;
45
46pub type PgTaskId = TaskId<Ulid>;
47
48pub type CompactType = Vec<u8>;
49
50#[doc = features_table! {
51    setup = r#"
52        # {
53        #   use apalis_postgres::PostgresStorage;
54        #   use sqlx::PgPool;
55        #   let pool = PgPool::connect(std::env::var("DATABASE_URL").unwrap().as_str()).await.unwrap();
56        #   PostgresStorage::setup(&pool).await.unwrap();
57        #   PostgresStorage::new(&pool)
58        # };
59    "#,
60
61    Backend => supported("Supports storage and retrieval of tasks", true),
62    TaskSink => supported("Ability to push new tasks", true),
63    Serialization => supported("Serialization support for arguments", true),
64    Workflow => supported("Flexible enough to support workflows", true),
65    WebUI => supported("Expose a web interface for monitoring tasks", true),
66    FetchById => supported("Allow fetching a task by its ID", false),
67    RegisterWorker => supported("Allow registering a worker with the backend", false),
68    MakeShared => supported("Share one connection across multiple workers via [`SharedPostgresStorage`]", false),
69    WaitForCompletion => supported("Wait for tasks to complete without blocking", true),
70    ResumeById => supported("Resume a task by its ID", false),
71    ResumeAbandoned => supported("Resume abandoned tasks", false),
72    ListWorkers => supported("List all workers registered with the backend", false),
73    ListTasks => supported("List all tasks in the backend", false),
74}]
75///
76/// [`SharedPostgresStorage`]: crate::shared::SharedPostgresStorage
77#[pin_project::pin_project]
78pub struct PostgresStorage<
79    Args,
80    Compact = CompactType,
81    Codec = JsonCodec<CompactType>,
82    Fetcher = PgFetcher<Args, Compact, Codec>,
83> {
84    _marker: PhantomData<(Args, Compact, Codec)>,
85    pool: PgPool,
86    config: Config,
87    #[pin]
88    fetcher: Fetcher,
89    #[pin]
90    sink: PgSink<Args, Compact, Codec>,
91}
92
93/// A fetcher that does nothing, used for notify-based storage
94#[derive(Debug, Clone, Default)]
95pub struct PgNotify {
96    _private: PhantomData<()>,
97}
98
99impl<Args, Compact, Codec, Fetcher: Clone> Clone
100    for PostgresStorage<Args, Compact, Codec, Fetcher>
101{
102    fn clone(&self) -> Self {
103        Self {
104            _marker: PhantomData,
105            pool: self.pool.clone(),
106            config: self.config.clone(),
107            fetcher: self.fetcher.clone(),
108            sink: self.sink.clone(),
109        }
110    }
111}
112
113impl PostgresStorage<(), (), ()> {
114    /// Perform migrations for storage
115    #[cfg(feature = "migrate")]
116    pub async fn setup(pool: &PgPool) -> Result<(), sqlx::Error> {
117        Self::migrations().run(pool).await?;
118        Ok(())
119    }
120
121    /// Get postgres migrations without running them
122    #[cfg(feature = "migrate")]
123    pub fn migrations() -> sqlx::migrate::Migrator {
124        sqlx::migrate!("./migrations")
125    }
126}
127
128impl<Args> PostgresStorage<Args> {
129    pub fn new(pool: &PgPool) -> Self {
130        let config = Config::new(std::any::type_name::<Args>());
131        Self::new_with_config(pool, &config)
132    }
133
134    /// Creates a new PostgresStorage instance.
135    pub fn new_with_config(pool: &PgPool, config: &Config) -> Self {
136        let sink = PgSink::new(pool, config);
137        Self {
138            _marker: PhantomData,
139            pool: pool.clone(),
140            config: config.clone(),
141            fetcher: PgFetcher {
142                _marker: PhantomData,
143            },
144            sink,
145        }
146    }
147
148    pub fn new_with_notify(
149        pool: &PgPool,
150        config: &Config,
151    ) -> PostgresStorage<Args, CompactType, JsonCodec<CompactType>, PgNotify> {
152        let sink = PgSink::new(pool, config);
153
154        PostgresStorage {
155            _marker: PhantomData,
156            pool: pool.clone(),
157            config: config.clone(),
158            fetcher: PgNotify::default(),
159            sink,
160        }
161    }
162
163    /// Returns a reference to the pool.
164    pub fn pool(&self) -> &PgPool {
165        &self.pool
166    }
167
168    /// Returns a reference to the config.
169    pub fn config(&self) -> &Config {
170        &self.config
171    }
172}
173
174impl<Args, Compact, Codec, Fetcher> PostgresStorage<Args, Compact, Codec, Fetcher> {
175    pub fn with_codec<NewCodec>(self) -> PostgresStorage<Args, Compact, NewCodec, Fetcher> {
176        PostgresStorage {
177            _marker: PhantomData,
178            sink: PgSink::new(&self.pool, &self.config),
179            pool: self.pool,
180            config: self.config,
181            fetcher: self.fetcher,
182        }
183    }
184}
185
186impl<Args, Decode> Backend
187    for PostgresStorage<Args, CompactType, Decode, PgFetcher<Args, CompactType, Decode>>
188where
189    Args: Send + 'static + Unpin,
190    Decode: Codec<Args, Compact = CompactType> + Send + 'static,
191    Decode::Error: std::error::Error + Send + Sync + 'static,
192{
193    type Args = Args;
194
195    type IdType = Ulid;
196
197    type Context = PgContext;
198
199    type Error = sqlx::Error;
200
201    type Stream = TaskStream<PgTask<Args>, sqlx::Error>;
202
203    type Beat = BoxStream<'static, Result<(), sqlx::Error>>;
204
205    type Layer = Stack<LockTaskLayer, AcknowledgeLayer<PgAck>>;
206
207    fn heartbeat(&self, worker: &WorkerContext) -> Self::Beat {
208        let pool = self.pool.clone();
209        let config = self.config.clone();
210        let worker = worker.clone();
211        let keep_alive = keep_alive_stream(pool, config, worker);
212        let reenqueue = reenqueue_orphaned_stream(
213            self.pool.clone(),
214            self.config.clone(),
215            *self.config.keep_alive(),
216        )
217        .map_ok(|_| ());
218        futures::stream::select(keep_alive, reenqueue).boxed()
219    }
220
221    fn middleware(&self) -> Self::Layer {
222        Stack::new(
223            LockTaskLayer::new(self.pool.clone()),
224            AcknowledgeLayer::new(PgAck::new(self.pool.clone())),
225        )
226    }
227
228    fn poll(self, worker: &WorkerContext) -> Self::Stream {
229        self.poll_basic(worker)
230            .map(|a| match a {
231                Ok(Some(task)) => Ok(Some(
232                    task.try_map(|t| Decode::decode(&t))
233                        .map_err(|e| sqlx::Error::Decode(e.into()))?,
234                )),
235                Ok(None) => Ok(None),
236                Err(e) => Err(e),
237            })
238            .boxed()
239    }
240}
241
242impl<Args, Decode> BackendExt
243    for PostgresStorage<Args, CompactType, Decode, PgFetcher<Args, CompactType, Decode>>
244where
245    Args: Send + 'static + Unpin,
246    Decode: Codec<Args, Compact = CompactType> + Send + 'static,
247    Decode::Error: std::error::Error + Send + Sync + 'static,
248{
249    type Compact = CompactType;
250
251    type Codec = Decode;
252    type CompactStream = TaskStream<PgTask<CompactType>, Self::Error>;
253
254    fn get_queue(&self) -> Queue {
255        self.config.queue().clone()
256    }
257    fn poll_compact(self, worker: &WorkerContext) -> Self::CompactStream {
258        self.poll_basic(worker).boxed()
259    }
260}
261
262impl<Args, Decode> PostgresStorage<Args, CompactType, Decode, PgFetcher<Args, CompactType, Decode>>
263where
264    Args: Send + 'static + Unpin,
265{
266    fn poll_basic(&self, worker: &WorkerContext) -> TaskStream<PgTask<CompactType>, sqlx::Error> {
267        let register_worker = initial_heartbeat(
268            self.pool.clone(),
269            self.config.clone(),
270            worker.clone(),
271            "PostgresStorage",
272        )
273        .map_ok(|_| None);
274        let register = stream::once(register_worker);
275        register
276            .chain(PgPollFetcher::<CompactType>::new(
277                &self.pool,
278                &self.config,
279                worker,
280            ))
281            .boxed()
282    }
283}
284
285impl<Args, Decode> Backend for PostgresStorage<Args, CompactType, Decode, PgNotify>
286where
287    Args: Send + 'static + Unpin,
288    Decode: Codec<Args, Compact = CompactType> + 'static + Send,
289    Decode::Error: std::error::Error + Send + Sync + 'static,
290{
291    type Args = Args;
292
293    type IdType = Ulid;
294
295    type Context = PgContext;
296
297    type Error = sqlx::Error;
298
299    type Stream = TaskStream<PgTask<Args>, sqlx::Error>;
300
301    type Beat = BoxStream<'static, Result<(), sqlx::Error>>;
302
303    type Layer = Stack<LockTaskLayer, AcknowledgeLayer<PgAck>>;
304
305    fn heartbeat(&self, worker: &WorkerContext) -> Self::Beat {
306        let pool = self.pool.clone();
307        let config = self.config.clone();
308        let worker = worker.clone();
309        let keep_alive = keep_alive_stream(pool, config, worker);
310        let reenqueue = reenqueue_orphaned_stream(
311            self.pool.clone(),
312            self.config.clone(),
313            *self.config.keep_alive(),
314        )
315        .map_ok(|_| ());
316        futures::stream::select(keep_alive, reenqueue).boxed()
317    }
318
319    fn middleware(&self) -> Self::Layer {
320        Stack::new(
321            LockTaskLayer::new(self.pool.clone()),
322            AcknowledgeLayer::new(PgAck::new(self.pool.clone())),
323        )
324    }
325
326    fn poll(self, worker: &WorkerContext) -> Self::Stream {
327        self.poll_with_notify(worker)
328            .map(|a| match a {
329                Ok(Some(task)) => Ok(Some(
330                    task.try_map(|t| Decode::decode(&t))
331                        .map_err(|e| sqlx::Error::Decode(e.into()))?,
332                )),
333                Ok(None) => Ok(None),
334                Err(e) => Err(e),
335            })
336            .boxed()
337    }
338}
339
340impl<Args, Decode> BackendExt for PostgresStorage<Args, CompactType, Decode, PgNotify>
341where
342    Args: Send + 'static + Unpin,
343    Decode: Codec<Args, Compact = CompactType> + 'static + Unpin + Send,
344    Decode::Error: std::error::Error + Send + Sync + 'static,
345{
346    type Compact = CompactType;
347
348    type Codec = Decode;
349    type CompactStream = TaskStream<PgTask<CompactType>, Self::Error>;
350
351    fn get_queue(&self) -> Queue {
352        self.config.queue().clone()
353    }
354
355    fn poll_compact(self, worker: &WorkerContext) -> Self::CompactStream {
356        self.poll_with_notify(worker).boxed()
357    }
358}
359
360impl<Args, Decode> PostgresStorage<Args, CompactType, Decode, PgNotify> {
361    pub fn poll_with_notify(
362        &self,
363        worker: &WorkerContext,
364    ) -> TaskStream<PgTask<CompactType>, sqlx::Error> {
365        let pool = self.pool.clone();
366        let worker_id = worker.name().to_owned();
367        let namespace = self.config.queue().to_string();
368        let listener = async move {
369            let mut fetcher = PgListener::connect_with(&pool)
370                .await
371                .expect("Failed to create listener");
372            fetcher.listen("apalis::job::insert").await.unwrap();
373            fetcher
374        };
375        let fetcher = stream::once(listener).flat_map(|f| f.into_stream());
376        let pool = self.pool.clone();
377        let register_worker = initial_heartbeat(
378            self.pool.clone(),
379            self.config.clone(),
380            worker.clone(),
381            "PostgresStorageWithNotify",
382        )
383        .map_ok(|_| None);
384        let register = stream::once(register_worker);
385        let lazy_fetcher = fetcher
386            .into_stream()
387            .filter_map(move |notification| {
388                let namespace = namespace.clone();
389                async move {
390                    let pg_notification = notification.ok()?;
391                    let payload = pg_notification.payload();
392                    let ev: InsertEvent = serde_json::from_str(payload).ok()?;
393
394                    if ev.job_type == namespace {
395                        return Some(ev.id);
396                    }
397                    None
398                }
399            })
400            .map(|t| t.to_string())
401            .ready_chunks(self.config.buffer_size())
402            .then(move |ids| {
403                let pool = pool.clone();
404                let worker_id = worker_id.clone();
405                async move {
406                    let mut tx = pool.begin().await?;
407                    use crate::from_row::PgTaskRow;
408                    let res: Vec<_> = sqlx::query_file_as!(
409                        PgTaskRow,
410                        "queries/task/queue_by_id.sql",
411                        &ids,
412                        &worker_id
413                    )
414                    .fetch(&mut *tx)
415                    .map(|r| {
416                        let row: TaskRow = r?.try_into()?;
417                        Ok(Some(
418                            row.try_into_task_compact()
419                                .map_err(|e| sqlx::Error::Protocol(e.to_string()))?,
420                        ))
421                    })
422                    .collect()
423                    .await;
424                    tx.commit().await?;
425                    Ok::<_, sqlx::Error>(res)
426                }
427            })
428            .flat_map(|vec| match vec {
429                Ok(vec) => stream::iter(vec.into_iter().map(|res| match res {
430                    Ok(t) => Ok(t),
431                    Err(e) => Err(e),
432                }))
433                .boxed(),
434                Err(e) => stream::once(ready(Err(e))).boxed(),
435            })
436            .boxed();
437
438        let eager_fetcher = StreamExt::boxed(PgPollFetcher::<CompactType>::new(
439            &self.pool,
440            &self.config,
441            worker,
442        ));
443        register.chain(select(lazy_fetcher, eager_fetcher)).boxed()
444    }
445}
446
447#[derive(Debug, Deserialize)]
448pub struct InsertEvent {
449    job_type: String,
450    id: PgTaskId,
451}
452
453#[cfg(test)]
454mod tests {
455    use std::{
456        collections::HashMap,
457        env,
458        time::{Duration, Instant},
459    };
460
461    use apalis_workflow::Workflow;
462    use apalis_workflow::WorkflowSink;
463
464    use apalis_core::{
465        backend::poll_strategy::{IntervalStrategy, StrategyBuilder},
466        error::BoxDynError,
467        task::data::Data,
468        worker::{builder::WorkerBuilder, event::Event, ext::event_listener::EventListenerExt},
469    };
470    use serde::{Deserialize, Serialize};
471
472    use super::*;
473
474    #[tokio::test]
475    async fn basic_worker() {
476        use apalis_core::backend::TaskSink;
477        let pool = PgPool::connect(env::var("DATABASE_URL").unwrap().as_str())
478            .await
479            .unwrap();
480        let mut backend = PostgresStorage::new(&pool);
481
482        let mut items = stream::repeat_with(HashMap::default).take(1);
483        backend.push_stream(&mut items).await.unwrap();
484
485        async fn send_reminder(
486            _: HashMap<String, String>,
487            wrk: WorkerContext,
488        ) -> Result<(), BoxDynError> {
489            tokio::time::sleep(Duration::from_secs(2)).await;
490            wrk.stop().unwrap();
491            Ok(())
492        }
493
494        let worker = WorkerBuilder::new("rango-tango-1")
495            .backend(backend)
496            .build(send_reminder);
497        worker.run().await.unwrap();
498    }
499
500    #[tokio::test]
501    async fn notify_worker() {
502        use apalis_core::backend::TaskSink;
503        let pool = PgPool::connect(env::var("DATABASE_URL").unwrap().as_str())
504            .await
505            .unwrap();
506        let config = Config::new("test").with_poll_interval(
507            StrategyBuilder::new()
508                .apply(IntervalStrategy::new(Duration::from_secs(6)))
509                .build(),
510        );
511        let backend = PostgresStorage::new_with_notify(&pool, &config);
512
513        let mut b = backend.clone();
514
515        tokio::spawn(async move {
516            tokio::time::sleep(Duration::from_secs(2)).await;
517            let mut items = stream::repeat_with(|| {
518                Task::builder(42u32)
519                    .with_ctx(PgContext::new().with_priority(1))
520                    .build()
521            })
522            .take(1);
523            b.push_all(&mut items).await.unwrap();
524        });
525
526        async fn send_reminder(_: u32, wrk: WorkerContext) -> Result<(), BoxDynError> {
527            wrk.stop().unwrap();
528            Ok(())
529        }
530
531        let instant = Instant::now();
532        let worker = WorkerBuilder::new("rango-tango-2")
533            .backend(backend)
534            .build(send_reminder);
535        worker.run().await.unwrap();
536        let run_for = instant.elapsed();
537        assert!(
538            run_for < Duration::from_secs(4),
539            "Worker did not use notify mechanism"
540        );
541    }
542
543    #[tokio::test]
544    async fn test_workflow_complete() {
545        #[derive(Debug, Serialize, Deserialize, Clone)]
546        struct PipelineConfig {
547            min_confidence: f32,
548            enable_sentiment: bool,
549        }
550
551        #[derive(Debug, Serialize, Deserialize)]
552        struct UserInput {
553            text: String,
554        }
555
556        #[derive(Debug, Serialize, Deserialize)]
557        struct Classified {
558            text: String,
559            label: String,
560            confidence: f32,
561        }
562
563        #[derive(Debug, Serialize, Deserialize)]
564        struct Summary {
565            text: String,
566            sentiment: Option<String>,
567        }
568
569        let workflow = Workflow::new("text-pipeline")
570            // Step 1: Preprocess input (e.g., tokenize, lowercase)
571            .and_then(|input: UserInput, mut worker: WorkerContext| async move {
572                worker.emit(&Event::Custom(Box::new(format!(
573                    "Preprocessing input: {}",
574                    input.text
575                ))));
576                let processed = input.text.to_lowercase();
577                Ok::<_, BoxDynError>(processed)
578            })
579            // Step 2: Classify text
580            .and_then(|text: String| async move {
581                let confidence = 0.85; // pretend model confidence
582                let items = text.split_whitespace().collect::<Vec<_>>();
583                let results = items
584                    .into_iter()
585                    .map(|x| Classified {
586                        text: x.to_string(),
587                        label: if x.contains("rust") {
588                            "Tech"
589                        } else {
590                            "General"
591                        }
592                        .to_string(),
593                        confidence,
594                    })
595                    .collect::<Vec<_>>();
596                Ok::<_, BoxDynError>(results)
597            })
598            // Step 3: Filter out low-confidence predictions
599            .filter_map(
600                |c: Classified| async move { if c.confidence >= 0.6 { Some(c) } else { None } },
601            )
602            .filter_map(move |c: Classified, config: Data<PipelineConfig>| {
603                let cfg = config.enable_sentiment;
604                async move {
605                    if !cfg {
606                        return Some(Summary {
607                            text: c.text,
608                            sentiment: None,
609                        });
610                    }
611
612                    // pretend we run a sentiment model
613                    let sentiment = if c.text.contains("delightful") {
614                        "positive"
615                    } else {
616                        "neutral"
617                    };
618                    Some(Summary {
619                        text: c.text,
620                        sentiment: Some(sentiment.to_string()),
621                    })
622                }
623            })
624            .and_then(|a: Vec<Summary>, mut worker: WorkerContext| async move {
625                dbg!(&a);
626                worker.emit(&Event::Custom(Box::new(format!(
627                    "Generated {} summaries",
628                    a.len()
629                ))));
630                worker.stop()
631            });
632
633        let pool = PgPool::connect(env::var("DATABASE_URL").unwrap().as_str())
634            .await
635            .unwrap();
636        let config = Config::new("test").with_poll_interval(
637            StrategyBuilder::new()
638                .apply(IntervalStrategy::new(Duration::from_secs(1)))
639                .build(),
640        );
641        let mut backend = PostgresStorage::new_with_notify(&pool, &config);
642
643        let input = UserInput {
644            text: "Rust makes systems programming delightful!".to_string(),
645        };
646        backend.push_start(input).await.unwrap();
647
648        let worker = WorkerBuilder::new("rango-tango")
649            .backend(backend)
650            .data(PipelineConfig {
651                min_confidence: 0.8,
652                enable_sentiment: true,
653            })
654            .on_event(|ctx, ev| match ev {
655                Event::Custom(msg) => {
656                    if let Some(m) = msg.downcast_ref::<String>() {
657                        println!("Custom Message: {m}");
658                    }
659                }
660                Event::Error(_) => {
661                    println!("On Error = {ev:?}");
662                    ctx.stop().unwrap();
663                }
664                _ => {
665                    println!("On Event = {ev:?}");
666                }
667            })
668            .build(workflow);
669        worker.run().await.unwrap();
670    }
671}