use postgres_protocol::escape::{escape_identifier, escape_literal};
use tokio_postgres::Client;
pub(crate) struct PgClient {
client: Client,
}
impl PgClient {
pub fn new(client: Client) -> Self {
Self { client }
}
pub async fn listen(&self, channel: &str) -> Result<u64, crate::tokio_postgres::Error> {
let channel = escape_identifier(channel);
let listen_cmd = format!("LISTEN {channel}");
self.client.execute(&listen_cmd, &[]).await
}
pub async fn unlisten(&self, channel: &str) -> Result<u64, crate::tokio_postgres::Error> {
let channel = escape_identifier(channel);
let unlisten_cmd = format!("UNLISTEN {channel}");
self.client.execute(&unlisten_cmd, &[]).await
}
pub async fn notify(
&self,
channel: &str,
payload: Option<&str>,
) -> Result<u64, crate::tokio_postgres::Error> {
let channel = escape_identifier(channel);
let notify_cmd = match payload {
Some(payload) => {
let payload = escape_literal(payload);
format!("NOTIFY {channel}, {payload}")
}
None => format!("NOTIFY {channel}"),
};
self.client.execute(¬ify_cmd, &[]).await
}
pub async fn get_pid(&self) -> Result<i32, crate::tokio_postgres::Error> {
let row = self
.client
.query_one("SELECT pg_backend_pid()", &[])
.await?;
let pid = row.get(0);
log::debug!("get_pid: {pid}");
Ok(pid)
}
}