apalis_postgres/
lib.rs

1#![doc = include_str!("../README.md")]
2//!
3//! [`PostgresStorageWithListener`]: crate::PostgresStorage
4//! [`SharedPostgresStorage`]: crate::shared::SharedPostgresStorage
5use std::{fmt::Debug, marker::PhantomData};
6
7use apalis_core::{
8    backend::{
9        Backend, BackendExt, TaskStream,
10        codec::{Codec, json::JsonCodec},
11    },
12    features_table,
13    layers::Stack,
14    task::{Task, task_id::TaskId},
15    worker::{context::WorkerContext, ext::ack::AcknowledgeLayer},
16};
17use apalis_sql::from_row::TaskRow;
18use futures::{
19    StreamExt, TryFutureExt, TryStreamExt,
20    future::ready,
21    stream::{self, BoxStream, select},
22};
23use serde::Deserialize;
24pub use sqlx::{PgPool, postgres::PgConnectOptions, postgres::PgListener, postgres::Postgres};
25use ulid::Ulid;
26
27pub use crate::{
28    ack::{LockTaskLayer, PgAck},
29    fetcher::{PgFetcher, PgPollFetcher},
30    queries::{
31        keep_alive::{initial_heartbeat, keep_alive_stream},
32        reenqueue_orphaned::reenqueue_orphaned_stream,
33    },
34    sink::PgSink,
35};
36
37mod ack;
38mod config;
39mod fetcher;
40mod from_row;
41
42pub type PgContext = apalis_sql::context::SqlContext;
43pub use config::Config;
44
45mod queries;
46pub mod shared;
47pub mod sink;
48
49pub type PgTask<Args> = Task<Args, PgContext, Ulid>;
50
51pub type CompactType = Vec<u8>;
52
53#[derive(Debug, Clone, Default)]
54pub struct PgNotify {
55    _private: PhantomData<()>,
56}
57
58#[doc = features_table! {
59    setup = r#"
60        # {
61        #   use apalis_postgres::PostgresStorage;
62        #   use sqlx::PgPool;
63        #   let pool = PgPool::connect(std::env::var("DATABASE_URL").unwrap().as_str()).await.unwrap();
64        #   PostgresStorage::setup(&pool).await.unwrap();
65        #   PostgresStorage::new(&pool)
66        # };
67    "#,
68
69    Backend => supported("Supports storage and retrieval of tasks", true),
70    TaskSink => supported("Ability to push new tasks", true),
71    Serialization => supported("Serialization support for arguments", true),
72    Workflow => supported("Flexible enough to support workflows", true),
73    WebUI => supported("Expose a web interface for monitoring tasks", true),
74    FetchById => supported("Allow fetching a task by its ID", false),
75    RegisterWorker => supported("Allow registering a worker with the backend", false),
76    MakeShared => supported("Share one connection across multiple workers via [`SharedPostgresStorage`]", false),
77    WaitForCompletion => supported("Wait for tasks to complete without blocking", true),
78    ResumeById => supported("Resume a task by its ID", false),
79    ResumeAbandoned => supported("Resume abandoned tasks", false),
80    ListWorkers => supported("List all workers registered with the backend", false),
81    ListTasks => supported("List all tasks in the backend", false),
82}]
83#[pin_project::pin_project]
84pub struct PostgresStorage<
85    Args,
86    Compact = CompactType,
87    Codec = JsonCodec<CompactType>,
88    Fetcher = PgFetcher<Args, Compact, Codec>,
89> {
90    _marker: PhantomData<(Args, Compact, Codec)>,
91    pool: PgPool,
92    config: Config,
93    #[pin]
94    fetcher: Fetcher,
95    #[pin]
96    sink: PgSink<Args, Compact, Codec>,
97}
98
99impl<Args, Compact, Codec, Fetcher: Clone> Clone
100    for PostgresStorage<Args, Compact, Codec, Fetcher>
101{
102    fn clone(&self) -> Self {
103        Self {
104            _marker: PhantomData,
105            pool: self.pool.clone(),
106            config: self.config.clone(),
107            fetcher: self.fetcher.clone(),
108            sink: self.sink.clone(),
109        }
110    }
111}
112
113impl PostgresStorage<(), (), ()> {
114    /// Perform migrations for storage
115    #[cfg(feature = "migrate")]
116    pub async fn setup(pool: &PgPool) -> Result<(), sqlx::Error> {
117        Self::migrations().run(pool).await?;
118        Ok(())
119    }
120
121    /// Get postgres migrations without running them
122    #[cfg(feature = "migrate")]
123    pub fn migrations() -> sqlx::migrate::Migrator {
124        sqlx::migrate!("./migrations")
125    }
126}
127
128impl<Args> PostgresStorage<Args> {
129    pub fn new(pool: &PgPool) -> Self {
130        let config = Config::new(std::any::type_name::<Args>());
131        Self::new_with_config(pool, &config)
132    }
133
134    /// Creates a new PostgresStorage instance.
135    pub fn new_with_config(pool: &PgPool, config: &Config) -> Self {
136        let sink = PgSink::new(pool, config);
137        Self {
138            _marker: PhantomData,
139            pool: pool.clone(),
140            config: config.clone(),
141            fetcher: PgFetcher {
142                _marker: PhantomData,
143            },
144            sink,
145        }
146    }
147
148    pub fn new_with_notify(
149        pool: &PgPool,
150        config: &Config,
151    ) -> PostgresStorage<Args, CompactType, JsonCodec<CompactType>, PgNotify> {
152        let sink = PgSink::new(pool, config);
153
154        PostgresStorage {
155            _marker: PhantomData,
156            pool: pool.clone(),
157            config: config.clone(),
158            fetcher: PgNotify::default(),
159            sink,
160        }
161    }
162
163    /// Returns a reference to the pool.
164    pub fn pool(&self) -> &PgPool {
165        &self.pool
166    }
167
168    /// Returns a reference to the config.
169    pub fn config(&self) -> &Config {
170        &self.config
171    }
172}
173
174impl<Args, Compact, Codec, Fetcher> PostgresStorage<Args, Compact, Codec, Fetcher> {
175    pub fn with_codec<NewCodec>(self) -> PostgresStorage<Args, Compact, NewCodec, Fetcher> {
176        PostgresStorage {
177            _marker: PhantomData,
178            sink: PgSink::new(&self.pool, &self.config),
179            pool: self.pool,
180            config: self.config,
181            fetcher: self.fetcher,
182        }
183    }
184}
185
186impl<Args, Decode> Backend
187    for PostgresStorage<Args, CompactType, Decode, PgFetcher<Args, CompactType, Decode>>
188where
189    Args: Send + 'static + Unpin,
190    Decode: Codec<Args, Compact = CompactType> + Send + 'static,
191    Decode::Error: std::error::Error + Send + Sync + 'static,
192{
193    type Args = Args;
194
195    type IdType = Ulid;
196
197    type Context = PgContext;
198
199    type Error = sqlx::Error;
200
201    type Stream = TaskStream<PgTask<Args>, sqlx::Error>;
202
203    type Beat = BoxStream<'static, Result<(), sqlx::Error>>;
204
205    type Layer = Stack<LockTaskLayer, AcknowledgeLayer<PgAck>>;
206
207    fn heartbeat(&self, worker: &WorkerContext) -> Self::Beat {
208        let pool = self.pool.clone();
209        let config = self.config.clone();
210        let worker = worker.clone();
211        let keep_alive = keep_alive_stream(pool, config, worker);
212        let reenqueue = reenqueue_orphaned_stream(
213            self.pool.clone(),
214            self.config.clone(),
215            *self.config.keep_alive(),
216        )
217        .map_ok(|_| ());
218        futures::stream::select(keep_alive, reenqueue).boxed()
219    }
220
221    fn middleware(&self) -> Self::Layer {
222        Stack::new(
223            LockTaskLayer::new(self.pool.clone()),
224            AcknowledgeLayer::new(PgAck::new(self.pool.clone())),
225        )
226    }
227
228    fn poll(self, worker: &WorkerContext) -> Self::Stream {
229        self.poll_basic(worker)
230            .map(|a| match a {
231                Ok(Some(task)) => Ok(Some(
232                    task.try_map(|t| Decode::decode(&t))
233                        .map_err(|e| sqlx::Error::Decode(e.into()))?,
234                )),
235                Ok(None) => Ok(None),
236                Err(e) => Err(e),
237            })
238            .boxed()
239    }
240}
241
242impl<Args, Decode> BackendExt
243    for PostgresStorage<Args, CompactType, Decode, PgFetcher<Args, CompactType, Decode>>
244where
245    Args: Send + 'static + Unpin,
246    Decode: Codec<Args, Compact = CompactType> + Send + 'static,
247    Decode::Error: std::error::Error + Send + Sync + 'static,
248{
249    type Compact = CompactType;
250
251    type Codec = Decode;
252    type CompactStream = TaskStream<PgTask<CompactType>, Self::Error>;
253    fn poll_compact(self, worker: &WorkerContext) -> Self::CompactStream {
254        self.poll_basic(worker).boxed()
255    }
256}
257
258impl<Args, Decode> PostgresStorage<Args, CompactType, Decode, PgFetcher<Args, CompactType, Decode>>
259where
260    Args: Send + 'static + Unpin,
261{
262    fn poll_basic(&self, worker: &WorkerContext) -> TaskStream<PgTask<CompactType>, sqlx::Error> {
263        let register_worker = initial_heartbeat(
264            self.pool.clone(),
265            self.config.clone(),
266            worker.clone(),
267            "PostgresStorage",
268        )
269        .map_ok(|_| None);
270        let register = stream::once(register_worker);
271        register
272            .chain(PgPollFetcher::<CompactType>::new(
273                &self.pool,
274                &self.config,
275                worker,
276            ))
277            .boxed()
278    }
279}
280
281impl<Args, Decode> Backend for PostgresStorage<Args, CompactType, Decode, PgNotify>
282where
283    Args: Send + 'static + Unpin,
284    Decode: Codec<Args, Compact = CompactType> + 'static + Send,
285    Decode::Error: std::error::Error + Send + Sync + 'static,
286{
287    type Args = Args;
288
289    type IdType = Ulid;
290
291    type Context = PgContext;
292
293    type Error = sqlx::Error;
294
295    type Stream = TaskStream<PgTask<Args>, sqlx::Error>;
296
297    type Beat = BoxStream<'static, Result<(), sqlx::Error>>;
298
299    type Layer = Stack<LockTaskLayer, AcknowledgeLayer<PgAck>>;
300
301    fn heartbeat(&self, worker: &WorkerContext) -> Self::Beat {
302        let pool = self.pool.clone();
303        let config = self.config.clone();
304        let worker = worker.clone();
305        let keep_alive = keep_alive_stream(pool, config, worker);
306        let reenqueue = reenqueue_orphaned_stream(
307            self.pool.clone(),
308            self.config.clone(),
309            *self.config.keep_alive(),
310        )
311        .map_ok(|_| ());
312        futures::stream::select(keep_alive, reenqueue).boxed()
313    }
314
315    fn middleware(&self) -> Self::Layer {
316        Stack::new(
317            LockTaskLayer::new(self.pool.clone()),
318            AcknowledgeLayer::new(PgAck::new(self.pool.clone())),
319        )
320    }
321
322    fn poll(self, worker: &WorkerContext) -> Self::Stream {
323        self.poll_with_notify(worker)
324            .map(|a| match a {
325                Ok(Some(task)) => Ok(Some(
326                    task.try_map(|t| Decode::decode(&t))
327                        .map_err(|e| sqlx::Error::Decode(e.into()))?,
328                )),
329                Ok(None) => Ok(None),
330                Err(e) => Err(e),
331            })
332            .boxed()
333    }
334}
335
336impl<Args, Decode> BackendExt for PostgresStorage<Args, CompactType, Decode, PgNotify>
337where
338    Args: Send + 'static + Unpin,
339    Decode: Codec<Args, Compact = CompactType> + 'static + Unpin + Send,
340    Decode::Error: std::error::Error + Send + Sync + 'static,
341{
342    type Compact = CompactType;
343
344    type Codec = Decode;
345    type CompactStream = TaskStream<PgTask<CompactType>, Self::Error>;
346    fn poll_compact(self, worker: &WorkerContext) -> Self::CompactStream {
347        self.poll_with_notify(worker).boxed()
348    }
349}
350
351impl<Args, Decode> PostgresStorage<Args, CompactType, Decode, PgNotify> {
352    pub fn poll_with_notify(
353        &self,
354        worker: &WorkerContext,
355    ) -> TaskStream<PgTask<CompactType>, sqlx::Error> {
356        let pool = self.pool.clone();
357        let worker_id = worker.name().to_owned();
358        let namespace = self.config.queue().to_string();
359        let listener = async move {
360            let mut fetcher = PgListener::connect_with(&pool)
361                .await
362                .expect("Failed to create listener");
363            fetcher.listen("apalis::job::insert").await.unwrap();
364            fetcher
365        };
366        let fetcher = stream::once(listener).flat_map(|f| f.into_stream());
367        let pool = self.pool.clone();
368        let register_worker = initial_heartbeat(
369            self.pool.clone(),
370            self.config.clone(),
371            worker.clone(),
372            "PostgresStorageWithNotify",
373        )
374        .map_ok(|_| None);
375        let register = stream::once(register_worker);
376        let lazy_fetcher = fetcher
377            .into_stream()
378            .filter_map(move |notification| {
379                let namespace = namespace.clone();
380                async move {
381                    let pg_notification = notification.ok()?;
382                    let payload = pg_notification.payload();
383                    let ev: InsertEvent = serde_json::from_str(payload).ok()?;
384
385                    if ev.job_type == namespace {
386                        return Some(ev.id);
387                    }
388                    None
389                }
390            })
391            .map(|t| t.to_string())
392            .ready_chunks(self.config.buffer_size())
393            .then(move |ids| {
394                let pool = pool.clone();
395                let worker_id = worker_id.clone();
396                async move {
397                    let mut tx = pool.begin().await?;
398                    use crate::from_row::PgTaskRow;
399                    let res: Vec<_> = sqlx::query_file_as!(
400                        PgTaskRow,
401                        "queries/task/queue_by_id.sql",
402                        &ids,
403                        &worker_id
404                    )
405                    .fetch(&mut *tx)
406                    .map(|r| {
407                        let row: TaskRow = r?.try_into()?;
408                        Ok(Some(
409                            row.try_into_task_compact()
410                                .map_err(|e| sqlx::Error::Protocol(e.to_string()))?,
411                        ))
412                    })
413                    .collect()
414                    .await;
415                    tx.commit().await?;
416                    Ok::<_, sqlx::Error>(res)
417                }
418            })
419            .flat_map(|vec| match vec {
420                Ok(vec) => stream::iter(vec.into_iter().map(|res| match res {
421                    Ok(t) => Ok(t),
422                    Err(e) => Err(e),
423                }))
424                .boxed(),
425                Err(e) => stream::once(ready(Err(e))).boxed(),
426            })
427            .boxed();
428
429        let eager_fetcher = StreamExt::boxed(PgPollFetcher::<CompactType>::new(
430            &self.pool,
431            &self.config,
432            worker,
433        ));
434        register.chain(select(lazy_fetcher, eager_fetcher)).boxed()
435    }
436}
437
438#[derive(Debug, Deserialize)]
439pub struct InsertEvent {
440    job_type: String,
441    id: TaskId,
442}
443
444#[cfg(test)]
445mod tests {
446    use std::{
447        collections::HashMap,
448        env,
449        time::{Duration, Instant},
450    };
451
452    use apalis_workflow::Workflow;
453    use apalis_workflow::WorkflowSink;
454
455    use apalis_core::{
456        backend::poll_strategy::{IntervalStrategy, StrategyBuilder},
457        error::BoxDynError,
458        task::data::Data,
459        worker::{builder::WorkerBuilder, event::Event, ext::event_listener::EventListenerExt},
460    };
461    use serde::{Deserialize, Serialize};
462
463    use super::*;
464
465    #[tokio::test]
466    async fn basic_worker() {
467        use apalis_core::backend::TaskSink;
468        let pool = PgPool::connect(
469            env::var("DATABASE_URL")
470                .unwrap_or("postgres://postgres:postgres@localhost/apalis_dev".to_owned())
471                .as_str(),
472        )
473        .await
474        .unwrap();
475        let mut backend = PostgresStorage::new(&pool);
476
477        let mut items = stream::repeat_with(HashMap::default).take(1);
478        backend.push_stream(&mut items).await.unwrap();
479
480        async fn send_reminder(
481            _: HashMap<String, String>,
482            wrk: WorkerContext,
483        ) -> Result<(), BoxDynError> {
484            tokio::time::sleep(Duration::from_secs(2)).await;
485            wrk.stop().unwrap();
486            Ok(())
487        }
488
489        let worker = WorkerBuilder::new("rango-tango-1")
490            .backend(backend)
491            .build(send_reminder);
492        worker.run().await.unwrap();
493    }
494
495    #[tokio::test]
496    async fn notify_worker() {
497        use apalis_core::backend::TaskSink;
498        let pool = PgPool::connect(
499            env::var("DATABASE_URL")
500                .unwrap_or("postgres://postgres:postgres@localhost/apalis_dev".to_owned())
501                .as_str(),
502        )
503        .await
504        .unwrap();
505        let config = Config::new("test").with_poll_interval(
506            StrategyBuilder::new()
507                .apply(IntervalStrategy::new(Duration::from_secs(6)))
508                .build(),
509        );
510        let backend = PostgresStorage::new_with_notify(&pool, &config);
511
512        let mut b = backend.clone();
513
514        tokio::spawn(async move {
515            tokio::time::sleep(Duration::from_secs(2)).await;
516            let mut items = stream::repeat_with(|| {
517                Task::builder(42u32)
518                    .with_ctx(PgContext::new().with_priority(1))
519                    .build()
520            })
521            .take(1);
522            b.push_all(&mut items).await.unwrap();
523        });
524
525        async fn send_reminder(_: u32, wrk: WorkerContext) -> Result<(), BoxDynError> {
526            wrk.stop().unwrap();
527            Ok(())
528        }
529
530        let instant = Instant::now();
531        let worker = WorkerBuilder::new("rango-tango-2")
532            .backend(backend)
533            .build(send_reminder);
534        worker.run().await.unwrap();
535        let run_for = instant.elapsed();
536        assert!(
537            run_for < Duration::from_secs(4),
538            "Worker did not use notify mechanism"
539        );
540    }
541
542    #[tokio::test]
543    async fn test_workflow_complete() {
544        #[derive(Debug, Serialize, Deserialize, Clone)]
545        struct PipelineConfig {
546            min_confidence: f32,
547            enable_sentiment: bool,
548        }
549
550        #[derive(Debug, Serialize, Deserialize)]
551        struct UserInput {
552            text: String,
553        }
554
555        #[derive(Debug, Serialize, Deserialize)]
556        struct Classified {
557            text: String,
558            label: String,
559            confidence: f32,
560        }
561
562        #[derive(Debug, Serialize, Deserialize)]
563        struct Summary {
564            text: String,
565            sentiment: Option<String>,
566        }
567
568        let workflow = Workflow::new("text-pipeline")
569            // Step 1: Preprocess input (e.g., tokenize, lowercase)
570            .and_then(|input: UserInput, mut worker: WorkerContext| async move {
571                worker.emit(&Event::Custom(Box::new(format!(
572                    "Preprocessing input: {}",
573                    input.text
574                ))));
575                let processed = input.text.to_lowercase();
576                Ok::<_, BoxDynError>(processed)
577            })
578            // Step 2: Classify text
579            .and_then(|text: String| async move {
580                let confidence = 0.85; // pretend model confidence
581                let items = text.split_whitespace().collect::<Vec<_>>();
582                let results = items
583                    .into_iter()
584                    .map(|x| Classified {
585                        text: x.to_string(),
586                        label: if x.contains("rust") {
587                            "Tech"
588                        } else {
589                            "General"
590                        }
591                        .to_string(),
592                        confidence,
593                    })
594                    .collect::<Vec<_>>();
595                Ok::<_, BoxDynError>(results)
596            })
597            // Step 3: Filter out low-confidence predictions
598            .filter_map(
599                |c: Classified| async move { if c.confidence >= 0.6 { Some(c) } else { None } },
600            )
601            .filter_map(move |c: Classified, config: Data<PipelineConfig>| {
602                let cfg = config.enable_sentiment;
603                async move {
604                    if !cfg {
605                        return Some(Summary {
606                            text: c.text,
607                            sentiment: None,
608                        });
609                    }
610
611                    // pretend we run a sentiment model
612                    let sentiment = if c.text.contains("delightful") {
613                        "positive"
614                    } else {
615                        "neutral"
616                    };
617                    Some(Summary {
618                        text: c.text,
619                        sentiment: Some(sentiment.to_string()),
620                    })
621                }
622            })
623            .and_then(|a: Vec<Summary>, mut worker: WorkerContext| async move {
624                dbg!(&a);
625                worker.emit(&Event::Custom(Box::new(format!(
626                    "Generated {} summaries",
627                    a.len()
628                ))));
629                worker.stop()
630            });
631
632        let pool = PgPool::connect(
633            env::var("DATABASE_URL")
634                .unwrap_or("postgres://postgres:postgres@localhost/apalis_dev".to_owned())
635                .as_str(),
636        )
637        .await
638        .unwrap();
639        let config = Config::new("test").with_poll_interval(
640            StrategyBuilder::new()
641                .apply(IntervalStrategy::new(Duration::from_secs(1)))
642                .build(),
643        );
644        let mut backend = PostgresStorage::new_with_notify(&pool, &config);
645
646        let input = UserInput {
647            text: "Rust makes systems programming delightful!".to_string(),
648        };
649        backend.push_start(input).await.unwrap();
650
651        let worker = WorkerBuilder::new("rango-tango")
652            .backend(backend)
653            .data(PipelineConfig {
654                min_confidence: 0.8,
655                enable_sentiment: true,
656            })
657            .on_event(|ctx, ev| match ev {
658                Event::Custom(msg) => {
659                    if let Some(m) = msg.downcast_ref::<String>() {
660                        println!("Custom Message: {m}");
661                    }
662                }
663                Event::Error(_) => {
664                    println!("On Error = {ev:?}");
665                    ctx.stop().unwrap();
666                }
667                _ => {
668                    println!("On Event = {ev:?}");
669                }
670            })
671            .build(workflow);
672        worker.run().await.unwrap();
673    }
674}