1use std::{
2 collections::HashMap,
3 marker::PhantomData,
4 pin::Pin,
5 sync::{
6 Arc, Mutex,
7 atomic::{AtomicBool, Ordering},
8 },
9 task::{Context, Poll},
10};
11
12type RegistrySender = Sender<Result<PgTaskId, Error>>;
19
20use apalis_codec::json::JsonCodec;
21use apalis_core::{backend::shared::MakeShared, worker::context::WorkerContext};
22use diesel::RunQueryDsl;
23use futures::{
24 Stream, StreamExt, TryFutureExt,
25 channel::mpsc::{self, Receiver, Sender},
26};
27use ulid::Ulid;
28
29use crate::{
30 CompactType, Config, Error, PgPool, PgTask, PgTaskId, PostgresStorage, fetcher::PgPollFetcher,
31 queries, sink::PgSink,
32};
33
34type RegistryEntry = (Ulid, RegistrySender);
41type RegistryMap = HashMap<String, Vec<RegistryEntry>>;
42type SharedRegistry = Arc<Mutex<RegistryMap>>;
43
44pub struct SharedPostgresStorage<Codec = JsonCodec<CompactType>> {
51 pool: PgPool,
52 registry: SharedRegistry,
53 listener_alive: Arc<AtomicBool>,
62 _marker: PhantomData<Codec>,
63}
64
65impl<Codec> SharedPostgresStorage<Codec> {
66 #[must_use]
68 pub fn new(pool: PgPool) -> Self {
69 let registry: SharedRegistry = Arc::new(Mutex::new(HashMap::new()));
70 Self {
71 pool,
72 registry,
73 listener_alive: Arc::new(AtomicBool::new(false)),
74 _marker: PhantomData,
75 }
76 }
77
78 fn spawn_registry_listener(&self) {
79 let pool = self.pool.clone();
80 let registry = self.registry.clone();
81 let listener_alive = self.listener_alive.clone();
82 if let Err(error) = std::thread::Builder::new()
83 .name("apalis-postgres-shared-listener".to_owned())
84 .spawn(move || {
85 let mut conn = match pool.get() {
86 Ok(conn) => conn,
87 Err(error) => {
88 exit_listener(
89 ®istry,
90 &listener_alive,
91 Some(format!(
92 "failed to get pooled connection for shared LISTEN: {error}"
93 )),
94 );
95 return;
96 }
97 };
98 if let Err(error) =
99 diesel::sql_query("LISTEN \"apalis::job::insert\"").execute(&mut conn)
100 {
101 exit_listener(
102 ®istry,
103 &listener_alive,
104 Some(format!("failed to start shared LISTEN listener: {error}")),
105 );
106 return;
107 }
108 loop {
109 for notification in conn.notifications_iter() {
110 let notification = match notification {
111 Ok(notification) => notification,
112 Err(error) => {
113 exit_listener(
114 ®istry,
115 &listener_alive,
116 Some(format!("failed to receive shared notification: {error}")),
117 );
118 return;
119 }
120 };
121 let Ok(event) =
122 serde_json::from_str::<crate::InsertEvent>(¬ification.payload)
123 else {
124 continue;
125 };
126 let (event_queue, ids) = event.into_ids();
127 let Ok(mut registry) = registry.lock() else {
128 listener_alive.store(false, Ordering::Release);
131 return;
132 };
133 if let Some(senders) = registry.get_mut(&event_queue) {
134 for id in ids {
139 senders.retain_mut(|(_, sender)| {
140 match sender.try_send(Ok(id)) {
141 Ok(()) => true,
142 Err(error) if error.is_disconnected() => false,
143 Err(_) => true,
148 }
149 });
150 }
151 if senders.is_empty() {
152 registry.remove(&event_queue);
153 }
154 }
155 }
156 match registry.lock() {
157 Ok(registry) if registry.is_empty() => {
158 listener_alive.store(false, Ordering::Release);
166 drop(registry);
167 return;
168 }
169 Ok(_) => {}
170 Err(_) => {
171 listener_alive.store(false, Ordering::Release);
173 return;
174 }
175 }
176 std::thread::sleep(queries::NOTIFY_LISTENER_POLL_INTERVAL);
177 }
178 })
179 {
180 exit_listener(
181 &self.registry,
182 &self.listener_alive,
183 Some(format!("failed to spawn listener: {error}")),
184 );
185 }
186 }
187}
188
189fn exit_listener(registry: &SharedRegistry, listener_alive: &AtomicBool, error: Option<String>) {
198 match registry.lock() {
199 Ok(mut guard) => {
200 if let Some(message) = error {
201 broadcast_notify_error_locked(&mut guard, message);
202 }
203 listener_alive.store(false, Ordering::Release);
204 drop(guard);
205 }
206 Err(_) => {
207 listener_alive.store(false, Ordering::Release);
210 }
211 }
212}
213
214#[cfg(test)]
215fn broadcast_notify_error(registry: &SharedRegistry, message: String) {
216 let Ok(mut guard) = registry.lock() else {
217 return;
218 };
219 broadcast_notify_error_locked(&mut guard, message);
220}
221
222fn broadcast_notify_error_locked(registry: &mut RegistryMap, message: String) {
223 registry.retain(|_, senders| {
224 senders.retain_mut(|(_, sender)| {
225 match sender.try_send(Err(Error::NotifyListener(message.clone()))) {
226 Ok(()) => true,
227 Err(error) => !error.is_disconnected(),
228 }
229 });
230 !senders.is_empty()
231 });
232}
233
234impl<Codec> std::fmt::Debug for SharedPostgresStorage<Codec> {
235 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
236 f.debug_struct("SharedPostgresStorage")
237 .finish_non_exhaustive()
238 }
239}
240
241#[derive(Debug, thiserror::Error)]
243#[non_exhaustive]
244pub enum SharedPostgresError {
245 #[error("registry lock poisoned")]
247 RegistryLocked,
248}
249
250impl<Args, Codec> MakeShared<Args> for SharedPostgresStorage<Codec> {
251 type Backend = PostgresStorage<Args, Codec, SharedFetcher>;
252 type Config = Config;
253 type MakeError = SharedPostgresError;
254
255 fn make_shared(&mut self) -> Result<Self::Backend, Self::MakeError>
256 where
257 Self::Config: Default,
258 {
259 self.make_shared_with_config(Config::new(std::any::type_name::<Args>()))
260 }
261
262 fn make_shared_with_config(
263 &mut self,
264 config: Self::Config,
265 ) -> Result<Self::Backend, Self::MakeError> {
266 let (sender, receiver) = mpsc::channel(
267 config
268 .buffer_size()
269 .clamp(1, crate::queries::NOTIFY_CHANNEL_CAPACITY_MAX),
270 );
271 let mut registry = self
272 .registry
273 .lock()
274 .map_err(|_| SharedPostgresError::RegistryLocked)?;
275 let queue = config.queue().to_string();
276 let registration_id = Ulid::new();
290 registry
291 .entry(queue)
292 .or_default()
293 .push((registration_id, sender));
294 let should_spawn_listener = !self.listener_alive.swap(true, Ordering::AcqRel);
295 drop(registry);
296
297 if should_spawn_listener {
298 self.spawn_registry_listener();
299 }
300
301 let registration = Arc::new(SharedRegistration {
302 id: registration_id,
303 queue: config.queue().to_string(),
304 registry: self.registry.clone(),
305 pool: self.pool.clone(),
306 });
307
308 Ok(PostgresStorage {
309 _marker: PhantomData,
310 sink: PgSink::new(&self.pool, &config),
311 pool: self.pool.clone(),
312 config,
313 fetcher: SharedFetcher {
314 receiver,
315 _registration: registration,
316 },
317 lease_token: crate::queries::worker::mint_lease_token().into(),
318 })
319 }
320}
321
322struct SharedRegistration {
323 id: Ulid,
328 queue: String,
329 registry: SharedRegistry,
330 pool: PgPool,
331}
332
333impl std::fmt::Debug for SharedRegistration {
334 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
335 f.debug_struct("SharedRegistration")
336 .field("queue", &self.queue)
337 .finish_non_exhaustive()
338 }
339}
340
341impl Drop for SharedRegistration {
342 fn drop(&mut self) {
343 let became_empty = match self.registry.lock() {
344 Ok(mut registry) => {
345 if let Some(senders) = registry.get_mut(&self.queue) {
349 senders.retain(|(id, _)| *id != self.id);
350 if senders.is_empty() {
351 registry.remove(&self.queue);
352 }
353 }
354 registry.is_empty()
355 }
356 Err(_) => false,
357 };
358 if became_empty {
365 let pool = self.pool.clone();
368 let _ = std::thread::Builder::new()
369 .name("apalis-postgres-shared-drop".to_owned())
370 .spawn(move || {
371 if let Ok(mut conn) = pool.get() {
372 let _ = diesel::sql_query("SELECT pg_notify('apalis::job::insert', '')")
373 .execute(&mut conn);
374 }
375 });
376 }
377 }
378}
379
380pub struct SharedFetcher {
391 receiver: Receiver<Result<PgTaskId, Error>>,
392 _registration: Arc<SharedRegistration>,
393}
394
395impl std::fmt::Debug for SharedFetcher {
396 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
397 f.debug_struct("SharedFetcher").finish_non_exhaustive()
398 }
399}
400
401impl Stream for SharedFetcher {
402 type Item = Result<PgTaskId, Error>;
403
404 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
405 Pin::new(&mut self.get_mut().receiver).poll_next(cx)
406 }
407}
408
409impl crate::fetcher::PgFetcherSource for SharedFetcher {
410 const STORAGE_NAME: &'static str = "SharedPostgresStorage";
411
412 fn into_compact_stream(
413 self,
414 pool: PgPool,
415 config: Config,
416 worker: WorkerContext,
417 lease_token: std::sync::Arc<str>,
418 ) -> apalis_core::backend::TaskStream<PgTask<CompactType>, Error> {
419 let register_worker = queries::initial_heartbeat(
420 pool.clone(),
421 config.clone(),
422 worker.clone(),
423 Self::STORAGE_NAME,
424 lease_token,
425 )
426 .map_ok(|_| None);
427
428 let lazy_fetcher = queries::batch_ids_into_tasks(
429 pool.clone(),
430 config.queue().to_string(),
431 worker.name().to_owned(),
432 config.buffer_size().max(1),
433 self,
434 )
435 .boxed();
436
437 let eager_fetcher = PgPollFetcher::<CompactType>::new(&pool, &config, &worker);
438 let combined = futures::stream::select(lazy_fetcher, eager_fetcher);
439 crate::fetcher::register_then_stream(register_worker, combined)
440 }
441}
442
443#[cfg(test)]
444mod tests {
445 use apalis_core::backend::{Backend, BackendExt, shared::MakeShared};
446 use diesel::{
447 PgConnection,
448 r2d2::{ConnectionManager, Pool},
449 };
450 use lets_expect::{AssertionError, AssertionResult, *};
451
452 use super::*;
453
454 struct SharedObservation {
455 queue: String,
456 buffer_size: usize,
457 debug: String,
458 }
459
460 fn unchecked_pool() -> PgPool {
461 let manager = ConnectionManager::<PgConnection>::new("postgres://127.0.0.1:1/not-used");
462 Pool::builder()
463 .max_size(1)
464 .connection_timeout(std::time::Duration::from_millis(10))
465 .build_unchecked(manager)
466 }
467
468 fn shared_debug() -> String {
469 let shared: SharedPostgresStorage = SharedPostgresStorage::new(unchecked_pool());
470 format!("{shared:?}")
471 }
472
473 fn make_default_shared() -> Result<SharedObservation, SharedPostgresError> {
474 let mut shared: SharedPostgresStorage = SharedPostgresStorage::new(unchecked_pool());
475 let storage = <SharedPostgresStorage as MakeShared<String>>::make_shared(&mut shared)?;
476 Ok(SharedObservation {
477 queue: storage.config.queue().to_string(),
478 buffer_size: storage.config.buffer_size(),
479 debug: format!("{storage:?}"),
480 })
481 }
482
483 fn make_configured_shared() -> Result<SharedObservation, SharedPostgresError> {
484 let mut shared: SharedPostgresStorage = SharedPostgresStorage::new(unchecked_pool());
485 let config = Config::new("shared-unit").set_buffer_size(3);
486 let storage = <SharedPostgresStorage as MakeShared<String>>::make_shared_with_config(
487 &mut shared,
488 config,
489 )?;
490 Ok(SharedObservation {
491 queue: storage.get_queue().to_string(),
492 buffer_size: storage.config.buffer_size(),
493 debug: format!("{:?}", storage.fetcher),
494 })
495 }
496
497 fn shared_trait_surfaces() -> Result<(String, String), SharedPostgresError> {
498 let mut shared: SharedPostgresStorage = SharedPostgresStorage::new(unchecked_pool());
499 let config = Config::new("shared-traits");
500 let storage = <SharedPostgresStorage as MakeShared<String>>::make_shared_with_config(
501 &mut shared,
502 config,
503 )?;
504 let worker = WorkerContext::new::<()>("shared-trait-worker");
505 let middleware_name = std::any::type_name_of_val(&storage.middleware()).to_owned();
506 let stream_name = std::any::type_name_of_val(&storage.poll_compact(&worker)).to_owned();
507 Ok((middleware_name, stream_name))
508 }
509
510 fn registration_debug_and_drop() -> (String, bool) {
511 let registry: SharedRegistry = Arc::new(Mutex::new(HashMap::new()));
512 let (sender, _receiver) = mpsc::channel(1);
513 let id = Ulid::new();
514 registry
515 .lock()
516 .expect("fresh shared registry is not poisoned")
517 .insert("shared-registration".to_owned(), vec![(id, sender)]);
518
519 let debug = {
520 let registration = SharedRegistration {
521 id,
522 queue: "shared-registration".to_owned(),
523 registry: registry.clone(),
524 pool: unchecked_pool(),
525 };
526 format!("{registration:?}")
527 };
528
529 let removed = registry
530 .lock()
531 .expect("fresh shared registry is not poisoned")
532 .is_empty();
533 (debug, removed)
534 }
535
536 fn drop_leaves_remaining(target_queue: &str, sibling_queues: &[&str]) -> usize {
542 let registry: SharedRegistry = Arc::new(Mutex::new(HashMap::new()));
543 let target_id = Ulid::new();
544 {
545 let mut reg = registry
546 .lock()
547 .expect("fresh shared registry is not poisoned");
548 let (sender, _r) = mpsc::channel(1);
549 reg.insert(target_queue.to_owned(), vec![(target_id, sender)]);
550 for sibling in sibling_queues {
551 let (sender, _r) = mpsc::channel(1);
552 reg.insert((*sibling).to_owned(), vec![(Ulid::new(), sender)]);
553 }
554 }
555
556 {
557 let registration = SharedRegistration {
558 id: target_id,
559 queue: target_queue.to_owned(),
560 registry: registry.clone(),
561 pool: unchecked_pool(),
562 };
563 drop(registration);
564 }
565
566 registry
567 .lock()
568 .expect("fresh shared registry is not poisoned")
569 .len()
570 }
571
572 fn drop_when_registry_empties() -> usize {
573 drop_leaves_remaining("shared-only", &[])
574 }
575
576 fn drop_when_registry_has_siblings() -> usize {
577 drop_leaves_remaining("shared-target", &["shared-other-a", "shared-other-b"])
578 }
579
580 fn drop_one_of_two_keeps_sibling_sender() -> usize {
585 let registry: SharedRegistry = Arc::new(Mutex::new(HashMap::new()));
586 let queue = "shared-coexist".to_owned();
587 let first_id = Ulid::new();
588 let second_id = Ulid::new();
589 let (first_sender, _first_rx) = mpsc::channel(1);
590 let (second_sender, _second_rx) = mpsc::channel(1);
591 registry
592 .lock()
593 .expect("fresh registry is not poisoned")
594 .insert(
595 queue.clone(),
596 vec![(first_id, first_sender), (second_id, second_sender)],
597 );
598
599 drop(SharedRegistration {
600 id: first_id,
601 queue: queue.clone(),
602 registry: registry.clone(),
603 pool: unchecked_pool(),
604 });
605
606 let guard = registry.lock().expect("registry is not poisoned");
607 guard.get(&queue).map(Vec::len).unwrap_or(0)
608 }
609
610 fn double_make_shared_same_queue() -> Result<(), SharedPostgresError> {
614 let mut shared: SharedPostgresStorage = SharedPostgresStorage::new(unchecked_pool());
615 let config = Config::new("double-make-shared");
616 let _first = <SharedPostgresStorage as MakeShared<String>>::make_shared_with_config(
617 &mut shared,
618 config.clone(),
619 )?;
620 let _second = <SharedPostgresStorage as MakeShared<String>>::make_shared_with_config(
621 &mut shared,
622 config,
623 )?;
624 Ok(())
625 }
626
627 fn broadcast_notify_error_observation() -> (usize, usize) {
633 let registry: SharedRegistry = Arc::new(Mutex::new(HashMap::new()));
634 let (alive_sender, _alive_receiver) = mpsc::channel(1);
635 let (dead_sender, dead_receiver) = mpsc::channel::<Result<PgTaskId, Error>>(1);
636 drop(dead_receiver);
637 {
638 let mut reg = registry.lock().expect("fresh registry is not poisoned");
639 reg.insert("alive".to_owned(), vec![(Ulid::new(), alive_sender)]);
640 reg.insert("dead".to_owned(), vec![(Ulid::new(), dead_sender)]);
641 }
642
643 let initial = registry.lock().expect("registry is not poisoned").len();
644 broadcast_notify_error(®istry, "synthetic listener failure".to_owned());
645 let retained = registry.lock().expect("registry is not poisoned").len();
646 (retained, initial)
647 }
648
649 fn debug_mentions_type(expected: &'static str) -> impl Fn(&String) -> AssertionResult {
654 move |debug| {
655 if debug.contains(expected) {
656 Ok(())
657 } else {
658 Err(AssertionError::new(vec![format!(
659 "expected debug output containing {expected:?}, got {debug}"
660 )]))
661 }
662 }
663 }
664
665 fn uses_default_queue(result: &SharedObservation) -> AssertionResult {
666 if result.queue == std::any::type_name::<String>()
667 && result.buffer_size == 10
668 && result.debug.contains("SharedFetcher")
669 {
670 Ok(())
671 } else {
672 Err(AssertionError::new(vec![format!(
673 "unexpected default shared storage: queue={:?}, buffer={}, debug={}",
674 result.queue, result.buffer_size, result.debug
675 )]))
676 }
677 }
678
679 fn uses_configured_queue(result: &SharedObservation) -> AssertionResult {
680 if result.queue == "shared-unit"
681 && result.buffer_size == 3
682 && result.debug.contains("SharedFetcher")
683 {
684 Ok(())
685 } else {
686 Err(AssertionError::new(vec![format!(
687 "unexpected configured shared storage: queue={:?}, buffer={}, debug={}",
688 result.queue, result.buffer_size, result.debug
689 )]))
690 }
691 }
692
693 fn constructs_backend_traits(result: &(String, String)) -> AssertionResult {
694 if result.0.contains("PgMiddleware") && result.1.contains("Stream") {
695 Ok(())
696 } else {
697 Err(AssertionError::new(vec![format!(
698 "unexpected shared trait surfaces: {result:?}"
699 )]))
700 }
701 }
702
703 fn removes_registration(result: &(String, bool)) -> AssertionResult {
704 if result.0.contains("SharedRegistration") && result.1 {
705 Ok(())
706 } else {
707 Err(AssertionError::new(vec![format!(
708 "expected registration debug and drop cleanup, got {result:?}"
709 )]))
710 }
711 }
712
713 fn make_shared_with_poisoned_registry() -> Result<(), SharedPostgresError> {
719 let mut shared: SharedPostgresStorage = SharedPostgresStorage::new(unchecked_pool());
720 let registry = shared.registry.clone();
721 let join = std::thread::spawn(move || {
722 let _guard = registry
723 .lock()
724 .expect("fresh registry lock is not poisoned");
725 panic!("synthetic poisoning panic");
726 });
727 let _ = join.join();
730 let config = Config::new("poisoned-registry");
731 <SharedPostgresStorage as MakeShared<String>>::make_shared_with_config(&mut shared, config)
732 .map(|_| ())
733 }
734
735 fn is_registry_locked(error: &SharedPostgresError) -> AssertionResult {
736 match error {
740 SharedPostgresError::RegistryLocked => Ok(()),
741 }
742 }
743
744 lets_expect! {
745 expect(shared_debug()) {
746 to describes_the_shared_factory { debug_mentions_type("SharedPostgresStorage") }
747 }
748
749 expect(make_default_shared()) {
750 when no_config_is_supplied {
751 to uses_the_task_type_as_the_namespace { be_ok_and uses_default_queue }
752 }
753 }
754
755 expect(make_configured_shared()) {
756 when config_is_supplied {
757 to exposes_the_queue_and_fetcher { be_ok_and uses_configured_queue }
758 }
759 }
760
761 expect(shared_trait_surfaces()) {
762 when backend_traits_are_requested {
763 to builds_middleware_and_compact_stream { be_ok_and constructs_backend_traits }
764 }
765 }
766
767 expect(registration_debug_and_drop()) {
768 when registration_is_dropped {
769 to removes_the_namespace_from_the_registry { removes_registration }
770 }
771 }
772
773 expect(drop_when_registry_empties()) {
774 when dropping_the_last_registration_empties_the_registry {
775 to leaves_no_remaining_registrations { equal(0) }
776 }
777 }
778
779 expect(drop_when_registry_has_siblings()) {
780 when dropping_one_of_several_registrations {
781 to keeps_sibling_registrations_intact { equal(2) }
782 }
783 }
784
785 expect(drop_one_of_two_keeps_sibling_sender()) {
786 when dropping_one_of_two_consumers_on_the_same_queue {
787 to leaves_the_other_senders_sender_in_place { equal(1) }
788 }
789 }
790
791 expect(double_make_shared_same_queue()) {
792 when the_same_queue_is_registered_twice {
793 to accepts_the_second_registration { be_ok }
797 }
798 }
799
800 expect(broadcast_notify_error_observation()) {
801 when listener_broadcasts_an_error_to_a_mixed_registry {
802 to drops_disconnected_senders_and_keeps_live_ones { equal((1_usize, 2_usize)) }
803 }
804 }
805
806 expect(make_shared_with_poisoned_registry()) {
812 when the_registry_mutex_is_poisoned_by_a_panic_in_another_thread {
813 to surfaces_registry_locked_rather_than_panicking_or_succeeding {
818 be_err_and is_registry_locked
819 }
820 }
821 }
822 }
823}