Skip to main content

apalis_pgmq/
lib.rs

1use std::{marker::PhantomData, sync::Arc};
2
3use apalis_codec::json::JsonCodec;
4use apalis_core::{
5    backend::{
6        Backend, TaskStream,
7        codec::Codec,
8        poll_strategy::{PollContext, PollStrategyExt},
9    },
10    task::{Task, attempt::Attempt, task_id::TaskId},
11    worker::{context::WorkerContext, ext::ack::AcknowledgeLayer},
12};
13use chrono::{DateTime, Utc};
14use futures::{
15    StreamExt,
16    stream::{self, BoxStream},
17};
18use serde::{Serialize, de::DeserializeOwned};
19use serde_json::Value;
20use sqlx::{PgPool, Postgres};
21
22use crate::{
23    config::Config, context::PgMqContext, errors::PgmqError, fetch::fetch_messages, sink::PgMqSink,
24};
25
26mod ack;
27mod config;
28mod context;
29mod errors;
30mod fetch;
31pub mod query;
32mod sink;
33mod util;
34
35pub const QUEUE_PREFIX: &str = r#"q"#;
36pub const ARCHIVE_PREFIX: &str = r#"a"#;
37pub const PGMQ_SCHEMA: &str = "apalis_pgmq";
38
39pub type PgMqTask<Args> = Task<Args, PgMqContext, i64>;
40
41pub struct PGMQueue<Args, Codec = JsonCodec<Vec<u8>>> {
42    connection: PgPool,
43    config: Config<Codec>,
44    sink: PgMqSink<Args, Codec>,
45    _args: PhantomData<Args>,
46}
47
48impl<Args, C> Clone for PGMQueue<Args, C> {
49    fn clone(&self) -> Self {
50        Self {
51            connection: self.connection.clone(),
52            config: self.config.clone(),
53            sink: self.sink.clone(),
54            _args: self._args,
55        }
56    }
57}
58
59impl PGMQueue<()> {
60    pub async fn setup<'c, E: sqlx::Executor<'c, Database = Postgres>>(
61        executor: E,
62    ) -> Result<bool, PgmqError> {
63        sqlx::query("CREATE EXTENSION IF NOT EXISTS pgmq CASCADE;")
64            .execute(executor)
65            .await
66            .map(|_| true)
67            .map_err(PgmqError::from)
68    }
69    async fn create<'c, E>(queue_name: &str, executor: E) -> Result<(), PgmqError>
70    where
71        E: sqlx::Acquire<'c, Database = Postgres>,
72    {
73        let mut tx = executor.begin().await?;
74        let setup = query::init_queue_client_only(queue_name, false)?;
75        for q in setup {
76            sqlx::query(&q).execute(&mut *tx).await?;
77        }
78        tx.commit().await?;
79        Ok(())
80    }
81}
82
83impl<Args: Serialize + DeserializeOwned> PGMQueue<Args> {
84    /// initialize a PGMQ connection with your own SQLx Postgres connection pool
85    pub async fn new(pool: PgPool, queue_name: &str) -> Self {
86        let config: Config<JsonCodec<Vec<u8>>> =
87            Config::default().with_queue(queue_name.to_string());
88        PGMQueue::new_with_config(pool, config).await
89    }
90}
91
92impl<Args, C: Codec<Args, Compact = Vec<u8>>> PGMQueue<Args, C> {
93    pub async fn new_with_config(pool: PgPool, config: Config<C>) -> Self {
94        PGMQueue::create(config.queue(), &pool)
95            .await
96            .expect("Queue to be created");
97        Self {
98            sink: PgMqSink::new(pool.clone(), config.clone()),
99            connection: pool,
100            config,
101            _args: PhantomData,
102        }
103    }
104
105    async fn read_batch(
106        config: Config<C>,
107        connection: PgPool,
108    ) -> Result<Option<Vec<Message>>, PgmqError> {
109        let query = &query::read(
110            config.queue(),
111            config.visibility_timeout().as_secs() as i32,
112            config.buffer_size() as i32,
113        )?;
114        let messages = fetch_messages(query, &connection).await?;
115        Ok(messages)
116    }
117}
118
119struct Message {
120    msg_id: i64,
121    visibility_time: DateTime<Utc>,
122    read_count: i32,
123    enqueued_at: DateTime<Utc>,
124    message: Vec<u8>,
125    headers: Value,
126}
127
128impl<Args, C> Backend for PGMQueue<Args, C>
129where
130    Args: Send + Sync + 'static + Unpin,
131    C: Codec<Args, Compact = Vec<u8>> + Send + Sync + 'static,
132    C::Error: std::error::Error + Send + Sync + 'static,
133{
134    type Args = Args;
135
136    type Context = PgMqContext;
137
138    type Beat = BoxStream<'static, Result<(), PgmqError>>;
139
140    type Error = PgmqError;
141
142    type IdType = i64;
143
144    type Layer = AcknowledgeLayer<Self>;
145
146    type Stream = TaskStream<PgMqTask<Args>, PgmqError>;
147
148    fn heartbeat(&self, _worker: &WorkerContext) -> Self::Beat {
149        Box::pin(stream::pending())
150    }
151
152    fn middleware(&self) -> Self::Layer {
153        AcknowledgeLayer::new(self.clone())
154    }
155
156    fn poll(self, worker: &WorkerContext) -> Self::Stream {
157        let ctx = PollContext::new(worker.clone(), Arc::default());
158        let poller = self.config.poll_strategy().clone().build_stream(&ctx);
159        stream::unfold(
160            (self, poller, Vec::new()),
161            |(backend, mut poller, mut buf)| async move {
162                if let Some(msg) = buf.pop() {
163                    return Some((Ok(msg), (backend, poller, buf)));
164                }
165
166                poller.next().await;
167
168                match Self::read_batch(backend.config.clone(), backend.connection.clone()).await {
169                    Ok(Some(messages)) => {
170                        buf = messages;
171                        buf.reverse();
172                        let msg = buf.pop().unwrap();
173                        Some((Ok(msg), (backend, poller, buf)))
174                    }
175                    Ok(None) => None,
176                    Err(e) => Some((Err(e), (backend, poller, buf))),
177                }
178            },
179        )
180        .map(|res| match res {
181            Ok(raw) => {
182                let args =
183                    C::decode(&raw.message).map_err(|e| PgmqError::ParsingError(e.into()))?;
184                let ctx = PgMqContext {
185                    enqueued_at: raw.enqueued_at,
186                    headers: raw
187                        .headers
188                        .as_object()
189                        .cloned()
190                        .ok_or(PgmqError::ParsingError("Headers are not an object".into()))?,
191                };
192                let task = Task::builder(args)
193                    .with_task_id(TaskId::new(raw.msg_id))
194                    .with_attempt(Attempt::new_with_value(raw.read_count as usize))
195                    .run_at_timestamp(raw.visibility_time.timestamp() as u64)
196                    .with_ctx(ctx)
197                    .build();
198                Ok(Some(task))
199            }
200            Err(e) => Err(e),
201        })
202        .boxed()
203    }
204}
205
206#[cfg(test)]
207mod tests {
208    use std::{collections::HashMap, env, time::Duration};
209
210    use apalis_core::{error::BoxDynError, worker::builder::WorkerBuilder};
211    use futures::SinkExt;
212
213    use super::*;
214
215    #[tokio::test]
216    async fn basic_worker() {
217        let pool = PgPool::connect(env::var("DATABASE_URL").unwrap().as_str())
218            .await
219            .unwrap();
220
221        PGMQueue::setup(&pool).await.unwrap();
222        let mut backend = PGMQueue::new(pool, "basic_test").await;
223
224        backend.send(Task::new(HashMap::new())).await.unwrap();
225
226        async fn send_reminder(
227            _: HashMap<String, String>,
228            wrk: WorkerContext,
229        ) -> Result<(), BoxDynError> {
230            tokio::time::sleep(Duration::from_secs(2)).await;
231            wrk.stop().unwrap();
232            Ok(())
233        }
234
235        let worker = WorkerBuilder::new("rango-tango-1")
236            .backend(backend)
237            .build(send_reminder);
238        worker.run().await.unwrap();
239    }
240}