Skip to main content

qail_pg/driver/
notification.rs

1//! LISTEN/NOTIFY support for PostgreSQL connections.
2//!
3//! PostgreSQL sends `NotificationResponse` messages asynchronously when
4//! a channel the connection is LISTENing on receives a NOTIFY.
5//!
6//! This module provides:
7//! - `Notification` struct — channel name + payload
8//! - `listen()` / `unlisten()` — subscribe/unsubscribe to channels
9//! - `poll_notifications()` — drain buffered notifications (non-blocking)
10//! - `recv_notification()` — block-wait for the next notification
11
12use super::{PgConnection, PgResult};
13use crate::protocol::PgEncoder;
14use tokio::io::AsyncWriteExt;
15
16/// A notification received from PostgreSQL LISTEN/NOTIFY.
17#[derive(Debug, Clone)]
18pub struct Notification {
19    /// The PID of the notifying backend process
20    pub process_id: i32,
21    /// The channel name
22    pub channel: String,
23    /// The payload (may be empty)
24    pub payload: String,
25}
26
27impl PgConnection {
28    /// Subscribe to a notification channel.
29    ///
30    /// ```ignore
31    /// conn.listen("price_calendar_changed").await?;
32    /// ```
33    pub async fn listen(&mut self, channel: &str) -> PgResult<()> {
34        // Channel names are identifiers, quote them to prevent injection
35        let sql = format!("LISTEN \"{}\"", channel.replace('"', "\"\""));
36        self.execute_simple(&sql).await
37    }
38
39    /// Unsubscribe from a notification channel.
40    pub async fn unlisten(&mut self, channel: &str) -> PgResult<()> {
41        let sql = format!("UNLISTEN \"{}\"", channel.replace('"', "\"\""));
42        self.execute_simple(&sql).await
43    }
44
45    /// Unsubscribe from all notification channels.
46    pub async fn unlisten_all(&mut self) -> PgResult<()> {
47        self.execute_simple("UNLISTEN *").await
48    }
49
50    /// Drain all buffered notifications without blocking.
51    ///
52    /// Notifications arrive asynchronously from PostgreSQL and are buffered
53    /// whenever `recv()` encounters a `NotificationResponse`. This method
54    /// returns all currently buffered notifications.
55    pub fn poll_notifications(&mut self) -> Vec<Notification> {
56        self.notifications.drain(..).collect()
57    }
58
59    /// Wait for the next notification, blocking until one arrives.
60    ///
61    /// Unlike `recv()`, this does NOT use the 30-second Slowloris timeout
62    /// guard. LISTEN connections idle for long periods — that's normal,
63    /// not a DoS attack.
64    ///
65    /// Useful for a dedicated LISTEN connection in a background task.
66    pub async fn recv_notification(&mut self) -> PgResult<Notification> {
67        use crate::protocol::BackendMessage;
68        use tokio::io::AsyncReadExt;
69
70        // Return buffered notification immediately if available
71        if let Some(n) = self.notifications.pop_front() {
72            return Ok(n);
73        }
74
75        // Send empty query to flush any pending notifications from server
76        let bytes = PgEncoder::encode_query_string("");
77        self.stream.write_all(&bytes).await?;
78
79        // Read messages — use recv() for the initial empty query response
80        // (which completes quickly), then switch to no-timeout reads
81        let mut got_ready = false;
82        loop {
83            // Try to decode from the existing buffer first
84            if self.buffer.len() >= 5 {
85                let msg_len = u32::from_be_bytes([
86                    self.buffer[1],
87                    self.buffer[2],
88                    self.buffer[3],
89                    self.buffer[4],
90                ]) as usize;
91
92                if self.buffer.len() > msg_len {
93                    let msg_bytes = self.buffer.split_to(msg_len + 1);
94                    let (msg, _) =
95                        BackendMessage::decode(&msg_bytes).map_err(super::PgError::Protocol)?;
96
97                    match msg {
98                        BackendMessage::NotificationResponse {
99                            process_id,
100                            channel,
101                            payload,
102                        } => {
103                            return Ok(Notification {
104                                process_id,
105                                channel,
106                                payload,
107                            });
108                        }
109                        BackendMessage::EmptyQueryResponse => continue,
110                        BackendMessage::ReadyForQuery(_) => {
111                            got_ready = true;
112                            // Check buffer for notifications that arrived with this batch
113                            if let Some(n) = self.notifications.pop_front() {
114                                return Ok(n);
115                            }
116                            continue;
117                        }
118                        _ => continue,
119                    }
120                }
121            }
122
123            // Read from socket — use tokio read (no timeout!) if we've
124            // already gotten ReadyForQuery (now we're just waiting for NOTIFY)
125            if self.buffer.capacity() - self.buffer.len() < 65536 {
126                self.buffer.reserve(131072);
127            }
128
129            if got_ready {
130                // No timeout — LISTEN connections idle for hours, that's fine
131                let n = self
132                    .stream
133                    .read_buf(&mut self.buffer)
134                    .await
135                    .map_err(|e| super::PgError::Connection(format!("Read error: {e}")))?;
136                if n == 0 {
137                    return Err(super::PgError::Connection("Connection closed".to_string()));
138                }
139            } else {
140                // Initial flush — use the normal timeout to avoid hanging
141                // if the server is unresponsive during the empty query
142                let n = self.read_with_timeout().await?;
143                if n == 0 {
144                    return Err(super::PgError::Connection("Connection closed".to_string()));
145                }
146            }
147        }
148    }
149}