1use std::{
2 collections::HashMap,
3 future::ready,
4 marker::PhantomData,
5 pin::Pin,
6 sync::Arc,
7 task::{Context, Poll},
8};
9
10use crate::{
11 CompactType, Config, InsertEvent, PgTask, PostgresStorage,
12 ack::{LockTaskLayer, PgAck},
13 context::PgContext,
14 fetcher::PgPollFetcher,
15 queries::{
16 keep_alive::{initial_heartbeat, keep_alive_stream},
17 reenqueue_orphaned::reenqueue_orphaned_stream,
18 },
19};
20use crate::{from_row::PgTaskRow, sink::PgSink};
21use apalis_core::{
22 backend::{
23 Backend, TaskStream,
24 codec::{Codec, json::JsonCodec},
25 shared::MakeShared,
26 },
27 layers::Stack,
28 task::task_id::TaskId,
29 worker::{context::WorkerContext, ext::ack::AcknowledgeLayer},
30};
31use apalis_sql::from_row::TaskRow;
32use futures::{
33 FutureExt, SinkExt, Stream, StreamExt, TryStreamExt,
34 channel::mpsc::{self, Receiver, Sender},
35 future::{BoxFuture, Shared},
36 lock::Mutex,
37 stream::{self, BoxStream, select},
38};
39use sqlx::{PgPool, postgres::PgListener};
40use ulid::Ulid;
41
42pub struct SharedPostgresStorage<Compact = CompactType, Codec = JsonCodec<CompactType>> {
43 pool: PgPool,
44 registry: Arc<Mutex<HashMap<String, Sender<TaskId>>>>,
45 drive: Shared<BoxFuture<'static, ()>>,
46 _marker: PhantomData<(Compact, Codec)>,
47}
48
49impl SharedPostgresStorage {
50 pub fn new(pool: PgPool) -> Self {
51 let registry: Arc<Mutex<HashMap<String, Sender<TaskId>>>> =
52 Arc::new(Mutex::new(HashMap::default()));
53 let p = pool.clone();
54 let instances = registry.clone();
55 Self {
56 pool,
57 drive: async move {
58 let mut listener = PgListener::connect_with(&p).await.unwrap();
59 listener.listen("apalis::job::insert").await.unwrap();
60 listener
61 .into_stream()
62 .filter_map(|notification| {
63 let instances = instances.clone();
64 async move {
65 let pg_notification = notification.ok()?;
66 let payload = pg_notification.payload();
67 let ev: InsertEvent = serde_json::from_str(payload).ok()?;
68 let instances = instances.lock().await;
69 if instances.get(&ev.job_type).is_some() {
70 return Some(ev);
71 }
72 None
73 }
74 })
75 .for_each(|ev| {
76 let instances = instances.clone();
77 async move {
78 let mut instances = instances.lock().await;
79 let sender = instances.get_mut(&ev.job_type).unwrap();
80 sender.send(ev.id).await.unwrap();
81 }
82 })
83 .await;
84 }
85 .boxed()
86 .shared(),
87 registry,
88 _marker: PhantomData,
89 }
90 }
91}
92#[derive(Debug, thiserror::Error)]
93pub enum SharedPostgresError {
94 #[error("namespace already exists: {0}")]
96 NamespaceExists(String),
97
98 #[error("registry locked")]
100 RegistryLocked,
101}
102
103impl<Args, Compact, Codec> MakeShared<Args> for SharedPostgresStorage<Compact, Codec> {
104 type Backend = PostgresStorage<Args, Compact, Codec, SharedFetcher>;
105 type Config = Config;
106 type MakeError = SharedPostgresError;
107 fn make_shared(&mut self) -> Result<Self::Backend, Self::MakeError>
108 where
109 Self::Config: Default,
110 {
111 Self::make_shared_with_config(self, Config::new(std::any::type_name::<Args>()))
112 }
113 fn make_shared_with_config(
114 &mut self,
115 config: Self::Config,
116 ) -> Result<Self::Backend, Self::MakeError> {
117 let (tx, rx) = mpsc::channel(config.buffer_size());
118 let mut r = self
119 .registry
120 .try_lock()
121 .ok_or(SharedPostgresError::RegistryLocked)?;
122 if r.insert(config.queue().to_string(), tx).is_some() {
123 return Err(SharedPostgresError::NamespaceExists(
124 config.queue().to_string(),
125 ));
126 }
127 let sink = PgSink::new(&self.pool, &config);
128 Ok(PostgresStorage {
129 _marker: PhantomData,
130 config,
131 fetcher: SharedFetcher {
132 poller: self.drive.clone(),
133 receiver: rx,
134 },
135 pool: self.pool.clone(),
136 sink,
137 })
138 }
139}
140
141pub struct SharedFetcher {
142 poller: Shared<BoxFuture<'static, ()>>,
143 receiver: Receiver<TaskId>,
144}
145
146impl Stream for SharedFetcher {
147 type Item = TaskId;
148
149 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
150 let this = self.get_mut();
151 let _ = this.poller.poll_unpin(cx);
153
154 this.receiver.poll_next_unpin(cx)
156 }
157}
158
159impl<Args, Decode> Backend for PostgresStorage<Args, CompactType, Decode, SharedFetcher>
160where
161 Args: Send + 'static + Unpin,
162 Decode: Codec<Args, Compact = CompactType> + 'static + Unpin + Send,
163 Decode::Error: std::error::Error + Send + Sync + 'static,
164{
165 type Args = Args;
166
167 type Compact = CompactType;
168
169 type IdType = Ulid;
170
171 type Error = sqlx::Error;
172
173 type Stream = TaskStream<PgTask<Args>, Self::Error>;
174
175 type Beat = BoxStream<'static, Result<(), Self::Error>>;
176
177 type Codec = Decode;
178
179 type Context = PgContext;
180
181 type Layer = Stack<LockTaskLayer, AcknowledgeLayer<PgAck>>;
182
183 fn heartbeat(&self, worker: &WorkerContext) -> Self::Beat {
184 let pool = self.pool.clone();
185 let config = self.config.clone();
186 let worker = worker.clone();
187 let keep_alive = keep_alive_stream(pool, config, worker);
188 let reenqueue = reenqueue_orphaned_stream(
189 self.pool.clone(),
190 self.config.clone(),
191 *self.config.keep_alive(),
192 )
193 .map_ok(|_| ());
194 futures::stream::select(keep_alive, reenqueue).boxed()
195 }
196
197 fn middleware(&self) -> Self::Layer {
198 Stack::new(
199 LockTaskLayer::new(self.pool.clone()),
200 AcknowledgeLayer::new(PgAck::new(self.pool.clone())),
201 )
202 }
203
204 fn poll(self, worker: &WorkerContext) -> Self::Stream {
205 let pool = self.pool.clone();
206 let worker_id = worker.name().to_owned();
207 let register_worker = initial_heartbeat(
208 self.pool.clone(),
209 self.config.clone(),
210 worker.clone(),
211 "SharedPostgresStorage",
212 )
213 .map(|_| Ok(None));
214 let register = stream::once(register_worker);
215 let lazy_fetcher = self
216 .fetcher
217 .map(|t| t.to_string())
218 .ready_chunks(self.config.buffer_size())
219 .then(move |ids| {
220 let pool = pool.clone();
221 let worker_id = worker_id.clone();
222 async move {
223 let mut tx = pool.begin().await?;
224 let res: Vec<_> = sqlx::query_file_as!(
225 PgTaskRow,
226 "queries/task/lock_by_id.sql",
227 &ids,
228 &worker_id
229 )
230 .fetch(&mut *tx)
231 .map(|r| {
232 let row: TaskRow = r?.try_into()?;
233 Ok(Some(
234 row.try_into_task::<Decode, Args, Ulid>()
235 .map_err(|e| sqlx::Error::Protocol(e.to_string()))?,
236 ))
237 })
238 .collect()
239 .await;
240 tx.commit().await?;
241 Ok::<_, sqlx::Error>(res)
242 }
243 })
244 .flat_map(|vec| match vec {
245 Ok(vec) => stream::iter(vec.into_iter().map(|res| match res {
246 Ok(t) => Ok(t),
247 Err(e) => Err(e),
248 }))
249 .boxed(),
250 Err(e) => stream::once(ready(Err(e))).boxed(),
251 })
252 .boxed();
253
254 let eager_fetcher = StreamExt::boxed(PgPollFetcher::<Args, CompactType, Decode>::new(
255 &self.pool,
256 &self.config,
257 worker,
258 ));
259 register.chain(select(lazy_fetcher, eager_fetcher)).boxed()
260 }
261}
262
263#[cfg(test)]
264mod tests {
265 use std::time::Duration;
266
267 use apalis_core::{backend::TaskSink, error::BoxDynError, worker::builder::WorkerBuilder};
268
269 use super::*;
270
271 #[tokio::test]
272 async fn basic_worker() {
273 let pool = PgPool::connect("postgres://postgres:postgres@localhost/apalis_dev")
274 .await
275 .unwrap();
276 let mut store = SharedPostgresStorage::new(pool);
277
278 let mut map_store = store.make_shared().unwrap();
279
280 let mut int_store = store.make_shared().unwrap();
281
282 map_store
283 .push_stream(&mut stream::iter(vec![HashMap::<String, String>::new()]))
284 .await
285 .unwrap();
286 int_store.push(99).await.unwrap();
287
288 async fn send_reminder<T, I>(
289 _: T,
290 _task_id: TaskId<I>,
291 wrk: WorkerContext,
292 ) -> Result<(), BoxDynError> {
293 tokio::time::sleep(Duration::from_secs(2)).await;
294 wrk.stop().unwrap();
295 Ok(())
296 }
297
298 let int_worker = WorkerBuilder::new("rango-tango-2")
299 .backend(int_store)
300 .build(send_reminder);
301 let map_worker = WorkerBuilder::new("rango-tango-1")
302 .backend(map_store)
303 .build(send_reminder);
304 tokio::try_join!(int_worker.run(), map_worker.run()).unwrap();
305 }
306}