apalis_postgres/
shared.rs

1use std::{
2    collections::HashMap,
3    future::ready,
4    marker::PhantomData,
5    pin::Pin,
6    sync::Arc,
7    task::{Context, Poll},
8};
9
10use crate::{
11    CompactType, Config, InsertEvent, PgContext, PgTask, PostgresStorage,
12    ack::{LockTaskLayer, PgAck},
13    fetcher::PgPollFetcher,
14    queries::{
15        keep_alive::{initial_heartbeat, keep_alive_stream},
16        reenqueue_orphaned::reenqueue_orphaned_stream,
17    },
18};
19use crate::{from_row::PgTaskRow, sink::PgSink};
20use apalis_core::{
21    backend::{
22        Backend, BackendExt, TaskStream,
23        codec::{Codec, json::JsonCodec},
24        shared::MakeShared,
25    },
26    layers::Stack,
27    task::task_id::TaskId,
28    worker::{context::WorkerContext, ext::ack::AcknowledgeLayer},
29};
30use apalis_sql::from_row::TaskRow;
31use futures::{
32    FutureExt, SinkExt, Stream, StreamExt, TryFutureExt, TryStreamExt,
33    channel::mpsc::{self, Receiver, Sender},
34    future::{BoxFuture, Shared},
35    lock::Mutex,
36    stream::{self, BoxStream, select},
37};
38use sqlx::{PgPool, postgres::PgListener};
39use ulid::Ulid;
40
41pub struct SharedPostgresStorage<Compact = CompactType, Codec = JsonCodec<CompactType>> {
42    pool: PgPool,
43    registry: Arc<Mutex<HashMap<String, Sender<TaskId>>>>,
44    drive: Shared<BoxFuture<'static, ()>>,
45    _marker: PhantomData<(Compact, Codec)>,
46}
47
48impl SharedPostgresStorage {
49    pub fn new(pool: PgPool) -> Self {
50        let registry: Arc<Mutex<HashMap<String, Sender<TaskId>>>> =
51            Arc::new(Mutex::new(HashMap::default()));
52        let p = pool.clone();
53        let instances = registry.clone();
54        Self {
55            pool,
56            drive: async move {
57                let mut listener = PgListener::connect_with(&p).await.unwrap();
58                listener.listen("apalis::job::insert").await.unwrap();
59                listener
60                    .into_stream()
61                    .filter_map(|notification| {
62                        let instances = instances.clone();
63                        async move {
64                            let pg_notification = notification.ok()?;
65                            let payload = pg_notification.payload();
66                            let ev: InsertEvent = serde_json::from_str(payload).ok()?;
67                            let instances = instances.lock().await;
68                            if instances.get(&ev.job_type).is_some() {
69                                return Some(ev);
70                            }
71                            None
72                        }
73                    })
74                    .for_each(|ev| {
75                        let instances = instances.clone();
76                        async move {
77                            let mut instances = instances.lock().await;
78                            let sender = instances.get_mut(&ev.job_type).unwrap();
79                            sender.send(ev.id).await.unwrap();
80                        }
81                    })
82                    .await;
83            }
84            .boxed()
85            .shared(),
86            registry,
87            _marker: PhantomData,
88        }
89    }
90}
91#[derive(Debug, thiserror::Error)]
92pub enum SharedPostgresError {
93    /// Namespace not found
94    #[error("namespace already exists: {0}")]
95    NamespaceExists(String),
96
97    /// Registry locked
98    #[error("registry locked")]
99    RegistryLocked,
100}
101
102impl<Args, Compact, Codec> MakeShared<Args> for SharedPostgresStorage<Compact, Codec> {
103    type Backend = PostgresStorage<Args, Compact, Codec, SharedFetcher>;
104    type Config = Config;
105    type MakeError = SharedPostgresError;
106    fn make_shared(&mut self) -> Result<Self::Backend, Self::MakeError>
107    where
108        Self::Config: Default,
109    {
110        Self::make_shared_with_config(self, Config::new(std::any::type_name::<Args>()))
111    }
112    fn make_shared_with_config(
113        &mut self,
114        config: Self::Config,
115    ) -> Result<Self::Backend, Self::MakeError> {
116        let (tx, rx) = mpsc::channel(config.buffer_size());
117        let mut r = self
118            .registry
119            .try_lock()
120            .ok_or(SharedPostgresError::RegistryLocked)?;
121        if r.insert(config.queue().to_string(), tx).is_some() {
122            return Err(SharedPostgresError::NamespaceExists(
123                config.queue().to_string(),
124            ));
125        }
126        let sink = PgSink::new(&self.pool, &config);
127        Ok(PostgresStorage {
128            _marker: PhantomData,
129            config,
130            fetcher: SharedFetcher {
131                poller: self.drive.clone(),
132                receiver: Arc::new(Mutex::new(rx)),
133            },
134            pool: self.pool.clone(),
135            sink,
136        })
137    }
138}
139
140#[derive(Clone, Debug)]
141pub struct SharedFetcher {
142    poller: Shared<BoxFuture<'static, ()>>,
143    receiver: Arc<Mutex<Receiver<TaskId>>>,
144}
145
146impl Stream for SharedFetcher {
147    type Item = TaskId;
148
149    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
150        let this = self.get_mut();
151        // Keep the poller alive by polling it, but ignoring the output
152        let _ = this.poller.poll_unpin(cx);
153
154        // Delegate actual items to receiver
155        let mut receiver = this.receiver.try_lock();
156        if let Some(ref mut rx) = receiver {
157            rx.poll_next_unpin(cx)
158        } else {
159            Poll::Pending
160        }
161    }
162}
163
164impl<Args, Decode> Backend for PostgresStorage<Args, CompactType, Decode, SharedFetcher>
165where
166    Args: Send + 'static + Unpin,
167    Decode: Codec<Args, Compact = CompactType> + 'static + Unpin + Send,
168    Decode::Error: std::error::Error + Send + Sync + 'static,
169{
170    type Args = Args;
171
172    type IdType = Ulid;
173
174    type Error = sqlx::Error;
175
176    type Stream = TaskStream<PgTask<Args>, Self::Error>;
177
178    type Beat = BoxStream<'static, Result<(), Self::Error>>;
179
180    type Context = PgContext;
181
182    type Layer = Stack<LockTaskLayer, AcknowledgeLayer<PgAck>>;
183
184    fn heartbeat(&self, worker: &WorkerContext) -> Self::Beat {
185        let pool = self.pool.clone();
186        let config = self.config.clone();
187        let worker = worker.clone();
188        let keep_alive = keep_alive_stream(pool, config, worker);
189        let reenqueue = reenqueue_orphaned_stream(
190            self.pool.clone(),
191            self.config.clone(),
192            *self.config.keep_alive(),
193        )
194        .map_ok(|_| ());
195        futures::stream::select(keep_alive, reenqueue).boxed()
196    }
197
198    fn middleware(&self) -> Self::Layer {
199        Stack::new(
200            LockTaskLayer::new(self.pool.clone()),
201            AcknowledgeLayer::new(PgAck::new(self.pool.clone())),
202        )
203    }
204
205    fn poll(self, worker: &WorkerContext) -> Self::Stream {
206        self.poll_shared(worker)
207            .map(|a| match a {
208                Ok(Some(task)) => Ok(Some(
209                    task.try_map(|t| Decode::decode(&t))
210                        .map_err(|e| sqlx::Error::Decode(e.into()))?,
211                )),
212                Ok(None) => Ok(None),
213                Err(e) => Err(e),
214            })
215            .boxed()
216    }
217}
218
219impl<Args, Decode> BackendExt for PostgresStorage<Args, CompactType, Decode, SharedFetcher>
220where
221    Args: Send + 'static + Unpin,
222    Decode: Codec<Args, Compact = CompactType> + 'static + Unpin + Send,
223    Decode::Error: std::error::Error + Send + Sync + 'static,
224{
225    type Compact = CompactType;
226
227    type Codec = Decode;
228    type CompactStream = TaskStream<PgTask<CompactType>, Self::Error>;
229    fn poll_compact(self, worker: &WorkerContext) -> Self::CompactStream {
230        self.poll_shared(worker).boxed()
231    }
232}
233
234impl<Args, Decode> PostgresStorage<Args, CompactType, Decode, SharedFetcher> {
235    fn poll_shared(
236        self,
237        worker: &WorkerContext,
238    ) -> impl Stream<Item = Result<Option<PgTask<CompactType>>, sqlx::Error>> + 'static {
239        let pool = self.pool.clone();
240        let worker_id = worker.name().to_owned();
241        let register_worker = initial_heartbeat(
242            self.pool.clone(),
243            self.config.clone(),
244            worker.clone(),
245            "SharedPostgresStorage",
246        )
247        .map_ok(|_| None);
248        let register = stream::once(register_worker);
249        let lazy_fetcher = self
250            .fetcher
251            .map(|t| t.to_string())
252            .ready_chunks(self.config.buffer_size())
253            .then(move |ids| {
254                let pool = pool.clone();
255                let worker_id = worker_id.clone();
256                async move {
257                    let mut tx = pool.begin().await?;
258                    let res: Vec<_> = sqlx::query_file_as!(
259                        PgTaskRow,
260                        "queries/task/queue_by_id.sql",
261                        &ids,
262                        &worker_id
263                    )
264                    .fetch(&mut *tx)
265                    .map(|r| {
266                        let row: TaskRow = r?.try_into()?;
267                        Ok(Some(
268                            row.try_into_task_compact()
269                                .map_err(|e| sqlx::Error::Protocol(e.to_string()))?,
270                        ))
271                    })
272                    .collect()
273                    .await;
274                    tx.commit().await?;
275                    Ok::<_, sqlx::Error>(res)
276                }
277            })
278            .flat_map(|vec| match vec {
279                Ok(vec) => stream::iter(vec.into_iter().map(|res| match res {
280                    Ok(t) => Ok(t),
281                    Err(e) => Err(e),
282                }))
283                .boxed(),
284                Err(e) => stream::once(ready(Err(e))).boxed(),
285            })
286            .boxed();
287        let eager_fetcher = StreamExt::boxed(PgPollFetcher::<CompactType>::new(
288            &self.pool,
289            &self.config,
290            worker,
291        ));
292        register.chain(select(lazy_fetcher, eager_fetcher))
293    }
294}
295
296#[cfg(test)]
297mod tests {
298    use std::time::Duration;
299
300    use apalis_core::{backend::TaskSink, error::BoxDynError, worker::builder::WorkerBuilder};
301
302    use super::*;
303
304    #[tokio::test]
305    async fn basic_worker() {
306        let pool = PgPool::connect("postgres://postgres:postgres@localhost/apalis_dev")
307            .await
308            .unwrap();
309        let mut store = SharedPostgresStorage::new(pool);
310
311        let mut map_store = store.make_shared().unwrap();
312
313        let mut int_store = store.make_shared().unwrap();
314
315        map_store
316            .push_stream(&mut stream::iter(vec![HashMap::<String, String>::new()]))
317            .await
318            .unwrap();
319        int_store.push(99).await.unwrap();
320
321        async fn send_reminder<T, I>(
322            _: T,
323            _task_id: TaskId<I>,
324            wrk: WorkerContext,
325        ) -> Result<(), BoxDynError> {
326            tokio::time::sleep(Duration::from_secs(2)).await;
327            wrk.stop().unwrap();
328            Ok(())
329        }
330
331        let int_worker = WorkerBuilder::new("rango-tango-2")
332            .backend(int_store)
333            .build(send_reminder);
334        let map_worker = WorkerBuilder::new("rango-tango-1")
335            .backend(map_store)
336            .build(send_reminder);
337        tokio::try_join!(int_worker.run(), map_worker.run()).unwrap();
338    }
339}