use std::fmt::Write;
use postgres_protocol::escape::{escape_identifier, escape_literal};
use tokio::sync::RwLock;
use tokio_postgres::Client;
pub(crate) struct PgClient {
client: RwLock<Client>,
}
impl PgClient {
pub fn new(client: Client) -> Self {
Self {
client: RwLock::new(client),
}
}
pub async fn replace(&self, client: Client) {
*self.client.write().await = client;
}
pub async fn listen(&self, channel: &str) -> Result<(), crate::tokio_postgres::Error> {
let channel = escape_identifier(channel);
let client = self.client.read().await;
client.batch_execute(&format!("LISTEN {channel}")).await
}
pub async fn unlisten(&self, channel: &str) -> Result<(), crate::tokio_postgres::Error> {
let channel = escape_identifier(channel);
let client = self.client.read().await;
client.batch_execute(&format!("UNLISTEN {channel}")).await
}
pub async fn notify(
&self,
channel: &str,
payload: Option<&str>,
) -> Result<(), crate::tokio_postgres::Error> {
let channel = escape_identifier(channel);
let cmd = match payload {
Some(payload) => {
let payload = escape_literal(payload);
format!("NOTIFY {channel}, {payload}")
}
None => format!("NOTIFY {channel}"),
};
let client = self.client.read().await;
client.batch_execute(&cmd).await
}
pub async fn notify_batch(
&self,
items: &[(&str, Option<&str>)],
) -> Result<(), crate::tokio_postgres::Error> {
let Some(cmd) = build_notify_batch_sql(items) else {
return Ok(());
};
let client = self.client.read().await;
client.batch_execute(&cmd).await
}
}
pub(crate) async fn fetch_backend_pid(
client: &Client,
) -> Result<i32, crate::tokio_postgres::Error> {
let row = client.query_one("SELECT pg_backend_pid()", &[]).await?;
let pid = row.get(0);
log::debug!("fetch_backend_pid: {pid}");
Ok(pid)
}
pub(crate) fn build_relisten_sql<'a>(
channels: impl IntoIterator<Item = &'a str>,
) -> Option<String> {
let mut cmd = String::new();
for channel in channels {
let channel = escape_identifier(channel);
let _ = write!(&mut cmd, "LISTEN {channel};");
}
(!cmd.is_empty()).then_some(cmd)
}
fn build_notify_batch_sql(items: &[(&str, Option<&str>)]) -> Option<String> {
if items.is_empty() {
return None;
}
let mut cmd = String::new();
for (channel, payload) in items {
let channel = escape_identifier(channel);
match payload {
Some(payload) => {
let payload = escape_literal(payload);
let _ = write!(&mut cmd, "NOTIFY {channel}, {payload};");
}
None => {
let _ = write!(&mut cmd, "NOTIFY {channel};");
}
}
}
Some(cmd)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty_batch_returns_none() {
assert!(build_notify_batch_sql(&[]).is_none());
}
#[test]
fn single_item_with_payload() {
let sql = build_notify_batch_sql(&[("foo", Some("hello"))]).unwrap();
assert_eq!(sql, "NOTIFY \"foo\", 'hello';");
}
#[test]
fn batch_has_no_explicit_transaction_commands() {
let sql = build_notify_batch_sql(&[("a", Some("1")), ("b", None)]).unwrap();
assert!(!sql.contains("BEGIN"), "got: {sql}");
assert!(!sql.contains("COMMIT"), "got: {sql}");
}
#[test]
fn single_item_without_payload() {
let sql = build_notify_batch_sql(&[("foo", None)]).unwrap();
assert!(sql.contains("NOTIFY \"foo\";"));
assert!(!sql.contains(", "));
}
#[test]
fn multiple_items_concatenated_in_order() {
let sql =
build_notify_batch_sql(&[("a", Some("1")), ("b", None), ("c", Some("3"))]).unwrap();
let expected = "NOTIFY \"a\", '1';NOTIFY \"b\";NOTIFY \"c\", '3';";
assert_eq!(sql, expected);
}
#[test]
fn relisten_sql_is_empty_for_no_channels() {
assert!(build_relisten_sql([]).is_none());
}
#[test]
fn relisten_sql_lists_each_channel_escaped() {
let sql = build_relisten_sql(["foo", "we\"ird"]).unwrap();
assert_eq!(sql, "LISTEN \"foo\";LISTEN \"we\"\"ird\";");
}
#[test]
fn channel_and_payload_are_escaped() {
let sql = build_notify_batch_sql(&[("ch", Some("a'b"))]).unwrap();
assert!(sql.contains("'a''b'"), "got: {sql}");
let sql = build_notify_batch_sql(&[("ch\"x", None)]).unwrap();
assert!(sql.contains("\"ch\"\"x\""), "got: {sql}");
}
}