apalis_postgres/
shared.rs

1use std::{
2    collections::HashMap,
3    future::ready,
4    marker::PhantomData,
5    pin::Pin,
6    sync::Arc,
7    task::{Context, Poll},
8};
9
10use crate::{
11    CompactType, Config, InsertEvent, PgContext, PgTask, PgTaskId, PostgresStorage,
12    ack::{LockTaskLayer, PgAck},
13    fetcher::PgPollFetcher,
14    queries::{
15        keep_alive::{initial_heartbeat, keep_alive_stream},
16        reenqueue_orphaned::reenqueue_orphaned_stream,
17    },
18};
19use crate::{from_row::PgTaskRow, sink::PgSink};
20use apalis_codec::json::JsonCodec;
21use apalis_core::{
22    backend::{Backend, BackendExt, TaskStream, codec::Codec, queue::Queue, shared::MakeShared},
23    layers::Stack,
24    worker::{context::WorkerContext, ext::ack::AcknowledgeLayer},
25};
26use apalis_sql::from_row::TaskRow;
27use futures::{
28    FutureExt, SinkExt, Stream, StreamExt, TryFutureExt, TryStreamExt,
29    channel::mpsc::{self, Receiver, Sender},
30    future::{BoxFuture, Shared},
31    lock::Mutex,
32    stream::{self, BoxStream, select},
33};
34use sqlx::{PgPool, postgres::PgListener};
35use ulid::Ulid;
36
37pub struct SharedPostgresStorage<Compact = CompactType, Codec = JsonCodec<CompactType>> {
38    pool: PgPool,
39    registry: Arc<Mutex<HashMap<String, Sender<PgTaskId>>>>,
40    drive: Shared<BoxFuture<'static, ()>>,
41    _marker: PhantomData<(Compact, Codec)>,
42}
43
44impl SharedPostgresStorage {
45    pub fn new(pool: PgPool) -> Self {
46        let registry: Arc<Mutex<HashMap<String, Sender<PgTaskId>>>> =
47            Arc::new(Mutex::new(HashMap::default()));
48        let p = pool.clone();
49        let instances = registry.clone();
50        Self {
51            pool,
52            drive: async move {
53                let mut listener = PgListener::connect_with(&p).await.unwrap();
54                listener.listen("apalis::job::insert").await.unwrap();
55                listener
56                    .into_stream()
57                    .filter_map(|notification| {
58                        let instances = instances.clone();
59                        async move {
60                            let pg_notification = notification.ok()?;
61                            let payload = pg_notification.payload();
62                            let ev: InsertEvent = serde_json::from_str(payload).ok()?;
63                            let instances = instances.lock().await;
64                            if instances.get(&ev.job_type).is_some() {
65                                return Some(ev);
66                            }
67                            None
68                        }
69                    })
70                    .for_each(|ev| {
71                        let instances = instances.clone();
72                        async move {
73                            let mut instances = instances.lock().await;
74                            let sender = instances.get_mut(&ev.job_type).unwrap();
75                            sender.send(ev.id).await.unwrap();
76                        }
77                    })
78                    .await;
79            }
80            .boxed()
81            .shared(),
82            registry,
83            _marker: PhantomData,
84        }
85    }
86}
87#[derive(Debug, thiserror::Error)]
88pub enum SharedPostgresError {
89    /// Namespace not found
90    #[error("namespace already exists: {0}")]
91    NamespaceExists(String),
92
93    /// Registry locked
94    #[error("registry locked")]
95    RegistryLocked,
96}
97
98impl<Args, Compact, Codec> MakeShared<Args> for SharedPostgresStorage<Compact, Codec> {
99    type Backend = PostgresStorage<Args, Compact, Codec, SharedFetcher>;
100    type Config = Config;
101    type MakeError = SharedPostgresError;
102    fn make_shared(&mut self) -> Result<Self::Backend, Self::MakeError>
103    where
104        Self::Config: Default,
105    {
106        Self::make_shared_with_config(self, Config::new(std::any::type_name::<Args>()))
107    }
108    fn make_shared_with_config(
109        &mut self,
110        config: Self::Config,
111    ) -> Result<Self::Backend, Self::MakeError> {
112        let (tx, rx) = mpsc::channel(config.buffer_size());
113        let mut r = self
114            .registry
115            .try_lock()
116            .ok_or(SharedPostgresError::RegistryLocked)?;
117        if r.insert(config.queue().to_string(), tx).is_some() {
118            return Err(SharedPostgresError::NamespaceExists(
119                config.queue().to_string(),
120            ));
121        }
122        let sink = PgSink::new(&self.pool, &config);
123        Ok(PostgresStorage {
124            _marker: PhantomData,
125            config,
126            fetcher: SharedFetcher {
127                poller: self.drive.clone(),
128                receiver: Arc::new(Mutex::new(rx)),
129            },
130            pool: self.pool.clone(),
131            sink,
132        })
133    }
134}
135
136#[derive(Clone, Debug)]
137pub struct SharedFetcher {
138    poller: Shared<BoxFuture<'static, ()>>,
139    receiver: Arc<Mutex<Receiver<PgTaskId>>>,
140}
141
142impl Stream for SharedFetcher {
143    type Item = PgTaskId;
144    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
145        let this = self.get_mut();
146        // Keep the poller alive by polling it, but ignoring the output
147        let _ = this.poller.poll_unpin(cx);
148
149        // Delegate actual items to receiver
150        let mut receiver = this.receiver.try_lock();
151        if let Some(ref mut rx) = receiver {
152            rx.poll_next_unpin(cx)
153        } else {
154            Poll::Pending
155        }
156    }
157}
158
159impl<Args, Decode> Backend for PostgresStorage<Args, CompactType, Decode, SharedFetcher>
160where
161    Args: Send + 'static + Unpin,
162    Decode: Codec<Args, Compact = CompactType> + 'static + Unpin + Send,
163    Decode::Error: std::error::Error + Send + Sync + 'static,
164{
165    type Args = Args;
166
167    type IdType = Ulid;
168
169    type Error = sqlx::Error;
170
171    type Stream = TaskStream<PgTask<Args>, Self::Error>;
172
173    type Beat = BoxStream<'static, Result<(), Self::Error>>;
174
175    type Context = PgContext;
176
177    type Layer = Stack<LockTaskLayer, AcknowledgeLayer<PgAck>>;
178
179    fn heartbeat(&self, worker: &WorkerContext) -> Self::Beat {
180        let pool = self.pool.clone();
181        let config = self.config.clone();
182        let worker = worker.clone();
183        let keep_alive = keep_alive_stream(pool, config, worker);
184        let reenqueue = reenqueue_orphaned_stream(
185            self.pool.clone(),
186            self.config.clone(),
187            *self.config.keep_alive(),
188        )
189        .map_ok(|_| ());
190        futures::stream::select(keep_alive, reenqueue).boxed()
191    }
192
193    fn middleware(&self) -> Self::Layer {
194        Stack::new(
195            LockTaskLayer::new(self.pool.clone()),
196            AcknowledgeLayer::new(PgAck::new(self.pool.clone())),
197        )
198    }
199
200    fn poll(self, worker: &WorkerContext) -> Self::Stream {
201        self.poll_shared(worker)
202            .map(|a| match a {
203                Ok(Some(task)) => Ok(Some(
204                    task.try_map(|t| Decode::decode(&t))
205                        .map_err(|e| sqlx::Error::Decode(e.into()))?,
206                )),
207                Ok(None) => Ok(None),
208                Err(e) => Err(e),
209            })
210            .boxed()
211    }
212}
213
214impl<Args, Decode> BackendExt for PostgresStorage<Args, CompactType, Decode, SharedFetcher>
215where
216    Args: Send + 'static + Unpin,
217    Decode: Codec<Args, Compact = CompactType> + 'static + Unpin + Send,
218    Decode::Error: std::error::Error + Send + Sync + 'static,
219{
220    type Compact = CompactType;
221
222    type Codec = Decode;
223    type CompactStream = TaskStream<PgTask<CompactType>, Self::Error>;
224
225    fn get_queue(&self) -> Queue {
226        self.config.queue().clone()
227    }
228
229    fn poll_compact(self, worker: &WorkerContext) -> Self::CompactStream {
230        self.poll_shared(worker).boxed()
231    }
232}
233
234impl<Args, Decode> PostgresStorage<Args, CompactType, Decode, SharedFetcher> {
235    fn poll_shared(
236        self,
237        worker: &WorkerContext,
238    ) -> impl Stream<Item = Result<Option<PgTask<CompactType>>, sqlx::Error>> + 'static {
239        let pool = self.pool.clone();
240        let worker_id = worker.name().to_owned();
241        let register_worker = initial_heartbeat(
242            self.pool.clone(),
243            self.config.clone(),
244            worker.clone(),
245            "SharedPostgresStorage",
246        )
247        .map_ok(|_| None);
248        let register = stream::once(register_worker);
249        let lazy_fetcher = self
250            .fetcher
251            .map(|t| t.to_string())
252            .ready_chunks(self.config.buffer_size())
253            .then(move |ids| {
254                let pool = pool.clone();
255                let worker_id = worker_id.clone();
256                async move {
257                    let mut tx = pool.begin().await?;
258                    let res: Vec<_> = sqlx::query_file_as!(
259                        PgTaskRow,
260                        "queries/task/queue_by_id.sql",
261                        &ids,
262                        &worker_id
263                    )
264                    .fetch(&mut *tx)
265                    .map(|r| {
266                        let row: TaskRow = r?.try_into()?;
267                        Ok(Some(
268                            row.try_into_task_compact()
269                                .map_err(|e| sqlx::Error::Protocol(e.to_string()))?,
270                        ))
271                    })
272                    .collect()
273                    .await;
274                    tx.commit().await?;
275                    Ok::<_, sqlx::Error>(res)
276                }
277            })
278            .flat_map(|vec| match vec {
279                Ok(vec) => stream::iter(vec.into_iter().map(|res| match res {
280                    Ok(t) => Ok(t),
281                    Err(e) => Err(e),
282                }))
283                .boxed(),
284                Err(e) => stream::once(ready(Err(e))).boxed(),
285            })
286            .boxed();
287        let eager_fetcher = StreamExt::boxed(PgPollFetcher::<CompactType>::new(
288            &self.pool,
289            &self.config,
290            worker,
291        ));
292        register.chain(select(lazy_fetcher, eager_fetcher))
293    }
294}
295
296#[cfg(test)]
297mod tests {
298    use std::time::Duration;
299
300    use apalis_core::{backend::TaskSink, error::BoxDynError, worker::builder::WorkerBuilder};
301
302    use super::*;
303
304    #[tokio::test]
305    async fn basic_worker() {
306        let pool = PgPool::connect(std::env::var("DATABASE_URL").unwrap().as_str())
307            .await
308            .unwrap();
309        let mut store = SharedPostgresStorage::new(pool);
310
311        let mut map_store = store.make_shared().unwrap();
312
313        let mut int_store = store.make_shared().unwrap();
314
315        map_store
316            .push_stream(&mut stream::iter(vec![HashMap::<String, String>::new()]))
317            .await
318            .unwrap();
319        int_store.push(99).await.unwrap();
320
321        async fn send_reminder<T>(
322            _: T,
323            _task_id: PgTaskId,
324            wrk: WorkerContext,
325        ) -> Result<(), BoxDynError> {
326            tokio::time::sleep(Duration::from_secs(2)).await;
327            wrk.stop().unwrap();
328            Ok(())
329        }
330
331        let int_worker = WorkerBuilder::new("rango-tango-2")
332            .backend(int_store)
333            .build(send_reminder);
334        let map_worker = WorkerBuilder::new("rango-tango-1")
335            .backend(map_store)
336            .build(send_reminder);
337        tokio::try_join!(int_worker.run(), map_worker.run()).unwrap();
338    }
339}