use crate::pg_connection_listener::{
command_task, drive_while, listener_task, raw_connect, ListenerTaskContext,
NotificationDispatcher,
};
use crate::tokio_postgres::{MakeTlsConnect, Socket, TlsConnect};
use dashmap::{DashMap, DashSet};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use tokio::sync::{broadcast, mpsc, oneshot};
use tokio::task::JoinSet;
use crate::pg_client::{fetch_backend_pid, PgClient};
use crate::pg_pubsub_options::PgPubSubOptions;
pub struct PgPubSubConnection {
pg_client: Arc<PgClient>,
listeners: Arc<DashMap<Box<str>, Listener>>,
channel_capacity: usize,
cmd_tx: mpsc::UnboundedSender<Command>,
#[allow(unused)] tasks: JoinSet<()>,
}
pub(crate) enum Command {
Listen {
channel: Box<str>,
response: oneshot::Sender<Result<(), tokio_postgres::Error>>,
},
Unsub {
channel: Box<str>,
},
UnlistenIfEmpty {
channel: Box<str>,
},
}
impl Command {
pub(crate) fn channel(&self) -> &str {
match self {
Command::Listen { channel, .. }
| Command::Unsub { channel }
| Command::UnlistenIfEmpty { channel } => channel,
}
}
}
#[derive(Clone, Debug)]
#[non_exhaustive]
pub struct Notification {
pub channel: Arc<str>,
pub payload: Arc<str>,
pub process_id: i32,
}
pub(crate) struct Listener {
pub send_channel: broadcast::Sender<Notification>,
pub listener_count: AtomicUsize,
pub channel: Arc<str>,
}
struct ListenRollbackGuard<'a> {
key: Option<Box<str>>,
cmd_tx: &'a mpsc::UnboundedSender<Command>,
}
impl ListenRollbackGuard<'_> {
fn disarm(mut self) {
self.key = None;
}
}
impl Drop for ListenRollbackGuard<'_> {
fn drop(&mut self) {
if let Some(key) = self.key.take() {
if let Err(err) = self.cmd_tx.send(Command::Unsub { channel: key }) {
log::error!("Failed to roll back listener: {err}");
}
}
}
}
pub struct Subscription {
channel: Box<str>,
receiver: broadcast::Receiver<Notification>,
cmd_tx: mpsc::UnboundedSender<Command>,
}
impl Subscription {
pub fn channel(&self) -> &str {
&self.channel
}
pub async fn recv(&mut self) -> Result<Notification, RecvError> {
self.receiver.recv().await.map_err(|err| match err {
broadcast::error::RecvError::Closed => RecvError::Closed,
broadcast::error::RecvError::Lagged(n) => RecvError::Lagged(n),
})
}
}
#[derive(Debug)]
#[non_exhaustive]
pub enum RecvError {
Closed,
Lagged(u64),
}
impl std::fmt::Display for RecvError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
RecvError::Closed => write!(f, "subscription closed"),
RecvError::Lagged(n) => write!(f, "subscription lagged, {n} notifications dropped"),
}
}
}
impl std::error::Error for RecvError {}
impl Drop for Subscription {
fn drop(&mut self) {
log::debug!(
"Unsubscribing from channel {channel}",
channel = self.channel
);
let channel = std::mem::take(&mut self.channel);
if let Err(err) = self.cmd_tx.send(Command::Unsub { channel }) {
log::error!("Error when unsubscribing: {err}");
}
}
}
#[derive(Debug)]
#[non_exhaustive]
pub enum PubSubError {
InvalidChannelName,
InvalidPayload,
ListenError(tokio_postgres::Error),
NotifyError(tokio_postgres::Error),
Closed,
}
impl std::fmt::Display for PubSubError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
PubSubError::InvalidChannelName => write!(f, "invalid channel name"),
PubSubError::InvalidPayload => {
write!(f, "notification payload must be shorter than 8000 bytes")
}
PubSubError::ListenError(e) => write!(f, "LISTEN command failed: {e}"),
PubSubError::NotifyError(e) => write!(f, "NOTIFY command failed: {e}"),
PubSubError::Closed => write!(f, "PgPubSub connection closed"),
}
}
}
impl std::error::Error for PubSubError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
PubSubError::ListenError(e) | PubSubError::NotifyError(e) => Some(e),
PubSubError::InvalidChannelName | PubSubError::InvalidPayload | PubSubError::Closed => {
None
}
}
}
}
impl PgPubSubConnection {
pub(crate) async fn connect<T>(
options: PgPubSubOptions<T>,
) -> Result<Self, tokio_postgres::Error>
where
T: MakeTlsConnect<Socket> + Clone + Send + Sync + 'static,
<T as MakeTlsConnect<Socket>>::Stream: Send + 'static,
<T as MakeTlsConnect<Socket>>::TlsConnect: Send,
<<T as MakeTlsConnect<Socket>>::TlsConnect as TlsConnect<Socket>>::Future: Send,
{
let PgPubSubOptions {
connection_params,
channel_capacity,
suppress_own_notifications,
tls,
} = options;
let (client, connection) = raw_connect(&connection_params, tls.clone()).await?;
let listener_map: Arc<DashMap<Box<str>, Listener>> = Default::default();
let pending_unlisten: Arc<DashSet<Box<str>>> = Default::default();
let (cmd_tx, cmd_rx) = mpsc::unbounded_channel();
let dispatcher = NotificationDispatcher {
listener_map: Arc::clone(&listener_map),
pending_unlisten: Arc::clone(&pending_unlisten),
cmd_tx: cmd_tx.clone(),
suppress_own_notifications,
};
let (backend_pid, connection) = if suppress_own_notifications {
let (pid, connection) =
drive_while(connection, &dispatcher, None, fetch_backend_pid(&client)).await;
(Some(pid?), connection)
} else {
(None, Some(connection))
};
let pg_client = Arc::new(PgClient::new(client));
let mut tasks: JoinSet<()> = JoinSet::new();
let ctx = ListenerTaskContext {
dispatcher,
connection_params,
tls,
pg_client: Arc::clone(&pg_client),
};
tasks.spawn(listener_task(connection, backend_pid, ctx));
let cmd_pg_client = Arc::clone(&pg_client);
let cmd_listener_map = Arc::clone(&listener_map);
tasks.spawn(async move {
command_task(cmd_rx, cmd_listener_map, pending_unlisten, cmd_pg_client).await;
});
Ok(PgPubSubConnection {
pg_client,
listeners: listener_map,
channel_capacity,
cmd_tx,
tasks,
})
}
pub async fn listen(&self, channel: &str) -> Result<Subscription, PubSubError> {
if !valid_channel_name(channel) {
return Err(PubSubError::InvalidChannelName);
}
let key: Box<str> = channel.into();
let (receiver, listen_response_rx) = {
let entry = self.listeners.entry(key.clone()).or_insert_with(|| {
let (sender, _) = broadcast::channel(self.channel_capacity);
Listener {
send_channel: sender,
listener_count: AtomicUsize::new(0),
channel: Arc::from(channel),
}
});
let prev = entry.listener_count.fetch_add(1, Ordering::Relaxed);
let receiver = entry.send_channel.subscribe();
let listen_rx = if prev == 0 {
let (response_tx, response_rx) = oneshot::channel();
if self
.cmd_tx
.send(Command::Listen {
channel: key.clone(),
response: response_tx,
})
.is_err()
{
entry.listener_count.fetch_sub(1, Ordering::Relaxed);
return Err(PubSubError::Closed);
}
Some(response_rx)
} else {
None
};
(receiver, listen_rx)
};
let rollback = ListenRollbackGuard {
key: Some(key),
cmd_tx: &self.cmd_tx,
};
if let Some(rx) = listen_response_rx {
rx.await
.map_err(|_| PubSubError::Closed)?
.map_err(PubSubError::ListenError)?;
}
rollback.disarm();
Ok(Subscription {
channel: channel.into(),
receiver,
cmd_tx: self.cmd_tx.clone(),
})
}
pub async fn notify(&self, channel: &str, payload: Option<&str>) -> Result<(), PubSubError> {
if !valid_channel_name(channel) {
return Err(PubSubError::InvalidChannelName);
}
if !valid_payload(payload) {
return Err(PubSubError::InvalidPayload);
}
self.notify_cmd(channel, payload).await
}
pub async fn notify_batch(&self, items: &[(&str, Option<&str>)]) -> Result<(), PubSubError> {
for (channel, payload) in items {
if !valid_channel_name(channel) {
return Err(PubSubError::InvalidChannelName);
}
if !valid_payload(*payload) {
return Err(PubSubError::InvalidPayload);
}
}
log::debug!("Notifying batch of {} items", items.len());
self.pg_client
.notify_batch(items)
.await
.map_err(PubSubError::NotifyError)
}
async fn notify_cmd(&self, channel: &str, payload: Option<&str>) -> Result<(), PubSubError> {
log::debug!(
"Notifying on channel {channel} and payload {payload_str}",
payload_str = payload.unwrap_or_default()
);
self.pg_client
.notify(channel, payload)
.await
.map_err(PubSubError::NotifyError)
}
}
fn valid_channel_name(channel: &str) -> bool {
(1..=63).contains(&channel.len())
}
fn valid_payload(payload: Option<&str>) -> bool {
payload.is_none_or(|p| p.len() < 8000)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn command_channel_returns_target_for_each_variant() {
let (response_tx, _response_rx) = oneshot::channel();
let listen = Command::Listen {
channel: "alpha".into(),
response: response_tx,
};
assert_eq!(listen.channel(), "alpha");
let unsub = Command::Unsub {
channel: "beta".into(),
};
assert_eq!(unsub.channel(), "beta");
let unlisten = Command::UnlistenIfEmpty {
channel: "gamma".into(),
};
assert_eq!(unlisten.channel(), "gamma");
}
#[test]
fn rollback_guard_drops_into_unsub_when_armed() {
let (cmd_tx, mut cmd_rx) = mpsc::unbounded_channel();
{
let _guard = ListenRollbackGuard {
key: Some("foo".into()),
cmd_tx: &cmd_tx,
};
}
match cmd_rx.try_recv().expect("expected an Unsub on guard drop") {
Command::Unsub { channel } => assert_eq!(&*channel, "foo"),
_ => panic!("expected Command::Unsub variant"),
}
assert!(cmd_rx.try_recv().is_err(), "exactly one command expected");
}
#[test]
fn rollback_guard_is_silent_after_disarm() {
let (cmd_tx, mut cmd_rx) = mpsc::unbounded_channel();
{
let guard = ListenRollbackGuard {
key: Some("foo".into()),
cmd_tx: &cmd_tx,
};
guard.disarm();
}
assert!(
cmd_rx.try_recv().is_err(),
"disarmed guard must not send anything on drop"
);
}
#[test]
fn rollback_guard_logs_but_does_not_panic_when_funnel_is_gone() {
let (cmd_tx, cmd_rx) = mpsc::unbounded_channel();
drop(cmd_rx);
let _guard = ListenRollbackGuard {
key: Some("foo".into()),
cmd_tx: &cmd_tx,
};
}
#[test]
fn valid_channel_name_accepts_one_to_sixty_three_bytes() {
assert!(valid_channel_name("a"));
assert!(valid_channel_name(&"a".repeat(63)));
}
#[test]
fn valid_channel_name_rejects_empty_and_oversize() {
assert!(!valid_channel_name(""));
assert!(!valid_channel_name(&"a".repeat(64)));
assert!(!valid_channel_name(&"a".repeat(1000)));
}
#[test]
fn valid_payload_accepts_none_empty_and_up_to_7999_bytes() {
assert!(valid_payload(None));
assert!(valid_payload(Some("")));
assert!(valid_payload(Some(&"a".repeat(7999))));
}
#[test]
fn valid_payload_rejects_8000_bytes_and_above() {
assert!(!valid_payload(Some(&"a".repeat(8000))));
assert!(!valid_payload(Some(&"a".repeat(100_000))));
}
}