use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{broadcast, watch};
const CHANNEL_BUFFER_SIZE: usize = 256;
const INITIAL_BACKOFF: Duration = Duration::from_millis(500);
const MAX_BACKOFF: Duration = Duration::from_secs(30);
pub struct PgNotifyBus {
pool: sqlx::PgPool,
senders: Arc<HashMap<String, broadcast::Sender<String>>>,
reconnect_tx: watch::Sender<u64>,
}
impl PgNotifyBus {
pub fn new(pool: sqlx::PgPool, channels: &[&str]) -> Self {
let mut senders = HashMap::with_capacity(channels.len());
for &ch in channels {
let (tx, _) = broadcast::channel(CHANNEL_BUFFER_SIZE);
senders.insert(ch.to_string(), tx);
}
let (reconnect_tx, _) = watch::channel(0u64);
Self {
pool,
senders: Arc::new(senders),
reconnect_tx,
}
}
pub fn subscribe(&self, channel: &str) -> Option<broadcast::Receiver<String>> {
self.senders.get(channel).map(|tx| tx.subscribe())
}
pub fn channels(&self) -> Vec<&str> {
self.senders.keys().map(|s| s.as_str()).collect()
}
pub fn subscribe_reconnects(&self) -> watch::Receiver<u64> {
self.reconnect_tx.subscribe()
}
pub async fn run(&self, shutdown: tokio::sync::watch::Receiver<bool>) {
let channel_names: Vec<String> = self.senders.keys().cloned().collect();
let mut backoff = INITIAL_BACKOFF;
let mut shutdown = shutdown;
loop {
let listener = match self.connect_and_listen(&channel_names).await {
Ok(l) => {
backoff = INITIAL_BACKOFF;
self.reconnect_tx.send_modify(|g| *g = g.saturating_add(1));
l
}
Err(e) => {
tracing::warn!(error = %e, "PgNotifyBus: connect/listen failed, retrying");
tokio::select! {
biased;
_ = shutdown.changed() => {
if *shutdown.borrow() {
tracing::debug!("PgNotifyBus: shutdown during reconnect");
return;
}
}
_ = tokio::time::sleep(backoff) => {}
}
backoff = (backoff * 2).min(MAX_BACKOFF);
continue;
}
};
tracing::info!(
channels = ?channel_names,
"PgNotifyBus: listening on {} channel(s)",
channel_names.len(),
);
if self.recv_loop(listener, &mut shutdown).await {
tracing::debug!("PgNotifyBus: shutting down");
return;
}
tracing::warn!("PgNotifyBus: connection lost, reconnecting");
}
}
async fn connect_and_listen(
&self,
channels: &[String],
) -> Result<sqlx::postgres::PgListener, sqlx::Error> {
let mut listener = sqlx::postgres::PgListener::connect_with(&self.pool).await?;
for ch in channels {
listener.listen(ch).await?;
}
Ok(listener)
}
async fn recv_loop(
&self,
mut listener: sqlx::postgres::PgListener,
shutdown: &mut tokio::sync::watch::Receiver<bool>,
) -> bool {
loop {
tokio::select! {
biased;
_ = shutdown.changed() => {
if *shutdown.borrow() {
return true;
}
}
notification = listener.recv() => {
match notification {
Ok(n) => {
let channel = n.channel();
let payload = n.payload().to_string();
if let Some(tx) = self.senders.get(channel) {
let _ = tx.send(payload);
} else {
tracing::debug!(
channel = channel,
"PgNotifyBus: notification on unregistered channel, ignoring",
);
}
}
Err(e) => {
tracing::warn!(error = %e, "PgNotifyBus: recv error");
return false;
}
}
}
}
}
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
mod tests {
use super::*;
fn make_bus(channels: &[&str]) -> PgNotifyBus {
let pool = sqlx::PgPool::connect_lazy("postgres://localhost/test").unwrap();
PgNotifyBus::new(pool, channels)
}
#[tokio::test]
async fn subscribe_returns_receiver_for_registered_channel() {
let bus = make_bus(&["forge_changes", "forge_jobs_available"]);
assert!(bus.subscribe("forge_changes").is_some());
assert!(bus.subscribe("forge_jobs_available").is_some());
}
#[tokio::test]
async fn subscribe_returns_none_for_unknown_channel() {
let bus = make_bus(&["forge_changes"]);
assert!(bus.subscribe("forge_nonexistent").is_none());
}
#[tokio::test]
async fn channels_returns_all_registered_names() {
let bus = make_bus(&[
"forge_changes",
"forge_jobs_available",
"forge_workflow_wakeup",
]);
let mut names = bus.channels();
names.sort();
assert_eq!(
names,
vec![
"forge_changes",
"forge_jobs_available",
"forge_workflow_wakeup"
],
);
}
#[tokio::test]
async fn fan_out_delivers_to_all_subscribers() {
let bus = make_bus(&["test_channel"]);
let mut rx1 = bus.subscribe("test_channel").unwrap();
let mut rx2 = bus.subscribe("test_channel").unwrap();
let tx = bus.senders.get("test_channel").unwrap();
tx.send("hello".to_string()).unwrap();
assert_eq!(rx1.recv().await.unwrap(), "hello");
assert_eq!(rx2.recv().await.unwrap(), "hello");
}
#[tokio::test]
async fn send_without_subscribers_does_not_error() {
let bus = make_bus(&["test_channel"]);
let tx = bus.senders.get("test_channel").unwrap();
let _ = tx.send("orphan".to_string());
}
#[tokio::test]
async fn empty_channels_list_produces_empty_bus() {
let bus = make_bus(&[]);
assert!(bus.channels().is_empty());
assert!(bus.subscribe("anything").is_none());
}
#[tokio::test]
async fn reconnect_subscriber_starts_at_zero_and_observes_ticks() {
let bus = make_bus(&["test_channel"]);
let mut rx = bus.subscribe_reconnects();
assert_eq!(*rx.borrow(), 0, "fresh bus starts at generation 0");
bus.reconnect_tx.send_modify(|g| *g = g.saturating_add(1));
rx.changed().await.unwrap();
assert_eq!(*rx.borrow(), 1, "first connect publishes generation 1");
bus.reconnect_tx.send_modify(|g| *g = g.saturating_add(1));
rx.changed().await.unwrap();
assert_eq!(*rx.borrow(), 2, "reconnect publishes generation 2");
}
}