1use std::{marker::PhantomData, sync::Arc};
2
3use apalis_codec::json::JsonCodec;
4use apalis_core::{
5 backend::{
6 Backend, TaskStream,
7 codec::Codec,
8 poll_strategy::{PollContext, PollStrategyExt},
9 },
10 task::{Task, attempt::Attempt, task_id::TaskId},
11 worker::{context::WorkerContext, ext::ack::AcknowledgeLayer},
12};
13use chrono::{DateTime, Utc};
14use futures::{
15 StreamExt,
16 stream::{self, BoxStream},
17};
18use serde::{Serialize, de::DeserializeOwned};
19use serde_json::Value;
20use sqlx::{PgPool, Postgres};
21
22use crate::{
23 config::Config, context::PgMqContext, errors::PgmqError, fetch::fetch_messages, sink::PgMqSink,
24};
25
26mod ack;
27mod config;
28mod context;
29mod errors;
30mod fetch;
31pub mod query;
32mod sink;
33mod util;
34
35pub const QUEUE_PREFIX: &str = r#"q"#;
36pub const ARCHIVE_PREFIX: &str = r#"a"#;
37pub const PGMQ_SCHEMA: &str = "apalis_pgmq";
38
39pub type PgMqTask<Args> = Task<Args, PgMqContext, i64>;
40
41pub struct PGMQueue<Args, Codec = JsonCodec<Vec<u8>>> {
42 connection: PgPool,
43 config: Config<Codec>,
44 sink: PgMqSink<Args, Codec>,
45 _args: PhantomData<Args>,
46}
47
48impl<Args, C> Clone for PGMQueue<Args, C> {
49 fn clone(&self) -> Self {
50 Self {
51 connection: self.connection.clone(),
52 config: self.config.clone(),
53 sink: self.sink.clone(),
54 _args: self._args,
55 }
56 }
57}
58
59impl PGMQueue<()> {
60 pub async fn setup<'c, E: sqlx::Executor<'c, Database = Postgres>>(
61 executor: E,
62 ) -> Result<bool, PgmqError> {
63 sqlx::query("CREATE EXTENSION IF NOT EXISTS pgmq CASCADE;")
64 .execute(executor)
65 .await
66 .map(|_| true)
67 .map_err(PgmqError::from)
68 }
69 async fn create<'c, E>(queue_name: &str, executor: E) -> Result<(), PgmqError>
70 where
71 E: sqlx::Acquire<'c, Database = Postgres>,
72 {
73 let mut tx = executor.begin().await?;
74 let setup = query::init_queue_client_only(queue_name, false)?;
75 for q in setup {
76 sqlx::query(&q).execute(&mut *tx).await?;
77 }
78 tx.commit().await?;
79 Ok(())
80 }
81}
82
83impl<Args: Serialize + DeserializeOwned> PGMQueue<Args> {
84 pub async fn new(pool: PgPool, queue_name: &str) -> Self {
86 let config: Config<JsonCodec<Vec<u8>>> =
87 Config::default().with_queue(queue_name.to_string());
88 PGMQueue::new_with_config(pool, config).await
89 }
90}
91
92impl<Args, C: Codec<Args, Compact = Vec<u8>>> PGMQueue<Args, C> {
93 pub async fn new_with_config(pool: PgPool, config: Config<C>) -> Self {
94 PGMQueue::create(config.queue(), &pool)
95 .await
96 .expect("Queue to be created");
97 Self {
98 sink: PgMqSink::new(pool.clone(), config.clone()),
99 connection: pool,
100 config,
101 _args: PhantomData,
102 }
103 }
104
105 async fn read_batch(
106 config: Config<C>,
107 connection: PgPool,
108 ) -> Result<Option<Vec<Message>>, PgmqError> {
109 let query = &query::read(
110 config.queue(),
111 config.visibility_timeout().as_secs() as i32,
112 config.buffer_size() as i32,
113 )?;
114 let messages = fetch_messages(query, &connection).await?;
115 Ok(messages)
116 }
117}
118
119struct Message {
120 msg_id: i64,
121 visibility_time: DateTime<Utc>,
122 read_count: i32,
123 enqueued_at: DateTime<Utc>,
124 message: Vec<u8>,
125 headers: Value,
126}
127
128impl<Args, C> Backend for PGMQueue<Args, C>
129where
130 Args: Send + Sync + 'static + Unpin,
131 C: Codec<Args, Compact = Vec<u8>> + Send + Sync + 'static,
132 C::Error: std::error::Error + Send + Sync + 'static,
133{
134 type Args = Args;
135
136 type Context = PgMqContext;
137
138 type Beat = BoxStream<'static, Result<(), PgmqError>>;
139
140 type Error = PgmqError;
141
142 type IdType = i64;
143
144 type Layer = AcknowledgeLayer<Self>;
145
146 type Stream = TaskStream<PgMqTask<Args>, PgmqError>;
147
148 fn heartbeat(&self, _worker: &WorkerContext) -> Self::Beat {
149 Box::pin(stream::pending())
150 }
151
152 fn middleware(&self) -> Self::Layer {
153 AcknowledgeLayer::new(self.clone())
154 }
155
156 fn poll(self, worker: &WorkerContext) -> Self::Stream {
157 let ctx = PollContext::new(worker.clone(), Arc::default());
158 let poller = self.config.poll_strategy().clone().build_stream(&ctx);
159 stream::unfold(
160 (self, poller, Vec::new()),
161 |(backend, mut poller, mut buf)| async move {
162 if let Some(msg) = buf.pop() {
163 return Some((Ok(msg), (backend, poller, buf)));
164 }
165
166 poller.next().await;
167
168 match Self::read_batch(backend.config.clone(), backend.connection.clone()).await {
169 Ok(Some(messages)) => {
170 buf = messages;
171 buf.reverse();
172 let msg = buf.pop().unwrap();
173 Some((Ok(msg), (backend, poller, buf)))
174 }
175 Ok(None) => None,
176 Err(e) => Some((Err(e), (backend, poller, buf))),
177 }
178 },
179 )
180 .map(|res| match res {
181 Ok(raw) => {
182 let args =
183 C::decode(&raw.message).map_err(|e| PgmqError::ParsingError(e.into()))?;
184 let ctx = PgMqContext {
185 enqueued_at: raw.enqueued_at,
186 headers: raw
187 .headers
188 .as_object()
189 .cloned()
190 .ok_or(PgmqError::ParsingError("Headers are not an object".into()))?,
191 };
192 let task = Task::builder(args)
193 .with_task_id(TaskId::new(raw.msg_id))
194 .with_attempt(Attempt::new_with_value(raw.read_count as usize))
195 .run_at_timestamp(raw.visibility_time.timestamp() as u64)
196 .with_ctx(ctx)
197 .build();
198 Ok(Some(task))
199 }
200 Err(e) => Err(e),
201 })
202 .boxed()
203 }
204}
205
206#[cfg(test)]
207mod tests {
208 use std::{collections::HashMap, env, time::Duration};
209
210 use apalis_core::{error::BoxDynError, worker::builder::WorkerBuilder};
211 use futures::SinkExt;
212
213 use super::*;
214
215 #[tokio::test]
216 async fn basic_worker() {
217 let pool = PgPool::connect(env::var("DATABASE_URL").unwrap().as_str())
218 .await
219 .unwrap();
220
221 PGMQueue::setup(&pool).await.unwrap();
222 let mut backend = PGMQueue::new(pool, "basic_test").await;
223
224 backend.send(Task::new(HashMap::new())).await.unwrap();
225
226 async fn send_reminder(
227 _: HashMap<String, String>,
228 wrk: WorkerContext,
229 ) -> Result<(), BoxDynError> {
230 tokio::time::sleep(Duration::from_secs(2)).await;
231 wrk.stop().unwrap();
232 Ok(())
233 }
234
235 let worker = WorkerBuilder::new("rango-tango-1")
236 .backend(backend)
237 .build(send_reminder);
238 worker.run().await.unwrap();
239 }
240}