1use std::{
2 collections::HashMap,
3 future::ready,
4 marker::PhantomData,
5 pin::Pin,
6 sync::Arc,
7 task::{Context, Poll},
8};
9
10use crate::{
11 CompactType, Config, InsertEvent, PgTask, PostgresStorage,
12 ack::{LockTaskLayer, PgAck},
13 context::PgContext,
14 fetcher::PgPollFetcher,
15 queries::{
16 keep_alive::{initial_heartbeat, keep_alive_stream},
17 reenqueue_orphaned::reenqueue_orphaned_stream,
18 },
19};
20use crate::{from_row::PgTaskRow, sink::PgSink};
21use apalis_core::{
22 backend::{
23 Backend, TaskStream,
24 codec::{Codec, json::JsonCodec},
25 shared::MakeShared,
26 },
27 layers::Stack,
28 task::task_id::TaskId,
29 worker::{context::WorkerContext, ext::ack::AcknowledgeLayer},
30};
31use apalis_sql::from_row::TaskRow;
32use futures::{
33 FutureExt, SinkExt, Stream, StreamExt, TryStreamExt,
34 channel::mpsc::{self, Receiver, Sender},
35 future::{BoxFuture, Shared},
36 lock::Mutex,
37 stream::{self, BoxStream, select},
38};
39use sqlx::{PgPool, postgres::PgListener};
40use ulid::Ulid;
41
42pub struct SharedPostgresStorage<Compact = CompactType, Codec = JsonCodec<CompactType>> {
43 pool: PgPool,
44 registry: Arc<Mutex<HashMap<String, Sender<TaskId>>>>,
45 drive: Shared<BoxFuture<'static, ()>>,
46 _marker: PhantomData<(Compact, Codec)>,
47}
48
49impl SharedPostgresStorage {
50 pub fn new(pool: PgPool) -> Self {
51 let registry: Arc<Mutex<HashMap<String, Sender<TaskId>>>> =
52 Arc::new(Mutex::new(HashMap::default()));
53 let p = pool.clone();
54 let instances = registry.clone();
55 Self {
56 pool,
57 drive: async move {
58 let mut listener = PgListener::connect_with(&p).await.unwrap();
59 listener.listen("apalis::job::insert").await.unwrap();
60 listener
61 .into_stream()
62 .filter_map(|notification| {
63 let instances = instances.clone();
64 async move {
65 let pg_notification = notification.ok()?;
66 let payload = pg_notification.payload();
67 let ev: InsertEvent = serde_json::from_str(payload).ok()?;
68 let instances = instances.lock().await;
69 if instances.get(&ev.job_type).is_some() {
70 return Some(ev);
71 }
72 None
73 }
74 })
75 .for_each(|ev| {
76 let instances = instances.clone();
77 async move {
78 let mut instances = instances.lock().await;
79 let sender = instances.get_mut(&ev.job_type).unwrap();
80 sender.send(ev.id).await.unwrap();
81 }
82 })
83 .await;
84 }
85 .boxed()
86 .shared(),
87 registry,
88 _marker: PhantomData,
89 }
90 }
91}
92#[derive(Debug, thiserror::Error)]
93pub enum SharedPostgresError {
94 #[error("namespace already exists: {0}")]
96 NamespaceExists(String),
97
98 #[error("registry locked")]
100 RegistryLocked,
101}
102
103impl<Args, Compact, Codec> MakeShared<Args> for SharedPostgresStorage<Compact, Codec> {
104 type Backend = PostgresStorage<Args, Compact, Codec, SharedFetcher>;
105 type Config = Config;
106 type MakeError = SharedPostgresError;
107 fn make_shared(&mut self) -> Result<Self::Backend, Self::MakeError>
108 where
109 Self::Config: Default,
110 {
111 Self::make_shared_with_config(self, Config::new(std::any::type_name::<Args>()))
112 }
113 fn make_shared_with_config(
114 &mut self,
115 config: Self::Config,
116 ) -> Result<Self::Backend, Self::MakeError> {
117 let (tx, rx) = mpsc::channel(config.buffer_size());
118 let mut r = self
119 .registry
120 .try_lock()
121 .ok_or(SharedPostgresError::RegistryLocked)?;
122 if r.insert(config.queue().to_string(), tx).is_some() {
123 return Err(SharedPostgresError::NamespaceExists(
124 config.queue().to_string(),
125 ));
126 }
127 let sink = PgSink::new(&self.pool, &config);
128 Ok(PostgresStorage {
129 _marker: PhantomData,
130 config,
131 fetcher: SharedFetcher {
132 poller: self.drive.clone(),
133 receiver: Arc::new(Mutex::new(rx)),
134 },
135 pool: self.pool.clone(),
136 sink,
137 })
138 }
139}
140
141#[derive(Clone, Debug)]
142pub struct SharedFetcher {
143 poller: Shared<BoxFuture<'static, ()>>,
144 receiver: Arc<Mutex<Receiver<TaskId>>>,
145}
146
147impl Stream for SharedFetcher {
148 type Item = TaskId;
149
150 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
151 let this = self.get_mut();
152 let _ = this.poller.poll_unpin(cx);
154
155 let mut receiver = this.receiver.try_lock();
157 if let Some(ref mut rx) = receiver {
158 rx.poll_next_unpin(cx)
159 } else {
160 Poll::Pending
161 }
162 }
163}
164
165impl<Args, Decode> Backend for PostgresStorage<Args, CompactType, Decode, SharedFetcher>
166where
167 Args: Send + 'static + Unpin,
168 Decode: Codec<Args, Compact = CompactType> + 'static + Unpin + Send,
169 Decode::Error: std::error::Error + Send + Sync + 'static,
170{
171 type Args = Args;
172
173 type Compact = CompactType;
174
175 type IdType = Ulid;
176
177 type Error = sqlx::Error;
178
179 type Stream = TaskStream<PgTask<Args>, Self::Error>;
180
181 type Beat = BoxStream<'static, Result<(), Self::Error>>;
182
183 type Codec = Decode;
184
185 type Context = PgContext;
186
187 type Layer = Stack<LockTaskLayer, AcknowledgeLayer<PgAck>>;
188
189 fn heartbeat(&self, worker: &WorkerContext) -> Self::Beat {
190 let pool = self.pool.clone();
191 let config = self.config.clone();
192 let worker = worker.clone();
193 let keep_alive = keep_alive_stream(pool, config, worker);
194 let reenqueue = reenqueue_orphaned_stream(
195 self.pool.clone(),
196 self.config.clone(),
197 *self.config.keep_alive(),
198 )
199 .map_ok(|_| ());
200 futures::stream::select(keep_alive, reenqueue).boxed()
201 }
202
203 fn middleware(&self) -> Self::Layer {
204 Stack::new(
205 LockTaskLayer::new(self.pool.clone()),
206 AcknowledgeLayer::new(PgAck::new(self.pool.clone())),
207 )
208 }
209
210 fn poll(self, worker: &WorkerContext) -> Self::Stream {
211 let pool = self.pool.clone();
212 let worker_id = worker.name().to_owned();
213 let register_worker = initial_heartbeat(
214 self.pool.clone(),
215 self.config.clone(),
216 worker.clone(),
217 "SharedPostgresStorage",
218 )
219 .map(|_| Ok(None));
220 let register = stream::once(register_worker);
221 let lazy_fetcher = self
222 .fetcher
223 .map(|t| t.to_string())
224 .ready_chunks(self.config.buffer_size())
225 .then(move |ids| {
226 let pool = pool.clone();
227 let worker_id = worker_id.clone();
228 async move {
229 let mut tx = pool.begin().await?;
230 let res: Vec<_> = sqlx::query_file_as!(
231 PgTaskRow,
232 "queries/task/lock_by_id.sql",
233 &ids,
234 &worker_id
235 )
236 .fetch(&mut *tx)
237 .map(|r| {
238 let row: TaskRow = r?.try_into()?;
239 Ok(Some(
240 row.try_into_task::<Decode, Args, Ulid>()
241 .map_err(|e| sqlx::Error::Protocol(e.to_string()))?,
242 ))
243 })
244 .collect()
245 .await;
246 tx.commit().await?;
247 Ok::<_, sqlx::Error>(res)
248 }
249 })
250 .flat_map(|vec| match vec {
251 Ok(vec) => stream::iter(vec.into_iter().map(|res| match res {
252 Ok(t) => Ok(t),
253 Err(e) => Err(e),
254 }))
255 .boxed(),
256 Err(e) => stream::once(ready(Err(e))).boxed(),
257 })
258 .boxed();
259
260 let eager_fetcher = StreamExt::boxed(PgPollFetcher::<Args, CompactType, Decode>::new(
261 &self.pool,
262 &self.config,
263 worker,
264 ));
265 register.chain(select(lazy_fetcher, eager_fetcher)).boxed()
266 }
267}
268
269#[cfg(test)]
270mod tests {
271 use std::time::Duration;
272
273 use apalis_core::{backend::TaskSink, error::BoxDynError, worker::builder::WorkerBuilder};
274
275 use super::*;
276
277 #[tokio::test]
278 async fn basic_worker() {
279 let pool = PgPool::connect("postgres://postgres:postgres@localhost/apalis_dev")
280 .await
281 .unwrap();
282 let mut store = SharedPostgresStorage::new(pool);
283
284 let mut map_store = store.make_shared().unwrap();
285
286 let mut int_store = store.make_shared().unwrap();
287
288 map_store
289 .push_stream(&mut stream::iter(vec![HashMap::<String, String>::new()]))
290 .await
291 .unwrap();
292 int_store.push(99).await.unwrap();
293
294 async fn send_reminder<T, I>(
295 _: T,
296 _task_id: TaskId<I>,
297 wrk: WorkerContext,
298 ) -> Result<(), BoxDynError> {
299 tokio::time::sleep(Duration::from_secs(2)).await;
300 wrk.stop().unwrap();
301 Ok(())
302 }
303
304 let int_worker = WorkerBuilder::new("rango-tango-2")
305 .backend(int_store)
306 .build(send_reminder);
307 let map_worker = WorkerBuilder::new("rango-tango-1")
308 .backend(map_store)
309 .build(send_reminder);
310 tokio::try_join!(int_worker.run(), map_worker.run()).unwrap();
311 }
312}