apalis_postgres/
lib.rs

1#![doc = include_str!("../README.md")]
2//!
3//! [`PostgresStorageWithListener`]: crate::PostgresStorage
4//! [`SharedPostgresStorage`]: crate::shared::SharedPostgresStorage
5use std::{fmt::Debug, marker::PhantomData};
6
7use apalis_core::{
8    backend::{
9        Backend, BackendExt, TaskStream,
10        codec::{Codec, json::JsonCodec},
11    },
12    features_table,
13    layers::Stack,
14    task::{Task, task_id::TaskId},
15    worker::{context::WorkerContext, ext::ack::AcknowledgeLayer},
16};
17use apalis_sql::from_row::TaskRow;
18use futures::{
19    StreamExt, TryFutureExt, TryStreamExt,
20    future::ready,
21    stream::{self, BoxStream, select},
22};
23use serde::Deserialize;
24pub use sqlx::{PgPool, postgres::PgConnectOptions, postgres::PgListener, postgres::Postgres};
25use ulid::Ulid;
26
27pub use crate::{
28    ack::{LockTaskLayer, PgAck},
29    fetcher::{PgFetcher, PgPollFetcher},
30    queries::{
31        keep_alive::{initial_heartbeat, keep_alive_stream},
32        reenqueue_orphaned::reenqueue_orphaned_stream,
33    },
34    sink::PgSink,
35};
36
37mod ack;
38mod config;
39mod fetcher;
40mod from_row;
41
42pub type PgContext = apalis_sql::context::SqlContext;
43pub use config::Config;
44
45mod queries;
46pub mod shared;
47pub mod sink;
48
49pub type PgTask<Args> = Task<Args, PgContext, Ulid>;
50
51pub type CompactType = Vec<u8>;
52
53#[derive(Debug, Clone, Default)]
54pub struct PgNotify {
55    _private: PhantomData<()>,
56}
57
58#[doc = features_table! {
59    setup = r#"
60        # {
61        #   use apalis_postgres::PostgresStorage;
62        #   use sqlx::PgPool;
63        #   let pool = PgPool::connect(std::env::var("DATABASE_URL").unwrap().as_str()).await.unwrap();
64        #   PostgresStorage::setup(&pool).await.unwrap();
65        #   PostgresStorage::new(&pool)
66        # };
67    "#,
68
69    Backend => supported("Supports storage and retrieval of tasks", true),
70    TaskSink => supported("Ability to push new tasks", true),
71    Serialization => supported("Serialization support for arguments", true),
72    Workflow => supported("Flexible enough to support workflows", true),
73    WebUI => supported("Expose a web interface for monitoring tasks", true),
74    FetchById => supported("Allow fetching a task by its ID", false),
75    RegisterWorker => supported("Allow registering a worker with the backend", false),
76    MakeShared => supported("Share one connection across multiple workers via [`SharedPostgresStorage`]", false),
77    WaitForCompletion => supported("Wait for tasks to complete without blocking", true),
78    ResumeById => supported("Resume a task by its ID", false),
79    ResumeAbandoned => supported("Resume abandoned tasks", false),
80    ListWorkers => supported("List all workers registered with the backend", false),
81    ListTasks => supported("List all tasks in the backend", false),
82}]
83///
84/// [`SharedPostgresStorage`]: crate::shared::SharedPostgresStorage
85#[pin_project::pin_project]
86pub struct PostgresStorage<
87    Args,
88    Compact = CompactType,
89    Codec = JsonCodec<CompactType>,
90    Fetcher = PgFetcher<Args, Compact, Codec>,
91> {
92    _marker: PhantomData<(Args, Compact, Codec)>,
93    pool: PgPool,
94    config: Config,
95    #[pin]
96    fetcher: Fetcher,
97    #[pin]
98    sink: PgSink<Args, Compact, Codec>,
99}
100
101impl<Args, Compact, Codec, Fetcher: Clone> Clone
102    for PostgresStorage<Args, Compact, Codec, Fetcher>
103{
104    fn clone(&self) -> Self {
105        Self {
106            _marker: PhantomData,
107            pool: self.pool.clone(),
108            config: self.config.clone(),
109            fetcher: self.fetcher.clone(),
110            sink: self.sink.clone(),
111        }
112    }
113}
114
115impl PostgresStorage<(), (), ()> {
116    /// Perform migrations for storage
117    #[cfg(feature = "migrate")]
118    pub async fn setup(pool: &PgPool) -> Result<(), sqlx::Error> {
119        Self::migrations().run(pool).await?;
120        Ok(())
121    }
122
123    /// Get postgres migrations without running them
124    #[cfg(feature = "migrate")]
125    pub fn migrations() -> sqlx::migrate::Migrator {
126        sqlx::migrate!("./migrations")
127    }
128}
129
130impl<Args> PostgresStorage<Args> {
131    pub fn new(pool: &PgPool) -> Self {
132        let config = Config::new(std::any::type_name::<Args>());
133        Self::new_with_config(pool, &config)
134    }
135
136    /// Creates a new PostgresStorage instance.
137    pub fn new_with_config(pool: &PgPool, config: &Config) -> Self {
138        let sink = PgSink::new(pool, config);
139        Self {
140            _marker: PhantomData,
141            pool: pool.clone(),
142            config: config.clone(),
143            fetcher: PgFetcher {
144                _marker: PhantomData,
145            },
146            sink,
147        }
148    }
149
150    pub fn new_with_notify(
151        pool: &PgPool,
152        config: &Config,
153    ) -> PostgresStorage<Args, CompactType, JsonCodec<CompactType>, PgNotify> {
154        let sink = PgSink::new(pool, config);
155
156        PostgresStorage {
157            _marker: PhantomData,
158            pool: pool.clone(),
159            config: config.clone(),
160            fetcher: PgNotify::default(),
161            sink,
162        }
163    }
164
165    /// Returns a reference to the pool.
166    pub fn pool(&self) -> &PgPool {
167        &self.pool
168    }
169
170    /// Returns a reference to the config.
171    pub fn config(&self) -> &Config {
172        &self.config
173    }
174}
175
176impl<Args, Compact, Codec, Fetcher> PostgresStorage<Args, Compact, Codec, Fetcher> {
177    pub fn with_codec<NewCodec>(self) -> PostgresStorage<Args, Compact, NewCodec, Fetcher> {
178        PostgresStorage {
179            _marker: PhantomData,
180            sink: PgSink::new(&self.pool, &self.config),
181            pool: self.pool,
182            config: self.config,
183            fetcher: self.fetcher,
184        }
185    }
186}
187
188impl<Args, Decode> Backend
189    for PostgresStorage<Args, CompactType, Decode, PgFetcher<Args, CompactType, Decode>>
190where
191    Args: Send + 'static + Unpin,
192    Decode: Codec<Args, Compact = CompactType> + Send + 'static,
193    Decode::Error: std::error::Error + Send + Sync + 'static,
194{
195    type Args = Args;
196
197    type IdType = Ulid;
198
199    type Context = PgContext;
200
201    type Error = sqlx::Error;
202
203    type Stream = TaskStream<PgTask<Args>, sqlx::Error>;
204
205    type Beat = BoxStream<'static, Result<(), sqlx::Error>>;
206
207    type Layer = Stack<LockTaskLayer, AcknowledgeLayer<PgAck>>;
208
209    fn heartbeat(&self, worker: &WorkerContext) -> Self::Beat {
210        let pool = self.pool.clone();
211        let config = self.config.clone();
212        let worker = worker.clone();
213        let keep_alive = keep_alive_stream(pool, config, worker);
214        let reenqueue = reenqueue_orphaned_stream(
215            self.pool.clone(),
216            self.config.clone(),
217            *self.config.keep_alive(),
218        )
219        .map_ok(|_| ());
220        futures::stream::select(keep_alive, reenqueue).boxed()
221    }
222
223    fn middleware(&self) -> Self::Layer {
224        Stack::new(
225            LockTaskLayer::new(self.pool.clone()),
226            AcknowledgeLayer::new(PgAck::new(self.pool.clone())),
227        )
228    }
229
230    fn poll(self, worker: &WorkerContext) -> Self::Stream {
231        self.poll_basic(worker)
232            .map(|a| match a {
233                Ok(Some(task)) => Ok(Some(
234                    task.try_map(|t| Decode::decode(&t))
235                        .map_err(|e| sqlx::Error::Decode(e.into()))?,
236                )),
237                Ok(None) => Ok(None),
238                Err(e) => Err(e),
239            })
240            .boxed()
241    }
242}
243
244impl<Args, Decode> BackendExt
245    for PostgresStorage<Args, CompactType, Decode, PgFetcher<Args, CompactType, Decode>>
246where
247    Args: Send + 'static + Unpin,
248    Decode: Codec<Args, Compact = CompactType> + Send + 'static,
249    Decode::Error: std::error::Error + Send + Sync + 'static,
250{
251    type Compact = CompactType;
252
253    type Codec = Decode;
254    type CompactStream = TaskStream<PgTask<CompactType>, Self::Error>;
255    fn poll_compact(self, worker: &WorkerContext) -> Self::CompactStream {
256        self.poll_basic(worker).boxed()
257    }
258}
259
260impl<Args, Decode> PostgresStorage<Args, CompactType, Decode, PgFetcher<Args, CompactType, Decode>>
261where
262    Args: Send + 'static + Unpin,
263{
264    fn poll_basic(&self, worker: &WorkerContext) -> TaskStream<PgTask<CompactType>, sqlx::Error> {
265        let register_worker = initial_heartbeat(
266            self.pool.clone(),
267            self.config.clone(),
268            worker.clone(),
269            "PostgresStorage",
270        )
271        .map_ok(|_| None);
272        let register = stream::once(register_worker);
273        register
274            .chain(PgPollFetcher::<CompactType>::new(
275                &self.pool,
276                &self.config,
277                worker,
278            ))
279            .boxed()
280    }
281}
282
283impl<Args, Decode> Backend for PostgresStorage<Args, CompactType, Decode, PgNotify>
284where
285    Args: Send + 'static + Unpin,
286    Decode: Codec<Args, Compact = CompactType> + 'static + Send,
287    Decode::Error: std::error::Error + Send + Sync + 'static,
288{
289    type Args = Args;
290
291    type IdType = Ulid;
292
293    type Context = PgContext;
294
295    type Error = sqlx::Error;
296
297    type Stream = TaskStream<PgTask<Args>, sqlx::Error>;
298
299    type Beat = BoxStream<'static, Result<(), sqlx::Error>>;
300
301    type Layer = Stack<LockTaskLayer, AcknowledgeLayer<PgAck>>;
302
303    fn heartbeat(&self, worker: &WorkerContext) -> Self::Beat {
304        let pool = self.pool.clone();
305        let config = self.config.clone();
306        let worker = worker.clone();
307        let keep_alive = keep_alive_stream(pool, config, worker);
308        let reenqueue = reenqueue_orphaned_stream(
309            self.pool.clone(),
310            self.config.clone(),
311            *self.config.keep_alive(),
312        )
313        .map_ok(|_| ());
314        futures::stream::select(keep_alive, reenqueue).boxed()
315    }
316
317    fn middleware(&self) -> Self::Layer {
318        Stack::new(
319            LockTaskLayer::new(self.pool.clone()),
320            AcknowledgeLayer::new(PgAck::new(self.pool.clone())),
321        )
322    }
323
324    fn poll(self, worker: &WorkerContext) -> Self::Stream {
325        self.poll_with_notify(worker)
326            .map(|a| match a {
327                Ok(Some(task)) => Ok(Some(
328                    task.try_map(|t| Decode::decode(&t))
329                        .map_err(|e| sqlx::Error::Decode(e.into()))?,
330                )),
331                Ok(None) => Ok(None),
332                Err(e) => Err(e),
333            })
334            .boxed()
335    }
336}
337
338impl<Args, Decode> BackendExt for PostgresStorage<Args, CompactType, Decode, PgNotify>
339where
340    Args: Send + 'static + Unpin,
341    Decode: Codec<Args, Compact = CompactType> + 'static + Unpin + Send,
342    Decode::Error: std::error::Error + Send + Sync + 'static,
343{
344    type Compact = CompactType;
345
346    type Codec = Decode;
347    type CompactStream = TaskStream<PgTask<CompactType>, Self::Error>;
348    fn poll_compact(self, worker: &WorkerContext) -> Self::CompactStream {
349        self.poll_with_notify(worker).boxed()
350    }
351}
352
353impl<Args, Decode> PostgresStorage<Args, CompactType, Decode, PgNotify> {
354    pub fn poll_with_notify(
355        &self,
356        worker: &WorkerContext,
357    ) -> TaskStream<PgTask<CompactType>, sqlx::Error> {
358        let pool = self.pool.clone();
359        let worker_id = worker.name().to_owned();
360        let namespace = self.config.queue().to_string();
361        let listener = async move {
362            let mut fetcher = PgListener::connect_with(&pool)
363                .await
364                .expect("Failed to create listener");
365            fetcher.listen("apalis::job::insert").await.unwrap();
366            fetcher
367        };
368        let fetcher = stream::once(listener).flat_map(|f| f.into_stream());
369        let pool = self.pool.clone();
370        let register_worker = initial_heartbeat(
371            self.pool.clone(),
372            self.config.clone(),
373            worker.clone(),
374            "PostgresStorageWithNotify",
375        )
376        .map_ok(|_| None);
377        let register = stream::once(register_worker);
378        let lazy_fetcher = fetcher
379            .into_stream()
380            .filter_map(move |notification| {
381                let namespace = namespace.clone();
382                async move {
383                    let pg_notification = notification.ok()?;
384                    let payload = pg_notification.payload();
385                    let ev: InsertEvent = serde_json::from_str(payload).ok()?;
386
387                    if ev.job_type == namespace {
388                        return Some(ev.id);
389                    }
390                    None
391                }
392            })
393            .map(|t| t.to_string())
394            .ready_chunks(self.config.buffer_size())
395            .then(move |ids| {
396                let pool = pool.clone();
397                let worker_id = worker_id.clone();
398                async move {
399                    let mut tx = pool.begin().await?;
400                    use crate::from_row::PgTaskRow;
401                    let res: Vec<_> = sqlx::query_file_as!(
402                        PgTaskRow,
403                        "queries/task/queue_by_id.sql",
404                        &ids,
405                        &worker_id
406                    )
407                    .fetch(&mut *tx)
408                    .map(|r| {
409                        let row: TaskRow = r?.try_into()?;
410                        Ok(Some(
411                            row.try_into_task_compact()
412                                .map_err(|e| sqlx::Error::Protocol(e.to_string()))?,
413                        ))
414                    })
415                    .collect()
416                    .await;
417                    tx.commit().await?;
418                    Ok::<_, sqlx::Error>(res)
419                }
420            })
421            .flat_map(|vec| match vec {
422                Ok(vec) => stream::iter(vec.into_iter().map(|res| match res {
423                    Ok(t) => Ok(t),
424                    Err(e) => Err(e),
425                }))
426                .boxed(),
427                Err(e) => stream::once(ready(Err(e))).boxed(),
428            })
429            .boxed();
430
431        let eager_fetcher = StreamExt::boxed(PgPollFetcher::<CompactType>::new(
432            &self.pool,
433            &self.config,
434            worker,
435        ));
436        register.chain(select(lazy_fetcher, eager_fetcher)).boxed()
437    }
438}
439
440#[derive(Debug, Deserialize)]
441pub struct InsertEvent {
442    job_type: String,
443    id: TaskId,
444}
445
446#[cfg(test)]
447mod tests {
448    use std::{
449        collections::HashMap,
450        env,
451        time::{Duration, Instant},
452    };
453
454    use apalis_workflow::Workflow;
455    use apalis_workflow::WorkflowSink;
456
457    use apalis_core::{
458        backend::poll_strategy::{IntervalStrategy, StrategyBuilder},
459        error::BoxDynError,
460        task::data::Data,
461        worker::{builder::WorkerBuilder, event::Event, ext::event_listener::EventListenerExt},
462    };
463    use serde::{Deserialize, Serialize};
464
465    use super::*;
466
467    #[tokio::test]
468    async fn basic_worker() {
469        use apalis_core::backend::TaskSink;
470        let pool = PgPool::connect(env::var("DATABASE_URL").unwrap().as_str())
471            .await
472            .unwrap();
473        let mut backend = PostgresStorage::new(&pool);
474
475        let mut items = stream::repeat_with(HashMap::default).take(1);
476        backend.push_stream(&mut items).await.unwrap();
477
478        async fn send_reminder(
479            _: HashMap<String, String>,
480            wrk: WorkerContext,
481        ) -> Result<(), BoxDynError> {
482            tokio::time::sleep(Duration::from_secs(2)).await;
483            wrk.stop().unwrap();
484            Ok(())
485        }
486
487        let worker = WorkerBuilder::new("rango-tango-1")
488            .backend(backend)
489            .build(send_reminder);
490        worker.run().await.unwrap();
491    }
492
493    #[tokio::test]
494    async fn notify_worker() {
495        use apalis_core::backend::TaskSink;
496        let pool = PgPool::connect(env::var("DATABASE_URL").unwrap().as_str())
497            .await
498            .unwrap();
499        let config = Config::new("test").with_poll_interval(
500            StrategyBuilder::new()
501                .apply(IntervalStrategy::new(Duration::from_secs(6)))
502                .build(),
503        );
504        let backend = PostgresStorage::new_with_notify(&pool, &config);
505
506        let mut b = backend.clone();
507
508        tokio::spawn(async move {
509            tokio::time::sleep(Duration::from_secs(2)).await;
510            let mut items = stream::repeat_with(|| {
511                Task::builder(42u32)
512                    .with_ctx(PgContext::new().with_priority(1))
513                    .build()
514            })
515            .take(1);
516            b.push_all(&mut items).await.unwrap();
517        });
518
519        async fn send_reminder(_: u32, wrk: WorkerContext) -> Result<(), BoxDynError> {
520            wrk.stop().unwrap();
521            Ok(())
522        }
523
524        let instant = Instant::now();
525        let worker = WorkerBuilder::new("rango-tango-2")
526            .backend(backend)
527            .build(send_reminder);
528        worker.run().await.unwrap();
529        let run_for = instant.elapsed();
530        assert!(
531            run_for < Duration::from_secs(4),
532            "Worker did not use notify mechanism"
533        );
534    }
535
536    #[tokio::test]
537    async fn test_workflow_complete() {
538        #[derive(Debug, Serialize, Deserialize, Clone)]
539        struct PipelineConfig {
540            min_confidence: f32,
541            enable_sentiment: bool,
542        }
543
544        #[derive(Debug, Serialize, Deserialize)]
545        struct UserInput {
546            text: String,
547        }
548
549        #[derive(Debug, Serialize, Deserialize)]
550        struct Classified {
551            text: String,
552            label: String,
553            confidence: f32,
554        }
555
556        #[derive(Debug, Serialize, Deserialize)]
557        struct Summary {
558            text: String,
559            sentiment: Option<String>,
560        }
561
562        let workflow = Workflow::new("text-pipeline")
563            // Step 1: Preprocess input (e.g., tokenize, lowercase)
564            .and_then(|input: UserInput, mut worker: WorkerContext| async move {
565                worker.emit(&Event::Custom(Box::new(format!(
566                    "Preprocessing input: {}",
567                    input.text
568                ))));
569                let processed = input.text.to_lowercase();
570                Ok::<_, BoxDynError>(processed)
571            })
572            // Step 2: Classify text
573            .and_then(|text: String| async move {
574                let confidence = 0.85; // pretend model confidence
575                let items = text.split_whitespace().collect::<Vec<_>>();
576                let results = items
577                    .into_iter()
578                    .map(|x| Classified {
579                        text: x.to_string(),
580                        label: if x.contains("rust") {
581                            "Tech"
582                        } else {
583                            "General"
584                        }
585                        .to_string(),
586                        confidence,
587                    })
588                    .collect::<Vec<_>>();
589                Ok::<_, BoxDynError>(results)
590            })
591            // Step 3: Filter out low-confidence predictions
592            .filter_map(
593                |c: Classified| async move { if c.confidence >= 0.6 { Some(c) } else { None } },
594            )
595            .filter_map(move |c: Classified, config: Data<PipelineConfig>| {
596                let cfg = config.enable_sentiment;
597                async move {
598                    if !cfg {
599                        return Some(Summary {
600                            text: c.text,
601                            sentiment: None,
602                        });
603                    }
604
605                    // pretend we run a sentiment model
606                    let sentiment = if c.text.contains("delightful") {
607                        "positive"
608                    } else {
609                        "neutral"
610                    };
611                    Some(Summary {
612                        text: c.text,
613                        sentiment: Some(sentiment.to_string()),
614                    })
615                }
616            })
617            .and_then(|a: Vec<Summary>, mut worker: WorkerContext| async move {
618                dbg!(&a);
619                worker.emit(&Event::Custom(Box::new(format!(
620                    "Generated {} summaries",
621                    a.len()
622                ))));
623                worker.stop()
624            });
625
626        let pool = PgPool::connect(env::var("DATABASE_URL").unwrap().as_str())
627            .await
628            .unwrap();
629        let config = Config::new("test").with_poll_interval(
630            StrategyBuilder::new()
631                .apply(IntervalStrategy::new(Duration::from_secs(1)))
632                .build(),
633        );
634        let mut backend = PostgresStorage::new_with_notify(&pool, &config);
635
636        let input = UserInput {
637            text: "Rust makes systems programming delightful!".to_string(),
638        };
639        backend.push_start(input).await.unwrap();
640
641        let worker = WorkerBuilder::new("rango-tango")
642            .backend(backend)
643            .data(PipelineConfig {
644                min_confidence: 0.8,
645                enable_sentiment: true,
646            })
647            .on_event(|ctx, ev| match ev {
648                Event::Custom(msg) => {
649                    if let Some(m) = msg.downcast_ref::<String>() {
650                        println!("Custom Message: {m}");
651                    }
652                }
653                Event::Error(_) => {
654                    println!("On Error = {ev:?}");
655                    ctx.stop().unwrap();
656                }
657                _ => {
658                    println!("On Event = {ev:?}");
659                }
660            })
661            .build(workflow);
662        worker.run().await.unwrap();
663    }
664}