1use std::{
2 collections::HashMap,
3 future::ready,
4 marker::PhantomData,
5 pin::Pin,
6 sync::Arc,
7 task::{Context, Poll},
8};
9
10use crate::{
11 CompactType, Config, InsertEvent, PgContext, PgTask, PostgresStorage,
12 ack::{LockTaskLayer, PgAck},
13 fetcher::PgPollFetcher,
14 queries::{
15 keep_alive::{initial_heartbeat, keep_alive_stream},
16 reenqueue_orphaned::reenqueue_orphaned_stream,
17 },
18};
19use crate::{from_row::PgTaskRow, sink::PgSink};
20use apalis_core::{
21 backend::{
22 Backend, BackendExt, TaskStream,
23 codec::{Codec, json::JsonCodec},
24 shared::MakeShared,
25 },
26 layers::Stack,
27 task::task_id::TaskId,
28 worker::{context::WorkerContext, ext::ack::AcknowledgeLayer},
29};
30use apalis_sql::from_row::TaskRow;
31use futures::{
32 FutureExt, SinkExt, Stream, StreamExt, TryFutureExt, TryStreamExt,
33 channel::mpsc::{self, Receiver, Sender},
34 future::{BoxFuture, Shared},
35 lock::Mutex,
36 stream::{self, BoxStream, select},
37};
38use sqlx::{PgPool, postgres::PgListener};
39use ulid::Ulid;
40
41pub struct SharedPostgresStorage<Compact = CompactType, Codec = JsonCodec<CompactType>> {
42 pool: PgPool,
43 registry: Arc<Mutex<HashMap<String, Sender<TaskId>>>>,
44 drive: Shared<BoxFuture<'static, ()>>,
45 _marker: PhantomData<(Compact, Codec)>,
46}
47
48impl SharedPostgresStorage {
49 pub fn new(pool: PgPool) -> Self {
50 let registry: Arc<Mutex<HashMap<String, Sender<TaskId>>>> =
51 Arc::new(Mutex::new(HashMap::default()));
52 let p = pool.clone();
53 let instances = registry.clone();
54 Self {
55 pool,
56 drive: async move {
57 let mut listener = PgListener::connect_with(&p).await.unwrap();
58 listener.listen("apalis::job::insert").await.unwrap();
59 listener
60 .into_stream()
61 .filter_map(|notification| {
62 let instances = instances.clone();
63 async move {
64 let pg_notification = notification.ok()?;
65 let payload = pg_notification.payload();
66 let ev: InsertEvent = serde_json::from_str(payload).ok()?;
67 let instances = instances.lock().await;
68 if instances.get(&ev.job_type).is_some() {
69 return Some(ev);
70 }
71 None
72 }
73 })
74 .for_each(|ev| {
75 let instances = instances.clone();
76 async move {
77 let mut instances = instances.lock().await;
78 let sender = instances.get_mut(&ev.job_type).unwrap();
79 sender.send(ev.id).await.unwrap();
80 }
81 })
82 .await;
83 }
84 .boxed()
85 .shared(),
86 registry,
87 _marker: PhantomData,
88 }
89 }
90}
91#[derive(Debug, thiserror::Error)]
92pub enum SharedPostgresError {
93 #[error("namespace already exists: {0}")]
95 NamespaceExists(String),
96
97 #[error("registry locked")]
99 RegistryLocked,
100}
101
102impl<Args, Compact, Codec> MakeShared<Args> for SharedPostgresStorage<Compact, Codec> {
103 type Backend = PostgresStorage<Args, Compact, Codec, SharedFetcher>;
104 type Config = Config;
105 type MakeError = SharedPostgresError;
106 fn make_shared(&mut self) -> Result<Self::Backend, Self::MakeError>
107 where
108 Self::Config: Default,
109 {
110 Self::make_shared_with_config(self, Config::new(std::any::type_name::<Args>()))
111 }
112 fn make_shared_with_config(
113 &mut self,
114 config: Self::Config,
115 ) -> Result<Self::Backend, Self::MakeError> {
116 let (tx, rx) = mpsc::channel(config.buffer_size());
117 let mut r = self
118 .registry
119 .try_lock()
120 .ok_or(SharedPostgresError::RegistryLocked)?;
121 if r.insert(config.queue().to_string(), tx).is_some() {
122 return Err(SharedPostgresError::NamespaceExists(
123 config.queue().to_string(),
124 ));
125 }
126 let sink = PgSink::new(&self.pool, &config);
127 Ok(PostgresStorage {
128 _marker: PhantomData,
129 config,
130 fetcher: SharedFetcher {
131 poller: self.drive.clone(),
132 receiver: Arc::new(Mutex::new(rx)),
133 },
134 pool: self.pool.clone(),
135 sink,
136 })
137 }
138}
139
140#[derive(Clone, Debug)]
141pub struct SharedFetcher {
142 poller: Shared<BoxFuture<'static, ()>>,
143 receiver: Arc<Mutex<Receiver<TaskId>>>,
144}
145
146impl Stream for SharedFetcher {
147 type Item = TaskId;
148
149 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
150 let this = self.get_mut();
151 let _ = this.poller.poll_unpin(cx);
153
154 let mut receiver = this.receiver.try_lock();
156 if let Some(ref mut rx) = receiver {
157 rx.poll_next_unpin(cx)
158 } else {
159 Poll::Pending
160 }
161 }
162}
163
164impl<Args, Decode> Backend for PostgresStorage<Args, CompactType, Decode, SharedFetcher>
165where
166 Args: Send + 'static + Unpin,
167 Decode: Codec<Args, Compact = CompactType> + 'static + Unpin + Send,
168 Decode::Error: std::error::Error + Send + Sync + 'static,
169{
170 type Args = Args;
171
172 type IdType = Ulid;
173
174 type Error = sqlx::Error;
175
176 type Stream = TaskStream<PgTask<Args>, Self::Error>;
177
178 type Beat = BoxStream<'static, Result<(), Self::Error>>;
179
180 type Context = PgContext;
181
182 type Layer = Stack<LockTaskLayer, AcknowledgeLayer<PgAck>>;
183
184 fn heartbeat(&self, worker: &WorkerContext) -> Self::Beat {
185 let pool = self.pool.clone();
186 let config = self.config.clone();
187 let worker = worker.clone();
188 let keep_alive = keep_alive_stream(pool, config, worker);
189 let reenqueue = reenqueue_orphaned_stream(
190 self.pool.clone(),
191 self.config.clone(),
192 *self.config.keep_alive(),
193 )
194 .map_ok(|_| ());
195 futures::stream::select(keep_alive, reenqueue).boxed()
196 }
197
198 fn middleware(&self) -> Self::Layer {
199 Stack::new(
200 LockTaskLayer::new(self.pool.clone()),
201 AcknowledgeLayer::new(PgAck::new(self.pool.clone())),
202 )
203 }
204
205 fn poll(self, worker: &WorkerContext) -> Self::Stream {
206 self.poll_shared(worker)
207 .map(|a| match a {
208 Ok(Some(task)) => Ok(Some(
209 task.try_map(|t| Decode::decode(&t))
210 .map_err(|e| sqlx::Error::Decode(e.into()))?,
211 )),
212 Ok(None) => Ok(None),
213 Err(e) => Err(e),
214 })
215 .boxed()
216 }
217}
218
219impl<Args, Decode> BackendExt for PostgresStorage<Args, CompactType, Decode, SharedFetcher>
220where
221 Args: Send + 'static + Unpin,
222 Decode: Codec<Args, Compact = CompactType> + 'static + Unpin + Send,
223 Decode::Error: std::error::Error + Send + Sync + 'static,
224{
225 type Compact = CompactType;
226
227 type Codec = Decode;
228 type CompactStream = TaskStream<PgTask<CompactType>, Self::Error>;
229 fn poll_compact(self, worker: &WorkerContext) -> Self::CompactStream {
230 self.poll_shared(worker).boxed()
231 }
232}
233
234impl<Args, Decode> PostgresStorage<Args, CompactType, Decode, SharedFetcher> {
235 fn poll_shared(
236 self,
237 worker: &WorkerContext,
238 ) -> impl Stream<Item = Result<Option<PgTask<CompactType>>, sqlx::Error>> + 'static {
239 let pool = self.pool.clone();
240 let worker_id = worker.name().to_owned();
241 let register_worker = initial_heartbeat(
242 self.pool.clone(),
243 self.config.clone(),
244 worker.clone(),
245 "SharedPostgresStorage",
246 )
247 .map_ok(|_| None);
248 let register = stream::once(register_worker);
249 let lazy_fetcher = self
250 .fetcher
251 .map(|t| t.to_string())
252 .ready_chunks(self.config.buffer_size())
253 .then(move |ids| {
254 let pool = pool.clone();
255 let worker_id = worker_id.clone();
256 async move {
257 let mut tx = pool.begin().await?;
258 let res: Vec<_> = sqlx::query_file_as!(
259 PgTaskRow,
260 "queries/task/queue_by_id.sql",
261 &ids,
262 &worker_id
263 )
264 .fetch(&mut *tx)
265 .map(|r| {
266 let row: TaskRow = r?.try_into()?;
267 Ok(Some(
268 row.try_into_task_compact()
269 .map_err(|e| sqlx::Error::Protocol(e.to_string()))?,
270 ))
271 })
272 .collect()
273 .await;
274 tx.commit().await?;
275 Ok::<_, sqlx::Error>(res)
276 }
277 })
278 .flat_map(|vec| match vec {
279 Ok(vec) => stream::iter(vec.into_iter().map(|res| match res {
280 Ok(t) => Ok(t),
281 Err(e) => Err(e),
282 }))
283 .boxed(),
284 Err(e) => stream::once(ready(Err(e))).boxed(),
285 })
286 .boxed();
287 let eager_fetcher = StreamExt::boxed(PgPollFetcher::<CompactType>::new(
288 &self.pool,
289 &self.config,
290 worker,
291 ));
292 register.chain(select(lazy_fetcher, eager_fetcher))
293 }
294}
295
296#[cfg(test)]
297mod tests {
298 use std::time::Duration;
299
300 use apalis_core::{backend::TaskSink, error::BoxDynError, worker::builder::WorkerBuilder};
301
302 use super::*;
303
304 #[tokio::test]
305 async fn basic_worker() {
306 let pool = PgPool::connect("postgres://postgres:postgres@localhost/apalis_dev")
307 .await
308 .unwrap();
309 let mut store = SharedPostgresStorage::new(pool);
310
311 let mut map_store = store.make_shared().unwrap();
312
313 let mut int_store = store.make_shared().unwrap();
314
315 map_store
316 .push_stream(&mut stream::iter(vec![HashMap::<String, String>::new()]))
317 .await
318 .unwrap();
319 int_store.push(99).await.unwrap();
320
321 async fn send_reminder<T, I>(
322 _: T,
323 _task_id: TaskId<I>,
324 wrk: WorkerContext,
325 ) -> Result<(), BoxDynError> {
326 tokio::time::sleep(Duration::from_secs(2)).await;
327 wrk.stop().unwrap();
328 Ok(())
329 }
330
331 let int_worker = WorkerBuilder::new("rango-tango-2")
332 .backend(int_store)
333 .build(send_reminder);
334 let map_worker = WorkerBuilder::new("rango-tango-1")
335 .backend(map_store)
336 .build(send_reminder);
337 tokio::try_join!(int_worker.run(), map_worker.run()).unwrap();
338 }
339}