apalis_postgres/
shared.rs

1use std::{
2    collections::HashMap,
3    future::ready,
4    marker::PhantomData,
5    pin::Pin,
6    sync::Arc,
7    task::{Context, Poll},
8};
9
10use crate::{
11    CompactType, Config, InsertEvent, PgTask, PostgresStorage,
12    ack::{LockTaskLayer, PgAck},
13    context::PgContext,
14    fetcher::PgPollFetcher,
15    queries::{
16        keep_alive::{initial_heartbeat, keep_alive_stream},
17        reenqueue_orphaned::reenqueue_orphaned_stream,
18    },
19};
20use crate::{from_row::PgTaskRow, sink::PgSink};
21use apalis_core::{
22    backend::{
23        Backend, TaskStream,
24        codec::{Codec, json::JsonCodec},
25        shared::MakeShared,
26    },
27    layers::Stack,
28    task::task_id::TaskId,
29    worker::{context::WorkerContext, ext::ack::AcknowledgeLayer},
30};
31use apalis_sql::from_row::TaskRow;
32use futures::{
33    FutureExt, SinkExt, Stream, StreamExt, TryStreamExt,
34    channel::mpsc::{self, Receiver, Sender},
35    future::{BoxFuture, Shared},
36    lock::Mutex,
37    stream::{self, BoxStream, select},
38};
39use sqlx::{PgPool, postgres::PgListener};
40use ulid::Ulid;
41
42pub struct SharedPostgresStorage<Compact = CompactType, Codec = JsonCodec<CompactType>> {
43    pool: PgPool,
44    registry: Arc<Mutex<HashMap<String, Sender<TaskId>>>>,
45    drive: Shared<BoxFuture<'static, ()>>,
46    _marker: PhantomData<(Compact, Codec)>,
47}
48
49impl SharedPostgresStorage {
50    pub fn new(pool: PgPool) -> Self {
51        let registry: Arc<Mutex<HashMap<String, Sender<TaskId>>>> =
52            Arc::new(Mutex::new(HashMap::default()));
53        let p = pool.clone();
54        let instances = registry.clone();
55        Self {
56            pool,
57            drive: async move {
58                let mut listener = PgListener::connect_with(&p).await.unwrap();
59                listener.listen("apalis::job::insert").await.unwrap();
60                listener
61                    .into_stream()
62                    .filter_map(|notification| {
63                        let instances = instances.clone();
64                        async move {
65                            let pg_notification = notification.ok()?;
66                            let payload = pg_notification.payload();
67                            let ev: InsertEvent = serde_json::from_str(payload).ok()?;
68                            let instances = instances.lock().await;
69                            if instances.get(&ev.job_type).is_some() {
70                                return Some(ev);
71                            }
72                            None
73                        }
74                    })
75                    .for_each(|ev| {
76                        let instances = instances.clone();
77                        async move {
78                            let mut instances = instances.lock().await;
79                            let sender = instances.get_mut(&ev.job_type).unwrap();
80                            sender.send(ev.id).await.unwrap();
81                        }
82                    })
83                    .await;
84            }
85            .boxed()
86            .shared(),
87            registry,
88            _marker: PhantomData,
89        }
90    }
91}
92#[derive(Debug, thiserror::Error)]
93pub enum SharedPostgresError {
94    /// Namespace not found
95    #[error("namespace already exists: {0}")]
96    NamespaceExists(String),
97
98    /// Registry locked
99    #[error("registry locked")]
100    RegistryLocked,
101}
102
103impl<Args, Compact, Codec> MakeShared<Args> for SharedPostgresStorage<Compact, Codec> {
104    type Backend = PostgresStorage<Args, Compact, Codec, SharedFetcher>;
105    type Config = Config;
106    type MakeError = SharedPostgresError;
107    fn make_shared(&mut self) -> Result<Self::Backend, Self::MakeError>
108    where
109        Self::Config: Default,
110    {
111        Self::make_shared_with_config(self, Config::new(std::any::type_name::<Args>()))
112    }
113    fn make_shared_with_config(
114        &mut self,
115        config: Self::Config,
116    ) -> Result<Self::Backend, Self::MakeError> {
117        let (tx, rx) = mpsc::channel(config.buffer_size());
118        let mut r = self
119            .registry
120            .try_lock()
121            .ok_or(SharedPostgresError::RegistryLocked)?;
122        if r.insert(config.queue().to_string(), tx).is_some() {
123            return Err(SharedPostgresError::NamespaceExists(
124                config.queue().to_string(),
125            ));
126        }
127        let sink = PgSink::new(&self.pool, &config);
128        Ok(PostgresStorage {
129            _marker: PhantomData,
130            config,
131            fetcher: SharedFetcher {
132                poller: self.drive.clone(),
133                receiver: Arc::new(Mutex::new(rx)),
134            },
135            pool: self.pool.clone(),
136            sink,
137        })
138    }
139}
140
141#[derive(Clone, Debug)]
142pub struct SharedFetcher {
143    poller: Shared<BoxFuture<'static, ()>>,
144    receiver: Arc<Mutex<Receiver<TaskId>>>,
145}
146
147impl Stream for SharedFetcher {
148    type Item = TaskId;
149
150    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
151        let this = self.get_mut();
152        // Keep the poller alive by polling it, but ignoring the output
153        let _ = this.poller.poll_unpin(cx);
154
155        // Delegate actual items to receiver
156        let mut receiver = this.receiver.try_lock();
157        if let Some(ref mut rx) = receiver {
158            rx.poll_next_unpin(cx)
159        } else {
160            Poll::Pending
161        }
162    }
163}
164
165impl<Args, Decode> Backend for PostgresStorage<Args, CompactType, Decode, SharedFetcher>
166where
167    Args: Send + 'static + Unpin,
168    Decode: Codec<Args, Compact = CompactType> + 'static + Unpin + Send,
169    Decode::Error: std::error::Error + Send + Sync + 'static,
170{
171    type Args = Args;
172
173    type Compact = CompactType;
174
175    type IdType = Ulid;
176
177    type Error = sqlx::Error;
178
179    type Stream = TaskStream<PgTask<Args>, Self::Error>;
180
181    type Beat = BoxStream<'static, Result<(), Self::Error>>;
182
183    type Codec = Decode;
184
185    type Context = PgContext;
186
187    type Layer = Stack<LockTaskLayer, AcknowledgeLayer<PgAck>>;
188
189    fn heartbeat(&self, worker: &WorkerContext) -> Self::Beat {
190        let pool = self.pool.clone();
191        let config = self.config.clone();
192        let worker = worker.clone();
193        let keep_alive = keep_alive_stream(pool, config, worker);
194        let reenqueue = reenqueue_orphaned_stream(
195            self.pool.clone(),
196            self.config.clone(),
197            *self.config.keep_alive(),
198        )
199        .map_ok(|_| ());
200        futures::stream::select(keep_alive, reenqueue).boxed()
201    }
202
203    fn middleware(&self) -> Self::Layer {
204        Stack::new(
205            LockTaskLayer::new(self.pool.clone()),
206            AcknowledgeLayer::new(PgAck::new(self.pool.clone())),
207        )
208    }
209
210    fn poll(self, worker: &WorkerContext) -> Self::Stream {
211        let pool = self.pool.clone();
212        let worker_id = worker.name().to_owned();
213        let register_worker = initial_heartbeat(
214            self.pool.clone(),
215            self.config.clone(),
216            worker.clone(),
217            "SharedPostgresStorage",
218        )
219        .map(|_| Ok(None));
220        let register = stream::once(register_worker);
221        let lazy_fetcher = self
222            .fetcher
223            .map(|t| t.to_string())
224            .ready_chunks(self.config.buffer_size())
225            .then(move |ids| {
226                let pool = pool.clone();
227                let worker_id = worker_id.clone();
228                async move {
229                    let mut tx = pool.begin().await?;
230                    let res: Vec<_> = sqlx::query_file_as!(
231                        PgTaskRow,
232                        "queries/task/lock_by_id.sql",
233                        &ids,
234                        &worker_id
235                    )
236                    .fetch(&mut *tx)
237                    .map(|r| {
238                        let row: TaskRow = r?.try_into()?;
239                        Ok(Some(
240                            row.try_into_task::<Decode, Args, Ulid>()
241                                .map_err(|e| sqlx::Error::Protocol(e.to_string()))?,
242                        ))
243                    })
244                    .collect()
245                    .await;
246                    tx.commit().await?;
247                    Ok::<_, sqlx::Error>(res)
248                }
249            })
250            .flat_map(|vec| match vec {
251                Ok(vec) => stream::iter(vec.into_iter().map(|res| match res {
252                    Ok(t) => Ok(t),
253                    Err(e) => Err(e),
254                }))
255                .boxed(),
256                Err(e) => stream::once(ready(Err(e))).boxed(),
257            })
258            .boxed();
259
260        let eager_fetcher = StreamExt::boxed(PgPollFetcher::<Args, CompactType, Decode>::new(
261            &self.pool,
262            &self.config,
263            worker,
264        ));
265        register.chain(select(lazy_fetcher, eager_fetcher)).boxed()
266    }
267}
268
269#[cfg(test)]
270mod tests {
271    use std::time::Duration;
272
273    use apalis_core::{backend::TaskSink, error::BoxDynError, worker::builder::WorkerBuilder};
274
275    use super::*;
276
277    #[tokio::test]
278    async fn basic_worker() {
279        let pool = PgPool::connect("postgres://postgres:postgres@localhost/apalis_dev")
280            .await
281            .unwrap();
282        let mut store = SharedPostgresStorage::new(pool);
283
284        let mut map_store = store.make_shared().unwrap();
285
286        let mut int_store = store.make_shared().unwrap();
287
288        map_store
289            .push_stream(&mut stream::iter(vec![HashMap::<String, String>::new()]))
290            .await
291            .unwrap();
292        int_store.push(99).await.unwrap();
293
294        async fn send_reminder<T, I>(
295            _: T,
296            _task_id: TaskId<I>,
297            wrk: WorkerContext,
298        ) -> Result<(), BoxDynError> {
299            tokio::time::sleep(Duration::from_secs(2)).await;
300            wrk.stop().unwrap();
301            Ok(())
302        }
303
304        let int_worker = WorkerBuilder::new("rango-tango-2")
305            .backend(int_store)
306            .build(send_reminder);
307        let map_worker = WorkerBuilder::new("rango-tango-1")
308            .backend(map_store)
309            .build(send_reminder);
310        tokio::try_join!(int_worker.run(), map_worker.run()).unwrap();
311    }
312}