Skip to main content

qail_pg/driver/
notification.rs

1//! LISTEN/NOTIFY support for PostgreSQL connections.
2//!
3//! PostgreSQL sends `NotificationResponse` messages asynchronously when
4//! a channel the connection is LISTENing on receives a NOTIFY.
5//!
6//! This module provides:
7//! - `Notification` struct — channel name + payload
8//! - `listen()` / `unlisten()` — subscribe/unsubscribe to channels
9//! - `poll_notifications()` — drain buffered notifications (non-blocking)
10//! - `recv_notification()` — block-wait for the next notification
11
12use super::{
13    PgConnection, PgError, PgResult, io::MAX_MESSAGE_SIZE, is_ignorable_session_message,
14    unexpected_backend_message,
15};
16use crate::protocol::PgEncoder;
17
18/// A notification received from PostgreSQL LISTEN/NOTIFY.
19#[derive(Debug, Clone)]
20pub struct Notification {
21    /// The PID of the notifying backend process
22    pub process_id: i32,
23    /// The channel name
24    pub channel: String,
25    /// The payload (may be empty)
26    pub payload: String,
27}
28
29#[inline]
30fn return_with_desync<T>(conn: &mut PgConnection, err: PgError) -> PgResult<T> {
31    if matches!(
32        err,
33        PgError::Protocol(_) | PgError::Connection(_) | PgError::Timeout(_)
34    ) {
35        conn.mark_io_desynced();
36    }
37    Err(err)
38}
39
40impl PgConnection {
41    /// Subscribe to a notification channel.
42    ///
43    /// ```ignore
44    /// conn.listen("price_calendar_changed").await?;
45    /// ```
46    pub async fn listen(&mut self, channel: &str) -> PgResult<()> {
47        // Channel names are identifiers, quote them to prevent injection
48        let sql = format!("LISTEN \"{}\"", channel.replace('"', "\"\""));
49        self.execute_simple(&sql).await
50    }
51
52    /// Unsubscribe from a notification channel.
53    pub async fn unlisten(&mut self, channel: &str) -> PgResult<()> {
54        let sql = format!("UNLISTEN \"{}\"", channel.replace('"', "\"\""));
55        self.execute_simple(&sql).await
56    }
57
58    /// Unsubscribe from all notification channels.
59    pub async fn unlisten_all(&mut self) -> PgResult<()> {
60        self.execute_simple("UNLISTEN *").await
61    }
62
63    /// Drain all buffered notifications without blocking.
64    ///
65    /// Notifications arrive asynchronously from PostgreSQL and are buffered
66    /// whenever `recv()` encounters a `NotificationResponse`. This method
67    /// returns all currently buffered notifications.
68    pub fn poll_notifications(&mut self) -> Vec<Notification> {
69        self.notifications.drain(..).collect()
70    }
71
72    /// Wait for the next notification, blocking until one arrives.
73    ///
74    /// Unlike `recv()`, this does NOT use the 30-second Slowloris timeout
75    /// guard. LISTEN connections idle for long periods — that's normal,
76    /// not a DoS attack.
77    ///
78    /// Useful for a dedicated LISTEN connection in a background task.
79    pub async fn recv_notification(&mut self) -> PgResult<Notification> {
80        use crate::protocol::BackendMessage;
81
82        // Return buffered notification immediately if available
83        if let Some(n) = self.notifications.pop_front() {
84            return Ok(n);
85        }
86
87        // Send empty query to flush any pending notifications from server
88        let bytes = PgEncoder::try_encode_query_string("")?;
89        self.write_all_with_timeout(&bytes, "stream write").await?;
90
91        // Read messages — use recv() for the initial empty query response
92        // (which completes quickly), then switch to no-timeout reads
93        let mut got_ready = false;
94        loop {
95            // Try to decode from the existing buffer first
96            if self.buffer.len() >= 5 {
97                let msg_len = u32::from_be_bytes([
98                    self.buffer[1],
99                    self.buffer[2],
100                    self.buffer[3],
101                    self.buffer[4],
102                ]) as usize;
103
104                if msg_len < 4 {
105                    return return_with_desync(
106                        self,
107                        PgError::Protocol(format!(
108                            "Invalid message length: {} (minimum 4)",
109                            msg_len
110                        )),
111                    );
112                }
113
114                if msg_len > MAX_MESSAGE_SIZE {
115                    return return_with_desync(
116                        self,
117                        PgError::Protocol(format!(
118                            "Message too large: {} bytes (max {})",
119                            msg_len, MAX_MESSAGE_SIZE
120                        )),
121                    );
122                }
123
124                if self.buffer.len() > msg_len {
125                    let msg_bytes = self.buffer.split_to(msg_len + 1);
126                    let (msg, _) = match BackendMessage::decode(&msg_bytes) {
127                        Ok(decoded) => decoded,
128                        Err(err) => return return_with_desync(self, PgError::Protocol(err)),
129                    };
130
131                    match msg {
132                        BackendMessage::NotificationResponse {
133                            process_id,
134                            channel,
135                            payload,
136                        } => {
137                            let notification = Notification {
138                                process_id,
139                                channel,
140                                payload,
141                            };
142                            if got_ready {
143                                return Ok(notification);
144                            }
145                            self.notifications.push_back(notification);
146                            continue;
147                        }
148                        BackendMessage::EmptyQueryResponse => continue,
149                        BackendMessage::NoticeResponse(_) => continue,
150                        BackendMessage::ParameterStatus { .. } => continue,
151                        BackendMessage::CommandComplete(_) => continue,
152                        BackendMessage::ReadyForQuery(_) => {
153                            got_ready = true;
154                            // Check buffer for notifications that arrived with this batch
155                            if let Some(n) = self.notifications.pop_front() {
156                                return Ok(n);
157                            }
158                            continue;
159                        }
160                        BackendMessage::ErrorResponse(err) => {
161                            return Err(PgError::QueryServer(err.into()));
162                        }
163                        msg if is_ignorable_session_message(&msg) => continue,
164                        other => {
165                            return return_with_desync(
166                                self,
167                                unexpected_backend_message("listen/notify wait", &other),
168                            );
169                        }
170                    }
171                }
172            }
173
174            // Read from socket — use tokio read (no timeout!) if we've
175            // already gotten ReadyForQuery (now we're just waiting for NOTIFY)
176            if self.buffer.capacity() - self.buffer.len() < 65536 {
177                self.buffer.reserve(131072);
178            }
179
180            if got_ready {
181                // LISTEN connections can stay idle for hours (empty buffer),
182                // but a partially buffered backend frame should still timeout
183                // to fail-closed on slowloris-style partial writes.
184                let n = if self.buffer.is_empty() {
185                    self.read_without_timeout().await?
186                } else {
187                    self.read_with_timeout().await?
188                };
189                if n == 0 {
190                    return return_with_desync(
191                        self,
192                        PgError::Connection("Connection closed".to_string()),
193                    );
194                }
195            } else {
196                // Initial flush — use the normal timeout to avoid hanging
197                // if the server is unresponsive during the empty query
198                let n = self.read_with_timeout().await?;
199                if n == 0 {
200                    return return_with_desync(
201                        self,
202                        PgError::Connection("Connection closed".to_string()),
203                    );
204                }
205            }
206        }
207    }
208}
209
210#[cfg(test)]
211mod tests {
212    use super::return_with_desync;
213    use crate::driver::{PgConnection, PgError};
214
215    #[cfg(unix)]
216    fn test_conn_with_peer() -> (PgConnection, tokio::net::UnixStream) {
217        use crate::driver::connection::StatementCache;
218        use crate::driver::stream::PgStream;
219        use bytes::BytesMut;
220        use std::collections::{HashMap, VecDeque};
221        use std::num::NonZeroUsize;
222        use tokio::net::UnixStream;
223
224        let (unix_stream, peer) = UnixStream::pair().expect("unix stream pair");
225        (
226            PgConnection {
227                stream: PgStream::Unix(unix_stream),
228                buffer: BytesMut::with_capacity(1024),
229                write_buf: BytesMut::with_capacity(1024),
230                sql_buf: BytesMut::with_capacity(256),
231                params_buf: Vec::new(),
232                prepared_statements: HashMap::new(),
233                stmt_cache: StatementCache::new(NonZeroUsize::new(2).expect("non-zero")),
234                column_info_cache: HashMap::new(),
235                process_id: 0,
236                cancel_key_bytes: Vec::new(),
237                requested_protocol_minor: PgConnection::default_protocol_minor(),
238                negotiated_protocol_minor: PgConnection::default_protocol_minor(),
239                notifications: VecDeque::new(),
240                replication_stream_active: false,
241                replication_mode_enabled: false,
242                last_replication_wal_end: None,
243                io_desynced: false,
244                pending_statement_closes: Vec::new(),
245                draining_statement_closes: false,
246            },
247            peer,
248        )
249    }
250
251    #[cfg(unix)]
252    fn test_conn() -> PgConnection {
253        test_conn_with_peer().0
254    }
255
256    #[cfg(unix)]
257    fn push_backend_frame(conn: &mut PgConnection, msg_type: u8, payload: &[u8]) {
258        conn.buffer.extend_from_slice(&[msg_type]);
259        conn.buffer
260            .extend_from_slice(&((payload.len() + 4) as u32).to_be_bytes());
261        conn.buffer.extend_from_slice(payload);
262    }
263
264    #[cfg(unix)]
265    fn notification_payload(process_id: i32, channel: &str, payload: &str) -> Vec<u8> {
266        let mut bytes = Vec::new();
267        bytes.extend_from_slice(&process_id.to_be_bytes());
268        bytes.extend_from_slice(channel.as_bytes());
269        bytes.push(0);
270        bytes.extend_from_slice(payload.as_bytes());
271        bytes.push(0);
272        bytes
273    }
274
275    #[cfg(unix)]
276    #[tokio::test]
277    async fn notification_return_with_desync_marks_protocol_error() {
278        let mut conn = test_conn();
279
280        let err =
281            return_with_desync::<()>(&mut conn, PgError::Protocol("bad notify frame".to_string()))
282                .expect_err("protocol error must be returned");
283
284        assert!(err.to_string().contains("bad notify frame"));
285        assert!(conn.is_io_desynced());
286    }
287
288    #[cfg(unix)]
289    #[tokio::test]
290    async fn recv_notification_drains_empty_query_before_returning_pre_ready_notify() {
291        let (mut conn, _peer) = test_conn_with_peer();
292        let payload = notification_payload(42, "jobs", "ready");
293
294        push_backend_frame(&mut conn, b'A', &payload);
295        push_backend_frame(&mut conn, b'I', &[]);
296        push_backend_frame(&mut conn, b'Z', b"I");
297
298        let notification = conn
299            .recv_notification()
300            .await
301            .expect("pre-ready notification should be returned after flush drain");
302
303        assert_eq!(notification.process_id, 42);
304        assert_eq!(notification.channel, "jobs");
305        assert_eq!(notification.payload, "ready");
306        assert!(
307            conn.buffer.is_empty(),
308            "empty-query flush frames must not remain buffered"
309        );
310        assert!(!conn.is_io_desynced());
311    }
312}