Skip to main content

qail_pg/driver/
notification.rs

1//! LISTEN/NOTIFY support for PostgreSQL connections.
2//!
3//! PostgreSQL sends `NotificationResponse` messages asynchronously when
4//! a channel the connection is LISTENing on receives a NOTIFY.
5//!
6//! This module provides:
7//! - `Notification` struct — channel name + payload
8//! - `listen()` / `unlisten()` — subscribe/unsubscribe to channels
9//! - `poll_notifications()` — drain buffered notifications (non-blocking)
10//! - `recv_notification()` — block-wait for the next notification
11
12use super::{
13    PgConnection, PgResult, io::MAX_MESSAGE_SIZE, is_ignorable_session_message,
14    unexpected_backend_message,
15};
16use crate::protocol::PgEncoder;
17
18/// A notification received from PostgreSQL LISTEN/NOTIFY.
19#[derive(Debug, Clone)]
20pub struct Notification {
21    /// The PID of the notifying backend process
22    pub process_id: i32,
23    /// The channel name
24    pub channel: String,
25    /// The payload (may be empty)
26    pub payload: String,
27}
28
29impl PgConnection {
30    /// Subscribe to a notification channel.
31    ///
32    /// ```ignore
33    /// conn.listen("price_calendar_changed").await?;
34    /// ```
35    pub async fn listen(&mut self, channel: &str) -> PgResult<()> {
36        // Channel names are identifiers, quote them to prevent injection
37        let sql = format!("LISTEN \"{}\"", channel.replace('"', "\"\""));
38        self.execute_simple(&sql).await
39    }
40
41    /// Unsubscribe from a notification channel.
42    pub async fn unlisten(&mut self, channel: &str) -> PgResult<()> {
43        let sql = format!("UNLISTEN \"{}\"", channel.replace('"', "\"\""));
44        self.execute_simple(&sql).await
45    }
46
47    /// Unsubscribe from all notification channels.
48    pub async fn unlisten_all(&mut self) -> PgResult<()> {
49        self.execute_simple("UNLISTEN *").await
50    }
51
52    /// Drain all buffered notifications without blocking.
53    ///
54    /// Notifications arrive asynchronously from PostgreSQL and are buffered
55    /// whenever `recv()` encounters a `NotificationResponse`. This method
56    /// returns all currently buffered notifications.
57    pub fn poll_notifications(&mut self) -> Vec<Notification> {
58        self.notifications.drain(..).collect()
59    }
60
61    /// Wait for the next notification, blocking until one arrives.
62    ///
63    /// Unlike `recv()`, this does NOT use the 30-second Slowloris timeout
64    /// guard. LISTEN connections idle for long periods — that's normal,
65    /// not a DoS attack.
66    ///
67    /// Useful for a dedicated LISTEN connection in a background task.
68    pub async fn recv_notification(&mut self) -> PgResult<Notification> {
69        use crate::protocol::BackendMessage;
70
71        // Return buffered notification immediately if available
72        if let Some(n) = self.notifications.pop_front() {
73            return Ok(n);
74        }
75
76        // Send empty query to flush any pending notifications from server
77        let bytes = PgEncoder::try_encode_query_string("")?;
78        self.write_all_with_timeout(&bytes, "stream write").await?;
79
80        // Read messages — use recv() for the initial empty query response
81        // (which completes quickly), then switch to no-timeout reads
82        let mut got_ready = false;
83        loop {
84            // Try to decode from the existing buffer first
85            if self.buffer.len() >= 5 {
86                let msg_len = u32::from_be_bytes([
87                    self.buffer[1],
88                    self.buffer[2],
89                    self.buffer[3],
90                    self.buffer[4],
91                ]) as usize;
92
93                if msg_len < 4 {
94                    return Err(super::PgError::Protocol(format!(
95                        "Invalid message length: {} (minimum 4)",
96                        msg_len
97                    )));
98                }
99
100                if msg_len > MAX_MESSAGE_SIZE {
101                    return Err(super::PgError::Protocol(format!(
102                        "Message too large: {} bytes (max {})",
103                        msg_len, MAX_MESSAGE_SIZE
104                    )));
105                }
106
107                if self.buffer.len() > msg_len {
108                    let msg_bytes = self.buffer.split_to(msg_len + 1);
109                    let (msg, _) =
110                        BackendMessage::decode(&msg_bytes).map_err(super::PgError::Protocol)?;
111
112                    match msg {
113                        BackendMessage::NotificationResponse {
114                            process_id,
115                            channel,
116                            payload,
117                        } => {
118                            return Ok(Notification {
119                                process_id,
120                                channel,
121                                payload,
122                            });
123                        }
124                        BackendMessage::EmptyQueryResponse => continue,
125                        BackendMessage::NoticeResponse(_) => continue,
126                        BackendMessage::ParameterStatus { .. } => continue,
127                        BackendMessage::CommandComplete(_) => continue,
128                        BackendMessage::ReadyForQuery(_) => {
129                            got_ready = true;
130                            // Check buffer for notifications that arrived with this batch
131                            if let Some(n) = self.notifications.pop_front() {
132                                return Ok(n);
133                            }
134                            continue;
135                        }
136                        BackendMessage::ErrorResponse(err) => {
137                            return Err(super::PgError::QueryServer(err.into()));
138                        }
139                        msg if is_ignorable_session_message(&msg) => continue,
140                        other => {
141                            return Err(unexpected_backend_message("listen/notify wait", &other));
142                        }
143                    }
144                }
145            }
146
147            // Read from socket — use tokio read (no timeout!) if we've
148            // already gotten ReadyForQuery (now we're just waiting for NOTIFY)
149            if self.buffer.capacity() - self.buffer.len() < 65536 {
150                self.buffer.reserve(131072);
151            }
152
153            if got_ready {
154                // LISTEN connections can stay idle for hours (empty buffer),
155                // but a partially buffered backend frame should still timeout
156                // to fail-closed on slowloris-style partial writes.
157                let n = if self.buffer.is_empty() {
158                    self.read_without_timeout().await?
159                } else {
160                    self.read_with_timeout().await?
161                };
162                if n == 0 {
163                    return Err(super::PgError::Connection("Connection closed".to_string()));
164                }
165            } else {
166                // Initial flush — use the normal timeout to avoid hanging
167                // if the server is unresponsive during the empty query
168                let n = self.read_with_timeout().await?;
169                if n == 0 {
170                    return Err(super::PgError::Connection("Connection closed".to_string()));
171                }
172            }
173        }
174    }
175}