apalis_postgres/
shared.rs

1use std::{
2    collections::HashMap,
3    future::ready,
4    marker::PhantomData,
5    pin::Pin,
6    sync::Arc,
7    task::{Context, Poll},
8};
9
10use crate::{
11    CompactType, Config, InsertEvent, PgTask, PostgresStorage,
12    ack::{LockTaskLayer, PgAck},
13    context::PgContext,
14    fetcher::PgPollFetcher,
15    queries::{
16        keep_alive::{initial_heartbeat, keep_alive_stream},
17        reenqueue_orphaned::reenqueue_orphaned_stream,
18    },
19};
20use crate::{from_row::PgTaskRow, sink::PgSink};
21use apalis_core::{
22    backend::{
23        Backend, TaskStream,
24        codec::{Codec, json::JsonCodec},
25        shared::MakeShared,
26    },
27    layers::Stack,
28    task::task_id::TaskId,
29    worker::{context::WorkerContext, ext::ack::AcknowledgeLayer},
30};
31use apalis_sql::from_row::TaskRow;
32use futures::{
33    FutureExt, SinkExt, Stream, StreamExt, TryStreamExt,
34    channel::mpsc::{self, Receiver, Sender},
35    future::{BoxFuture, Shared},
36    lock::Mutex,
37    stream::{self, BoxStream, select},
38};
39use sqlx::{PgPool, postgres::PgListener};
40use ulid::Ulid;
41
42pub struct SharedPostgresStorage<Compact = CompactType, Codec = JsonCodec<CompactType>> {
43    pool: PgPool,
44    registry: Arc<Mutex<HashMap<String, Sender<TaskId>>>>,
45    drive: Shared<BoxFuture<'static, ()>>,
46    _marker: PhantomData<(Compact, Codec)>,
47}
48
49impl SharedPostgresStorage {
50    pub fn new(pool: PgPool) -> Self {
51        let registry: Arc<Mutex<HashMap<String, Sender<TaskId>>>> =
52            Arc::new(Mutex::new(HashMap::default()));
53        let p = pool.clone();
54        let instances = registry.clone();
55        Self {
56            pool,
57            drive: async move {
58                let mut listener = PgListener::connect_with(&p).await.unwrap();
59                listener.listen("apalis::job::insert").await.unwrap();
60                listener
61                    .into_stream()
62                    .filter_map(|notification| {
63                        let instances = instances.clone();
64                        async move {
65                            let pg_notification = notification.ok()?;
66                            let payload = pg_notification.payload();
67                            let ev: InsertEvent = serde_json::from_str(payload).ok()?;
68                            let instances = instances.lock().await;
69                            if instances.get(&ev.job_type).is_some() {
70                                return Some(ev);
71                            }
72                            None
73                        }
74                    })
75                    .for_each(|ev| {
76                        let instances = instances.clone();
77                        async move {
78                            let mut instances = instances.lock().await;
79                            let sender = instances.get_mut(&ev.job_type).unwrap();
80                            sender.send(ev.id).await.unwrap();
81                        }
82                    })
83                    .await;
84            }
85            .boxed()
86            .shared(),
87            registry,
88            _marker: PhantomData,
89        }
90    }
91}
92#[derive(Debug, thiserror::Error)]
93pub enum SharedPostgresError {
94    /// Namespace not found
95    #[error("namespace already exists: {0}")]
96    NamespaceExists(String),
97
98    /// Registry locked
99    #[error("registry locked")]
100    RegistryLocked,
101}
102
103impl<Args, Compact, Codec> MakeShared<Args> for SharedPostgresStorage<Compact, Codec> {
104    type Backend = PostgresStorage<Args, Compact, Codec, SharedFetcher>;
105    type Config = Config;
106    type MakeError = SharedPostgresError;
107    fn make_shared(&mut self) -> Result<Self::Backend, Self::MakeError>
108    where
109        Self::Config: Default,
110    {
111        Self::make_shared_with_config(self, Config::new(std::any::type_name::<Args>()))
112    }
113    fn make_shared_with_config(
114        &mut self,
115        config: Self::Config,
116    ) -> Result<Self::Backend, Self::MakeError> {
117        let (tx, rx) = mpsc::channel(config.buffer_size());
118        let mut r = self
119            .registry
120            .try_lock()
121            .ok_or(SharedPostgresError::RegistryLocked)?;
122        if r.insert(config.queue().to_string(), tx).is_some() {
123            return Err(SharedPostgresError::NamespaceExists(
124                config.queue().to_string(),
125            ));
126        }
127        let sink = PgSink::new(&self.pool, &config);
128        Ok(PostgresStorage {
129            _marker: PhantomData,
130            config,
131            fetcher: SharedFetcher {
132                poller: self.drive.clone(),
133                receiver: rx,
134            },
135            pool: self.pool.clone(),
136            sink,
137        })
138    }
139}
140
141pub struct SharedFetcher {
142    poller: Shared<BoxFuture<'static, ()>>,
143    receiver: Receiver<TaskId>,
144}
145
146impl Stream for SharedFetcher {
147    type Item = TaskId;
148
149    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
150        let this = self.get_mut();
151        // Keep the poller alive by polling it, but ignoring the output
152        let _ = this.poller.poll_unpin(cx);
153
154        // Delegate actual items to receiver
155        this.receiver.poll_next_unpin(cx)
156    }
157}
158
159impl<Args, Decode> Backend for PostgresStorage<Args, CompactType, Decode, SharedFetcher>
160where
161    Args: Send + 'static + Unpin,
162    Decode: Codec<Args, Compact = CompactType> + 'static + Unpin + Send,
163    Decode::Error: std::error::Error + Send + Sync + 'static,
164{
165    type Args = Args;
166
167    type Compact = CompactType;
168
169    type IdType = Ulid;
170
171    type Error = sqlx::Error;
172
173    type Stream = TaskStream<PgTask<Args>, Self::Error>;
174
175    type Beat = BoxStream<'static, Result<(), Self::Error>>;
176
177    type Codec = Decode;
178
179    type Context = PgContext;
180
181    type Layer = Stack<LockTaskLayer, AcknowledgeLayer<PgAck>>;
182
183    fn heartbeat(&self, worker: &WorkerContext) -> Self::Beat {
184        let pool = self.pool.clone();
185        let config = self.config.clone();
186        let worker = worker.clone();
187        let keep_alive = keep_alive_stream(pool, config, worker);
188        let reenqueue = reenqueue_orphaned_stream(
189            self.pool.clone(),
190            self.config.clone(),
191            *self.config.keep_alive(),
192        )
193        .map_ok(|_| ());
194        futures::stream::select(keep_alive, reenqueue).boxed()
195    }
196
197    fn middleware(&self) -> Self::Layer {
198        Stack::new(
199            LockTaskLayer::new(self.pool.clone()),
200            AcknowledgeLayer::new(PgAck::new(self.pool.clone())),
201        )
202    }
203
204    fn poll(self, worker: &WorkerContext) -> Self::Stream {
205        let pool = self.pool.clone();
206        let worker_id = worker.name().to_owned();
207        let register_worker = initial_heartbeat(
208            self.pool.clone(),
209            self.config.clone(),
210            worker.clone(),
211            "SharedPostgresStorage",
212        )
213        .map(|_| Ok(None));
214        let register = stream::once(register_worker);
215        let lazy_fetcher = self
216            .fetcher
217            .map(|t| t.to_string())
218            .ready_chunks(self.config.buffer_size())
219            .then(move |ids| {
220                let pool = pool.clone();
221                let worker_id = worker_id.clone();
222                async move {
223                    let mut tx = pool.begin().await?;
224                    let res: Vec<_> = sqlx::query_file_as!(
225                        PgTaskRow,
226                        "queries/task/lock_by_id.sql",
227                        &ids,
228                        &worker_id
229                    )
230                    .fetch(&mut *tx)
231                    .map(|r| {
232                        let row: TaskRow = r?.try_into()?;
233                        Ok(Some(
234                            row.try_into_task::<Decode, Args, Ulid>()
235                                .map_err(|e| sqlx::Error::Protocol(e.to_string()))?,
236                        ))
237                    })
238                    .collect()
239                    .await;
240                    tx.commit().await?;
241                    Ok::<_, sqlx::Error>(res)
242                }
243            })
244            .flat_map(|vec| match vec {
245                Ok(vec) => stream::iter(vec.into_iter().map(|res| match res {
246                    Ok(t) => Ok(t),
247                    Err(e) => Err(e),
248                }))
249                .boxed(),
250                Err(e) => stream::once(ready(Err(e))).boxed(),
251            })
252            .boxed();
253
254        let eager_fetcher = StreamExt::boxed(PgPollFetcher::<Args, CompactType, Decode>::new(
255            &self.pool,
256            &self.config,
257            worker,
258        ));
259        register.chain(select(lazy_fetcher, eager_fetcher)).boxed()
260    }
261}
262
263#[cfg(test)]
264mod tests {
265    use std::time::Duration;
266
267    use apalis_core::{backend::TaskSink, error::BoxDynError, worker::builder::WorkerBuilder};
268
269    use super::*;
270
271    #[tokio::test]
272    async fn basic_worker() {
273        let pool = PgPool::connect("postgres://postgres:postgres@localhost/apalis_dev")
274            .await
275            .unwrap();
276        let mut store = SharedPostgresStorage::new(pool);
277
278        let mut map_store = store.make_shared().unwrap();
279
280        let mut int_store = store.make_shared().unwrap();
281
282        map_store
283            .push_stream(&mut stream::iter(vec![HashMap::<String, String>::new()]))
284            .await
285            .unwrap();
286        int_store.push(99).await.unwrap();
287
288        async fn send_reminder<T, I>(
289            _: T,
290            _task_id: TaskId<I>,
291            wrk: WorkerContext,
292        ) -> Result<(), BoxDynError> {
293            tokio::time::sleep(Duration::from_secs(2)).await;
294            wrk.stop().unwrap();
295            Ok(())
296        }
297
298        let int_worker = WorkerBuilder::new("rango-tango-2")
299            .backend(int_store)
300            .build(send_reminder);
301        let map_worker = WorkerBuilder::new("rango-tango-1")
302            .backend(map_store)
303            .build(send_reminder);
304        tokio::try_join!(int_worker.run(), map_worker.run()).unwrap();
305    }
306}