use std::{
collections::HashMap,
marker::PhantomData,
pin::Pin,
sync::{
Arc, Mutex,
atomic::{AtomicBool, Ordering},
},
task::{Context, Poll},
};
type RegistrySender = Sender<Result<PgTaskId, Error>>;
use apalis_codec::json::JsonCodec;
use apalis_core::{backend::shared::MakeShared, worker::context::WorkerContext};
use diesel::RunQueryDsl;
use futures::{
Stream, StreamExt, TryFutureExt,
channel::mpsc::{self, Receiver, Sender},
};
use ulid::Ulid;
use crate::{
CompactType, Config, Error, PgPool, PgTask, PgTaskId, PostgresStorage, fetcher::PgPollFetcher,
queries, sink::PgSink,
};
type RegistryEntry = (Ulid, RegistrySender);
type RegistryMap = HashMap<String, Vec<RegistryEntry>>;
type SharedRegistry = Arc<Mutex<RegistryMap>>;
pub struct SharedPostgresStorage<Codec = JsonCodec<CompactType>> {
pool: PgPool,
registry: SharedRegistry,
listener_alive: Arc<AtomicBool>,
_marker: PhantomData<Codec>,
}
impl<Codec> SharedPostgresStorage<Codec> {
#[must_use]
pub fn new(pool: PgPool) -> Self {
let registry: SharedRegistry = Arc::new(Mutex::new(HashMap::new()));
Self {
pool,
registry,
listener_alive: Arc::new(AtomicBool::new(false)),
_marker: PhantomData,
}
}
fn spawn_registry_listener(&self) {
let pool = self.pool.clone();
let registry = self.registry.clone();
let listener_alive = self.listener_alive.clone();
if let Err(error) = std::thread::Builder::new()
.name("apalis-postgres-shared-listener".to_owned())
.spawn(move || {
let mut conn = match pool.get() {
Ok(conn) => conn,
Err(error) => {
exit_listener(
®istry,
&listener_alive,
Some(format!(
"failed to get pooled connection for shared LISTEN: {error}"
)),
);
return;
}
};
if let Err(error) =
diesel::sql_query("LISTEN \"apalis::job::insert\"").execute(&mut conn)
{
exit_listener(
®istry,
&listener_alive,
Some(format!("failed to start shared LISTEN listener: {error}")),
);
return;
}
loop {
for notification in conn.notifications_iter() {
let notification = match notification {
Ok(notification) => notification,
Err(error) => {
exit_listener(
®istry,
&listener_alive,
Some(format!("failed to receive shared notification: {error}")),
);
return;
}
};
let Ok(event) =
serde_json::from_str::<crate::InsertEvent>(¬ification.payload)
else {
continue;
};
let (event_queue, ids) = event.into_ids();
let Ok(mut registry) = registry.lock() else {
listener_alive.store(false, Ordering::Release);
return;
};
if let Some(senders) = registry.get_mut(&event_queue) {
for id in ids {
senders.retain_mut(|(_, sender)| {
match sender.try_send(Ok(id)) {
Ok(()) => true,
Err(error) if error.is_disconnected() => false,
Err(_) => true,
}
});
}
if senders.is_empty() {
registry.remove(&event_queue);
}
}
}
match registry.lock() {
Ok(registry) if registry.is_empty() => {
listener_alive.store(false, Ordering::Release);
drop(registry);
return;
}
Ok(_) => {}
Err(_) => {
listener_alive.store(false, Ordering::Release);
return;
}
}
std::thread::sleep(queries::NOTIFY_LISTENER_POLL_INTERVAL);
}
})
{
exit_listener(
&self.registry,
&self.listener_alive,
Some(format!("failed to spawn listener: {error}")),
);
}
}
}
fn exit_listener(registry: &SharedRegistry, listener_alive: &AtomicBool, error: Option<String>) {
match registry.lock() {
Ok(mut guard) => {
if let Some(message) = error {
broadcast_notify_error_locked(&mut guard, message);
}
listener_alive.store(false, Ordering::Release);
drop(guard);
}
Err(_) => {
listener_alive.store(false, Ordering::Release);
}
}
}
#[cfg(test)]
fn broadcast_notify_error(registry: &SharedRegistry, message: String) {
let Ok(mut guard) = registry.lock() else {
return;
};
broadcast_notify_error_locked(&mut guard, message);
}
fn broadcast_notify_error_locked(registry: &mut RegistryMap, message: String) {
registry.retain(|_, senders| {
senders.retain_mut(|(_, sender)| {
match sender.try_send(Err(Error::NotifyListener(message.clone()))) {
Ok(()) => true,
Err(error) => !error.is_disconnected(),
}
});
!senders.is_empty()
});
}
impl<Codec> std::fmt::Debug for SharedPostgresStorage<Codec> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SharedPostgresStorage")
.finish_non_exhaustive()
}
}
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum SharedPostgresError {
#[error("registry lock poisoned")]
RegistryLocked,
}
impl<Args, Codec> MakeShared<Args> for SharedPostgresStorage<Codec> {
type Backend = PostgresStorage<Args, Codec, SharedFetcher>;
type Config = Config;
type MakeError = SharedPostgresError;
fn make_shared(&mut self) -> Result<Self::Backend, Self::MakeError>
where
Self::Config: Default,
{
self.make_shared_with_config(Config::new(std::any::type_name::<Args>()))
}
fn make_shared_with_config(
&mut self,
config: Self::Config,
) -> Result<Self::Backend, Self::MakeError> {
let (sender, receiver) = mpsc::channel(
config
.buffer_size()
.clamp(1, crate::queries::NOTIFY_CHANNEL_CAPACITY_MAX),
);
let mut registry = self
.registry
.lock()
.map_err(|_| SharedPostgresError::RegistryLocked)?;
let queue = config.queue().to_string();
let registration_id = Ulid::new();
registry
.entry(queue)
.or_default()
.push((registration_id, sender));
let should_spawn_listener = !self.listener_alive.swap(true, Ordering::AcqRel);
drop(registry);
if should_spawn_listener {
self.spawn_registry_listener();
}
let registration = Arc::new(SharedRegistration {
id: registration_id,
queue: config.queue().to_string(),
registry: self.registry.clone(),
pool: self.pool.clone(),
});
Ok(PostgresStorage {
_marker: PhantomData,
sink: PgSink::new(&self.pool, &config),
pool: self.pool.clone(),
config,
fetcher: SharedFetcher {
receiver,
_registration: registration,
},
lease_token: crate::queries::worker::mint_lease_token().into(),
})
}
}
struct SharedRegistration {
id: Ulid,
queue: String,
registry: SharedRegistry,
pool: PgPool,
}
impl std::fmt::Debug for SharedRegistration {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SharedRegistration")
.field("queue", &self.queue)
.finish_non_exhaustive()
}
}
impl Drop for SharedRegistration {
fn drop(&mut self) {
let became_empty = match self.registry.lock() {
Ok(mut registry) => {
if let Some(senders) = registry.get_mut(&self.queue) {
senders.retain(|(id, _)| *id != self.id);
if senders.is_empty() {
registry.remove(&self.queue);
}
}
registry.is_empty()
}
Err(_) => false,
};
if became_empty {
let pool = self.pool.clone();
let _ = std::thread::Builder::new()
.name("apalis-postgres-shared-drop".to_owned())
.spawn(move || {
if let Ok(mut conn) = pool.get() {
let _ = diesel::sql_query("SELECT pg_notify('apalis::job::insert', '')")
.execute(&mut conn);
}
});
}
}
}
pub struct SharedFetcher {
receiver: Receiver<Result<PgTaskId, Error>>,
_registration: Arc<SharedRegistration>,
}
impl std::fmt::Debug for SharedFetcher {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SharedFetcher").finish_non_exhaustive()
}
}
impl Stream for SharedFetcher {
type Item = Result<PgTaskId, Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Pin::new(&mut self.get_mut().receiver).poll_next(cx)
}
}
impl crate::fetcher::PgFetcherSource for SharedFetcher {
const STORAGE_NAME: &'static str = "SharedPostgresStorage";
fn into_compact_stream(
self,
pool: PgPool,
config: Config,
worker: WorkerContext,
lease_token: std::sync::Arc<str>,
) -> apalis_core::backend::TaskStream<PgTask<CompactType>, Error> {
let register_worker = queries::initial_heartbeat(
pool.clone(),
config.clone(),
worker.clone(),
Self::STORAGE_NAME,
lease_token,
)
.map_ok(|_| None);
let lazy_fetcher = queries::batch_ids_into_tasks(
pool.clone(),
config.queue().to_string(),
worker.name().to_owned(),
config.buffer_size().max(1),
self,
)
.boxed();
let eager_fetcher = PgPollFetcher::<CompactType>::new(&pool, &config, &worker);
let combined = futures::stream::select(lazy_fetcher, eager_fetcher);
crate::fetcher::register_then_stream(register_worker, combined)
}
}
#[cfg(test)]
mod tests {
use apalis_core::backend::{Backend, BackendExt, shared::MakeShared};
use diesel::{
PgConnection,
r2d2::{ConnectionManager, Pool},
};
use lets_expect::{AssertionError, AssertionResult, *};
use super::*;
struct SharedObservation {
queue: String,
buffer_size: usize,
debug: String,
}
fn unchecked_pool() -> PgPool {
let manager = ConnectionManager::<PgConnection>::new("postgres://127.0.0.1:1/not-used");
Pool::builder()
.max_size(1)
.connection_timeout(std::time::Duration::from_millis(10))
.build_unchecked(manager)
}
fn shared_debug() -> String {
let shared: SharedPostgresStorage = SharedPostgresStorage::new(unchecked_pool());
format!("{shared:?}")
}
fn make_default_shared() -> Result<SharedObservation, SharedPostgresError> {
let mut shared: SharedPostgresStorage = SharedPostgresStorage::new(unchecked_pool());
let storage = <SharedPostgresStorage as MakeShared<String>>::make_shared(&mut shared)?;
Ok(SharedObservation {
queue: storage.config.queue().to_string(),
buffer_size: storage.config.buffer_size(),
debug: format!("{storage:?}"),
})
}
fn make_configured_shared() -> Result<SharedObservation, SharedPostgresError> {
let mut shared: SharedPostgresStorage = SharedPostgresStorage::new(unchecked_pool());
let config = Config::new("shared-unit").set_buffer_size(3);
let storage = <SharedPostgresStorage as MakeShared<String>>::make_shared_with_config(
&mut shared,
config,
)?;
Ok(SharedObservation {
queue: storage.get_queue().to_string(),
buffer_size: storage.config.buffer_size(),
debug: format!("{:?}", storage.fetcher),
})
}
fn shared_trait_surfaces() -> Result<(String, String), SharedPostgresError> {
let mut shared: SharedPostgresStorage = SharedPostgresStorage::new(unchecked_pool());
let config = Config::new("shared-traits");
let storage = <SharedPostgresStorage as MakeShared<String>>::make_shared_with_config(
&mut shared,
config,
)?;
let worker = WorkerContext::new::<()>("shared-trait-worker");
let middleware_name = std::any::type_name_of_val(&storage.middleware()).to_owned();
let stream_name = std::any::type_name_of_val(&storage.poll_compact(&worker)).to_owned();
Ok((middleware_name, stream_name))
}
fn registration_debug_and_drop() -> (String, bool) {
let registry: SharedRegistry = Arc::new(Mutex::new(HashMap::new()));
let (sender, _receiver) = mpsc::channel(1);
let id = Ulid::new();
registry
.lock()
.expect("fresh shared registry is not poisoned")
.insert("shared-registration".to_owned(), vec![(id, sender)]);
let debug = {
let registration = SharedRegistration {
id,
queue: "shared-registration".to_owned(),
registry: registry.clone(),
pool: unchecked_pool(),
};
format!("{registration:?}")
};
let removed = registry
.lock()
.expect("fresh shared registry is not poisoned")
.is_empty();
(debug, removed)
}
fn drop_leaves_remaining(target_queue: &str, sibling_queues: &[&str]) -> usize {
let registry: SharedRegistry = Arc::new(Mutex::new(HashMap::new()));
let target_id = Ulid::new();
{
let mut reg = registry
.lock()
.expect("fresh shared registry is not poisoned");
let (sender, _r) = mpsc::channel(1);
reg.insert(target_queue.to_owned(), vec![(target_id, sender)]);
for sibling in sibling_queues {
let (sender, _r) = mpsc::channel(1);
reg.insert((*sibling).to_owned(), vec![(Ulid::new(), sender)]);
}
}
{
let registration = SharedRegistration {
id: target_id,
queue: target_queue.to_owned(),
registry: registry.clone(),
pool: unchecked_pool(),
};
drop(registration);
}
registry
.lock()
.expect("fresh shared registry is not poisoned")
.len()
}
fn drop_when_registry_empties() -> usize {
drop_leaves_remaining("shared-only", &[])
}
fn drop_when_registry_has_siblings() -> usize {
drop_leaves_remaining("shared-target", &["shared-other-a", "shared-other-b"])
}
fn drop_one_of_two_keeps_sibling_sender() -> usize {
let registry: SharedRegistry = Arc::new(Mutex::new(HashMap::new()));
let queue = "shared-coexist".to_owned();
let first_id = Ulid::new();
let second_id = Ulid::new();
let (first_sender, _first_rx) = mpsc::channel(1);
let (second_sender, _second_rx) = mpsc::channel(1);
registry
.lock()
.expect("fresh registry is not poisoned")
.insert(
queue.clone(),
vec![(first_id, first_sender), (second_id, second_sender)],
);
drop(SharedRegistration {
id: first_id,
queue: queue.clone(),
registry: registry.clone(),
pool: unchecked_pool(),
});
let guard = registry.lock().expect("registry is not poisoned");
guard.get(&queue).map(Vec::len).unwrap_or(0)
}
fn double_make_shared_same_queue() -> Result<(), SharedPostgresError> {
let mut shared: SharedPostgresStorage = SharedPostgresStorage::new(unchecked_pool());
let config = Config::new("double-make-shared");
let _first = <SharedPostgresStorage as MakeShared<String>>::make_shared_with_config(
&mut shared,
config.clone(),
)?;
let _second = <SharedPostgresStorage as MakeShared<String>>::make_shared_with_config(
&mut shared,
config,
)?;
Ok(())
}
fn broadcast_notify_error_observation() -> (usize, usize) {
let registry: SharedRegistry = Arc::new(Mutex::new(HashMap::new()));
let (alive_sender, _alive_receiver) = mpsc::channel(1);
let (dead_sender, dead_receiver) = mpsc::channel::<Result<PgTaskId, Error>>(1);
drop(dead_receiver);
{
let mut reg = registry.lock().expect("fresh registry is not poisoned");
reg.insert("alive".to_owned(), vec![(Ulid::new(), alive_sender)]);
reg.insert("dead".to_owned(), vec![(Ulid::new(), dead_sender)]);
}
let initial = registry.lock().expect("registry is not poisoned").len();
broadcast_notify_error(®istry, "synthetic listener failure".to_owned());
let retained = registry.lock().expect("registry is not poisoned").len();
(retained, initial)
}
fn debug_mentions_type(expected: &'static str) -> impl Fn(&String) -> AssertionResult {
move |debug| {
if debug.contains(expected) {
Ok(())
} else {
Err(AssertionError::new(vec![format!(
"expected debug output containing {expected:?}, got {debug}"
)]))
}
}
}
fn uses_default_queue(result: &SharedObservation) -> AssertionResult {
if result.queue == std::any::type_name::<String>()
&& result.buffer_size == 10
&& result.debug.contains("SharedFetcher")
{
Ok(())
} else {
Err(AssertionError::new(vec![format!(
"unexpected default shared storage: queue={:?}, buffer={}, debug={}",
result.queue, result.buffer_size, result.debug
)]))
}
}
fn uses_configured_queue(result: &SharedObservation) -> AssertionResult {
if result.queue == "shared-unit"
&& result.buffer_size == 3
&& result.debug.contains("SharedFetcher")
{
Ok(())
} else {
Err(AssertionError::new(vec![format!(
"unexpected configured shared storage: queue={:?}, buffer={}, debug={}",
result.queue, result.buffer_size, result.debug
)]))
}
}
fn constructs_backend_traits(result: &(String, String)) -> AssertionResult {
if result.0.contains("PgMiddleware") && result.1.contains("Stream") {
Ok(())
} else {
Err(AssertionError::new(vec![format!(
"unexpected shared trait surfaces: {result:?}"
)]))
}
}
fn removes_registration(result: &(String, bool)) -> AssertionResult {
if result.0.contains("SharedRegistration") && result.1 {
Ok(())
} else {
Err(AssertionError::new(vec![format!(
"expected registration debug and drop cleanup, got {result:?}"
)]))
}
}
fn make_shared_with_poisoned_registry() -> Result<(), SharedPostgresError> {
let mut shared: SharedPostgresStorage = SharedPostgresStorage::new(unchecked_pool());
let registry = shared.registry.clone();
let join = std::thread::spawn(move || {
let _guard = registry
.lock()
.expect("fresh registry lock is not poisoned");
panic!("synthetic poisoning panic");
});
let _ = join.join();
let config = Config::new("poisoned-registry");
<SharedPostgresStorage as MakeShared<String>>::make_shared_with_config(&mut shared, config)
.map(|_| ())
}
fn is_registry_locked(error: &SharedPostgresError) -> AssertionResult {
match error {
SharedPostgresError::RegistryLocked => Ok(()),
}
}
lets_expect! {
expect(shared_debug()) {
to describes_the_shared_factory { debug_mentions_type("SharedPostgresStorage") }
}
expect(make_default_shared()) {
when no_config_is_supplied {
to uses_the_task_type_as_the_namespace { be_ok_and uses_default_queue }
}
}
expect(make_configured_shared()) {
when config_is_supplied {
to exposes_the_queue_and_fetcher { be_ok_and uses_configured_queue }
}
}
expect(shared_trait_surfaces()) {
when backend_traits_are_requested {
to builds_middleware_and_compact_stream { be_ok_and constructs_backend_traits }
}
}
expect(registration_debug_and_drop()) {
when registration_is_dropped {
to removes_the_namespace_from_the_registry { removes_registration }
}
}
expect(drop_when_registry_empties()) {
when dropping_the_last_registration_empties_the_registry {
to leaves_no_remaining_registrations { equal(0) }
}
}
expect(drop_when_registry_has_siblings()) {
when dropping_one_of_several_registrations {
to keeps_sibling_registrations_intact { equal(2) }
}
}
expect(drop_one_of_two_keeps_sibling_sender()) {
when dropping_one_of_two_consumers_on_the_same_queue {
to leaves_the_other_senders_sender_in_place { equal(1) }
}
}
expect(double_make_shared_same_queue()) {
when the_same_queue_is_registered_twice {
to accepts_the_second_registration { be_ok }
}
}
expect(broadcast_notify_error_observation()) {
when listener_broadcasts_an_error_to_a_mixed_registry {
to drops_disconnected_senders_and_keeps_live_ones { equal((1_usize, 2_usize)) }
}
}
expect(make_shared_with_poisoned_registry()) {
when the_registry_mutex_is_poisoned_by_a_panic_in_another_thread {
to surfaces_registry_locked_rather_than_panicking_or_succeeding {
be_err_and is_registry_locked
}
}
}
}
}