1use std::{
2 collections::HashMap,
3 future::ready,
4 marker::PhantomData,
5 pin::Pin,
6 sync::Arc,
7 task::{Context, Poll},
8};
9
10use crate::{
11 CompactType, Config, InsertEvent, PgContext, PgTask, PgTaskId, PostgresStorage,
12 ack::{LockTaskLayer, PgAck},
13 fetcher::PgPollFetcher,
14 queries::{
15 keep_alive::{initial_heartbeat, keep_alive_stream},
16 reenqueue_orphaned::reenqueue_orphaned_stream,
17 },
18};
19use crate::{from_row::PgTaskRow, sink::PgSink};
20use apalis_codec::json::JsonCodec;
21use apalis_core::{
22 backend::{Backend, BackendExt, TaskStream, codec::Codec, queue::Queue, shared::MakeShared},
23 layers::Stack,
24 worker::{context::WorkerContext, ext::ack::AcknowledgeLayer},
25};
26use apalis_sql::from_row::TaskRow;
27use futures::{
28 FutureExt, SinkExt, Stream, StreamExt, TryFutureExt, TryStreamExt,
29 channel::mpsc::{self, Receiver, Sender},
30 future::{BoxFuture, Shared},
31 lock::Mutex,
32 stream::{self, BoxStream, select},
33};
34use sqlx::{PgPool, postgres::PgListener};
35use ulid::Ulid;
36
37pub struct SharedPostgresStorage<Compact = CompactType, Codec = JsonCodec<CompactType>> {
38 pool: PgPool,
39 registry: Arc<Mutex<HashMap<String, Sender<PgTaskId>>>>,
40 drive: Shared<BoxFuture<'static, ()>>,
41 _marker: PhantomData<(Compact, Codec)>,
42}
43
44impl SharedPostgresStorage {
45 pub fn new(pool: PgPool) -> Self {
46 let registry: Arc<Mutex<HashMap<String, Sender<PgTaskId>>>> =
47 Arc::new(Mutex::new(HashMap::default()));
48 let p = pool.clone();
49 let instances = registry.clone();
50 Self {
51 pool,
52 drive: async move {
53 let mut listener = PgListener::connect_with(&p).await.unwrap();
54 listener.listen("apalis::job::insert").await.unwrap();
55 listener
56 .into_stream()
57 .filter_map(|notification| {
58 let instances = instances.clone();
59 async move {
60 let pg_notification = notification.ok()?;
61 let payload = pg_notification.payload();
62 let ev: InsertEvent = serde_json::from_str(payload).ok()?;
63 let instances = instances.lock().await;
64 if instances.get(&ev.job_type).is_some() {
65 return Some(ev);
66 }
67 None
68 }
69 })
70 .for_each(|ev| {
71 let instances = instances.clone();
72 async move {
73 let mut instances = instances.lock().await;
74 let sender = instances.get_mut(&ev.job_type).unwrap();
75 sender.send(ev.id).await.unwrap();
76 }
77 })
78 .await;
79 }
80 .boxed()
81 .shared(),
82 registry,
83 _marker: PhantomData,
84 }
85 }
86}
87#[derive(Debug, thiserror::Error)]
88pub enum SharedPostgresError {
89 #[error("namespace already exists: {0}")]
91 NamespaceExists(String),
92
93 #[error("registry locked")]
95 RegistryLocked,
96}
97
98impl<Args, Compact, Codec> MakeShared<Args> for SharedPostgresStorage<Compact, Codec> {
99 type Backend = PostgresStorage<Args, Compact, Codec, SharedFetcher>;
100 type Config = Config;
101 type MakeError = SharedPostgresError;
102 fn make_shared(&mut self) -> Result<Self::Backend, Self::MakeError>
103 where
104 Self::Config: Default,
105 {
106 Self::make_shared_with_config(self, Config::new(std::any::type_name::<Args>()))
107 }
108 fn make_shared_with_config(
109 &mut self,
110 config: Self::Config,
111 ) -> Result<Self::Backend, Self::MakeError> {
112 let (tx, rx) = mpsc::channel(config.buffer_size());
113 let mut r = self
114 .registry
115 .try_lock()
116 .ok_or(SharedPostgresError::RegistryLocked)?;
117 if r.insert(config.queue().to_string(), tx).is_some() {
118 return Err(SharedPostgresError::NamespaceExists(
119 config.queue().to_string(),
120 ));
121 }
122 let sink = PgSink::new(&self.pool, &config);
123 Ok(PostgresStorage {
124 _marker: PhantomData,
125 config,
126 fetcher: SharedFetcher {
127 poller: self.drive.clone(),
128 receiver: Arc::new(Mutex::new(rx)),
129 },
130 pool: self.pool.clone(),
131 sink,
132 })
133 }
134}
135
136#[derive(Clone, Debug)]
137pub struct SharedFetcher {
138 poller: Shared<BoxFuture<'static, ()>>,
139 receiver: Arc<Mutex<Receiver<PgTaskId>>>,
140}
141
142impl Stream for SharedFetcher {
143 type Item = PgTaskId;
144 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
145 let this = self.get_mut();
146 let _ = this.poller.poll_unpin(cx);
148
149 let mut receiver = this.receiver.try_lock();
151 if let Some(ref mut rx) = receiver {
152 rx.poll_next_unpin(cx)
153 } else {
154 Poll::Pending
155 }
156 }
157}
158
159impl<Args, Decode> Backend for PostgresStorage<Args, CompactType, Decode, SharedFetcher>
160where
161 Args: Send + 'static + Unpin,
162 Decode: Codec<Args, Compact = CompactType> + 'static + Unpin + Send,
163 Decode::Error: std::error::Error + Send + Sync + 'static,
164{
165 type Args = Args;
166
167 type IdType = Ulid;
168
169 type Error = sqlx::Error;
170
171 type Stream = TaskStream<PgTask<Args>, Self::Error>;
172
173 type Beat = BoxStream<'static, Result<(), Self::Error>>;
174
175 type Context = PgContext;
176
177 type Layer = Stack<LockTaskLayer, AcknowledgeLayer<PgAck>>;
178
179 fn heartbeat(&self, worker: &WorkerContext) -> Self::Beat {
180 let pool = self.pool.clone();
181 let config = self.config.clone();
182 let worker = worker.clone();
183 let keep_alive = keep_alive_stream(pool, config, worker);
184 let reenqueue = reenqueue_orphaned_stream(
185 self.pool.clone(),
186 self.config.clone(),
187 *self.config.keep_alive(),
188 )
189 .map_ok(|_| ());
190 futures::stream::select(keep_alive, reenqueue).boxed()
191 }
192
193 fn middleware(&self) -> Self::Layer {
194 Stack::new(
195 LockTaskLayer::new(self.pool.clone()),
196 AcknowledgeLayer::new(PgAck::new(self.pool.clone())),
197 )
198 }
199
200 fn poll(self, worker: &WorkerContext) -> Self::Stream {
201 self.poll_shared(worker)
202 .map(|a| match a {
203 Ok(Some(task)) => Ok(Some(
204 task.try_map(|t| Decode::decode(&t))
205 .map_err(|e| sqlx::Error::Decode(e.into()))?,
206 )),
207 Ok(None) => Ok(None),
208 Err(e) => Err(e),
209 })
210 .boxed()
211 }
212}
213
214impl<Args, Decode> BackendExt for PostgresStorage<Args, CompactType, Decode, SharedFetcher>
215where
216 Args: Send + 'static + Unpin,
217 Decode: Codec<Args, Compact = CompactType> + 'static + Unpin + Send,
218 Decode::Error: std::error::Error + Send + Sync + 'static,
219{
220 type Compact = CompactType;
221
222 type Codec = Decode;
223 type CompactStream = TaskStream<PgTask<CompactType>, Self::Error>;
224
225 fn get_queue(&self) -> Queue {
226 self.config.queue().clone()
227 }
228
229 fn poll_compact(self, worker: &WorkerContext) -> Self::CompactStream {
230 self.poll_shared(worker).boxed()
231 }
232}
233
234impl<Args, Decode> PostgresStorage<Args, CompactType, Decode, SharedFetcher> {
235 fn poll_shared(
236 self,
237 worker: &WorkerContext,
238 ) -> impl Stream<Item = Result<Option<PgTask<CompactType>>, sqlx::Error>> + 'static {
239 let pool = self.pool.clone();
240 let worker_id = worker.name().to_owned();
241 let register_worker = initial_heartbeat(
242 self.pool.clone(),
243 self.config.clone(),
244 worker.clone(),
245 "SharedPostgresStorage",
246 )
247 .map_ok(|_| None);
248 let register = stream::once(register_worker);
249 let lazy_fetcher = self
250 .fetcher
251 .map(|t| t.to_string())
252 .ready_chunks(self.config.buffer_size())
253 .then(move |ids| {
254 let pool = pool.clone();
255 let worker_id = worker_id.clone();
256 async move {
257 let mut tx = pool.begin().await?;
258 let res: Vec<_> = sqlx::query_file_as!(
259 PgTaskRow,
260 "queries/task/queue_by_id.sql",
261 &ids,
262 &worker_id
263 )
264 .fetch(&mut *tx)
265 .map(|r| {
266 let row: TaskRow = r?.try_into()?;
267 Ok(Some(
268 row.try_into_task_compact()
269 .map_err(|e| sqlx::Error::Protocol(e.to_string()))?,
270 ))
271 })
272 .collect()
273 .await;
274 tx.commit().await?;
275 Ok::<_, sqlx::Error>(res)
276 }
277 })
278 .flat_map(|vec| match vec {
279 Ok(vec) => stream::iter(vec.into_iter().map(|res| match res {
280 Ok(t) => Ok(t),
281 Err(e) => Err(e),
282 }))
283 .boxed(),
284 Err(e) => stream::once(ready(Err(e))).boxed(),
285 })
286 .boxed();
287 let eager_fetcher = StreamExt::boxed(PgPollFetcher::<CompactType>::new(
288 &self.pool,
289 &self.config,
290 worker,
291 ));
292 register.chain(select(lazy_fetcher, eager_fetcher))
293 }
294}
295
296#[cfg(test)]
297mod tests {
298 use std::time::Duration;
299
300 use apalis_core::{backend::TaskSink, error::BoxDynError, worker::builder::WorkerBuilder};
301
302 use super::*;
303
304 #[tokio::test]
305 async fn basic_worker() {
306 let pool = PgPool::connect(std::env::var("DATABASE_URL").unwrap().as_str())
307 .await
308 .unwrap();
309 let mut store = SharedPostgresStorage::new(pool);
310
311 let mut map_store = store.make_shared().unwrap();
312
313 let mut int_store = store.make_shared().unwrap();
314
315 map_store
316 .push_stream(&mut stream::iter(vec![HashMap::<String, String>::new()]))
317 .await
318 .unwrap();
319 int_store.push(99).await.unwrap();
320
321 async fn send_reminder<T>(
322 _: T,
323 _task_id: PgTaskId,
324 wrk: WorkerContext,
325 ) -> Result<(), BoxDynError> {
326 tokio::time::sleep(Duration::from_secs(2)).await;
327 wrk.stop().unwrap();
328 Ok(())
329 }
330
331 let int_worker = WorkerBuilder::new("rango-tango-2")
332 .backend(int_store)
333 .build(send_reminder);
334 let map_worker = WorkerBuilder::new("rango-tango-1")
335 .backend(map_store)
336 .build(send_reminder);
337 tokio::try_join!(int_worker.run(), map_worker.run()).unwrap();
338 }
339}