Skip to main content

qail_pg/driver/
notification.rs

1//! LISTEN/NOTIFY support for PostgreSQL connections.
2//!
3//! PostgreSQL sends `NotificationResponse` messages asynchronously when
4//! a channel the connection is LISTENing on receives a NOTIFY.
5//!
6//! This module provides:
7//! - `Notification` struct — channel name + payload
8//! - `listen()` / `unlisten()` — subscribe/unsubscribe to channels
9//! - `poll_notifications()` — drain buffered notifications (non-blocking)
10//! - `recv_notification()` — block-wait for the next notification
11
12use super::{
13    PgConnection, PgError, PgResult, io::MAX_MESSAGE_SIZE, is_ignorable_session_message,
14    unexpected_backend_message,
15};
16use crate::protocol::PgEncoder;
17
18/// A notification received from PostgreSQL LISTEN/NOTIFY.
19#[derive(Debug, Clone)]
20pub struct Notification {
21    /// The PID of the notifying backend process
22    pub process_id: i32,
23    /// The channel name
24    pub channel: String,
25    /// The payload (may be empty)
26    pub payload: String,
27}
28
29#[inline]
30fn return_with_desync<T>(conn: &mut PgConnection, err: PgError) -> PgResult<T> {
31    if matches!(
32        err,
33        PgError::Protocol(_) | PgError::Connection(_) | PgError::Timeout(_)
34    ) {
35        conn.mark_io_desynced();
36    }
37    Err(err)
38}
39
40impl PgConnection {
41    /// Subscribe to a notification channel.
42    ///
43    /// ```ignore
44    /// conn.listen("price_calendar_changed").await?;
45    /// ```
46    pub async fn listen(&mut self, channel: &str) -> PgResult<()> {
47        // Channel names are identifiers, quote them to prevent injection
48        let sql = format!("LISTEN \"{}\"", channel.replace('"', "\"\""));
49        self.execute_simple(&sql).await
50    }
51
52    /// Unsubscribe from a notification channel.
53    pub async fn unlisten(&mut self, channel: &str) -> PgResult<()> {
54        let sql = format!("UNLISTEN \"{}\"", channel.replace('"', "\"\""));
55        self.execute_simple(&sql).await
56    }
57
58    /// Unsubscribe from all notification channels.
59    pub async fn unlisten_all(&mut self) -> PgResult<()> {
60        self.execute_simple("UNLISTEN *").await
61    }
62
63    /// Drain all buffered notifications without blocking.
64    ///
65    /// Notifications arrive asynchronously from PostgreSQL and are buffered
66    /// whenever `recv()` encounters a `NotificationResponse`. This method
67    /// returns all currently buffered notifications.
68    pub fn poll_notifications(&mut self) -> Vec<Notification> {
69        self.notifications.drain(..).collect()
70    }
71
72    /// Wait for the next notification, blocking until one arrives.
73    ///
74    /// Unlike `recv()`, this does NOT use the 30-second Slowloris timeout
75    /// guard. LISTEN connections idle for long periods — that's normal,
76    /// not a DoS attack.
77    ///
78    /// Useful for a dedicated LISTEN connection in a background task.
79    pub async fn recv_notification(&mut self) -> PgResult<Notification> {
80        use crate::protocol::BackendMessage;
81
82        // Return buffered notification immediately if available
83        if let Some(n) = self.notifications.pop_front() {
84            return Ok(n);
85        }
86
87        // Send empty query to flush any pending notifications from server
88        let bytes = PgEncoder::try_encode_query_string("")?;
89        self.write_all_with_timeout(&bytes, "stream write").await?;
90
91        // Read messages — use recv() for the initial empty query response
92        // (which completes quickly), then switch to no-timeout reads
93        let mut got_ready = false;
94        loop {
95            // Try to decode from the existing buffer first
96            if self.buffer.len() >= 5 {
97                let msg_len = u32::from_be_bytes([
98                    self.buffer[1],
99                    self.buffer[2],
100                    self.buffer[3],
101                    self.buffer[4],
102                ]) as usize;
103
104                if msg_len < 4 {
105                    return return_with_desync(
106                        self,
107                        PgError::Protocol(format!(
108                            "Invalid message length: {} (minimum 4)",
109                            msg_len
110                        )),
111                    );
112                }
113
114                if msg_len > MAX_MESSAGE_SIZE {
115                    return return_with_desync(
116                        self,
117                        PgError::Protocol(format!(
118                            "Message too large: {} bytes (max {})",
119                            msg_len, MAX_MESSAGE_SIZE
120                        )),
121                    );
122                }
123
124                if self.buffer.len() > msg_len {
125                    let msg_bytes = self.buffer.split_to(msg_len + 1);
126                    let (msg, _) = match BackendMessage::decode(&msg_bytes) {
127                        Ok(decoded) => decoded,
128                        Err(err) => return return_with_desync(self, PgError::Protocol(err)),
129                    };
130
131                    match msg {
132                        BackendMessage::NotificationResponse {
133                            process_id,
134                            channel,
135                            payload,
136                        } => {
137                            return Ok(Notification {
138                                process_id,
139                                channel,
140                                payload,
141                            });
142                        }
143                        BackendMessage::EmptyQueryResponse => continue,
144                        BackendMessage::NoticeResponse(_) => continue,
145                        BackendMessage::ParameterStatus { .. } => continue,
146                        BackendMessage::CommandComplete(_) => continue,
147                        BackendMessage::ReadyForQuery(_) => {
148                            got_ready = true;
149                            // Check buffer for notifications that arrived with this batch
150                            if let Some(n) = self.notifications.pop_front() {
151                                return Ok(n);
152                            }
153                            continue;
154                        }
155                        BackendMessage::ErrorResponse(err) => {
156                            return Err(PgError::QueryServer(err.into()));
157                        }
158                        msg if is_ignorable_session_message(&msg) => continue,
159                        other => {
160                            return return_with_desync(
161                                self,
162                                unexpected_backend_message("listen/notify wait", &other),
163                            );
164                        }
165                    }
166                }
167            }
168
169            // Read from socket — use tokio read (no timeout!) if we've
170            // already gotten ReadyForQuery (now we're just waiting for NOTIFY)
171            if self.buffer.capacity() - self.buffer.len() < 65536 {
172                self.buffer.reserve(131072);
173            }
174
175            if got_ready {
176                // LISTEN connections can stay idle for hours (empty buffer),
177                // but a partially buffered backend frame should still timeout
178                // to fail-closed on slowloris-style partial writes.
179                let n = if self.buffer.is_empty() {
180                    self.read_without_timeout().await?
181                } else {
182                    self.read_with_timeout().await?
183                };
184                if n == 0 {
185                    return return_with_desync(
186                        self,
187                        PgError::Connection("Connection closed".to_string()),
188                    );
189                }
190            } else {
191                // Initial flush — use the normal timeout to avoid hanging
192                // if the server is unresponsive during the empty query
193                let n = self.read_with_timeout().await?;
194                if n == 0 {
195                    return return_with_desync(
196                        self,
197                        PgError::Connection("Connection closed".to_string()),
198                    );
199                }
200            }
201        }
202    }
203}
204
205#[cfg(test)]
206mod tests {
207    use super::return_with_desync;
208    use crate::driver::{PgConnection, PgError};
209
210    #[cfg(unix)]
211    fn test_conn() -> PgConnection {
212        use crate::driver::connection::StatementCache;
213        use crate::driver::stream::PgStream;
214        use bytes::BytesMut;
215        use std::collections::{HashMap, VecDeque};
216        use std::num::NonZeroUsize;
217        use tokio::net::UnixStream;
218
219        let (unix_stream, _peer) = UnixStream::pair().expect("unix stream pair");
220        PgConnection {
221            stream: PgStream::Unix(unix_stream),
222            buffer: BytesMut::with_capacity(1024),
223            write_buf: BytesMut::with_capacity(1024),
224            sql_buf: BytesMut::with_capacity(256),
225            params_buf: Vec::new(),
226            prepared_statements: HashMap::new(),
227            stmt_cache: StatementCache::new(NonZeroUsize::new(2).expect("non-zero")),
228            column_info_cache: HashMap::new(),
229            process_id: 0,
230            cancel_key_bytes: Vec::new(),
231            requested_protocol_minor: PgConnection::default_protocol_minor(),
232            negotiated_protocol_minor: PgConnection::default_protocol_minor(),
233            notifications: VecDeque::new(),
234            replication_stream_active: false,
235            replication_mode_enabled: false,
236            last_replication_wal_end: None,
237            io_desynced: false,
238            pending_statement_closes: Vec::new(),
239            draining_statement_closes: false,
240        }
241    }
242
243    #[cfg(unix)]
244    #[tokio::test]
245    async fn notification_return_with_desync_marks_protocol_error() {
246        let mut conn = test_conn();
247
248        let err =
249            return_with_desync::<()>(&mut conn, PgError::Protocol("bad notify frame".to_string()))
250                .expect_err("protocol error must be returned");
251
252        assert!(err.to_string().contains("bad notify frame"));
253        assert!(conn.is_io_desynced());
254    }
255}