1use super::{
13 PgConnection, PgError, PgResult, io::MAX_MESSAGE_SIZE, is_ignorable_session_message,
14 unexpected_backend_message,
15};
16use crate::protocol::PgEncoder;
17
18#[derive(Debug, Clone)]
20pub struct Notification {
21 pub process_id: i32,
23 pub channel: String,
25 pub payload: String,
27}
28
29#[inline]
30fn return_with_desync<T>(conn: &mut PgConnection, err: PgError) -> PgResult<T> {
31 if matches!(
32 err,
33 PgError::Protocol(_) | PgError::Connection(_) | PgError::Timeout(_)
34 ) {
35 conn.mark_io_desynced();
36 }
37 Err(err)
38}
39
40impl PgConnection {
41 pub async fn listen(&mut self, channel: &str) -> PgResult<()> {
47 let sql = format!("LISTEN \"{}\"", channel.replace('"', "\"\""));
49 self.execute_simple(&sql).await
50 }
51
52 pub async fn unlisten(&mut self, channel: &str) -> PgResult<()> {
54 let sql = format!("UNLISTEN \"{}\"", channel.replace('"', "\"\""));
55 self.execute_simple(&sql).await
56 }
57
58 pub async fn unlisten_all(&mut self) -> PgResult<()> {
60 self.execute_simple("UNLISTEN *").await
61 }
62
63 pub fn poll_notifications(&mut self) -> Vec<Notification> {
69 self.notifications.drain(..).collect()
70 }
71
72 pub async fn recv_notification(&mut self) -> PgResult<Notification> {
80 use crate::protocol::BackendMessage;
81
82 if let Some(n) = self.notifications.pop_front() {
84 return Ok(n);
85 }
86
87 let bytes = PgEncoder::try_encode_query_string("")?;
89 self.write_all_with_timeout(&bytes, "stream write").await?;
90
91 let mut got_ready = false;
94 loop {
95 if self.buffer.len() >= 5 {
97 let msg_len = u32::from_be_bytes([
98 self.buffer[1],
99 self.buffer[2],
100 self.buffer[3],
101 self.buffer[4],
102 ]) as usize;
103
104 if msg_len < 4 {
105 return return_with_desync(
106 self,
107 PgError::Protocol(format!(
108 "Invalid message length: {} (minimum 4)",
109 msg_len
110 )),
111 );
112 }
113
114 if msg_len > MAX_MESSAGE_SIZE {
115 return return_with_desync(
116 self,
117 PgError::Protocol(format!(
118 "Message too large: {} bytes (max {})",
119 msg_len, MAX_MESSAGE_SIZE
120 )),
121 );
122 }
123
124 if self.buffer.len() > msg_len {
125 let msg_bytes = self.buffer.split_to(msg_len + 1);
126 let (msg, _) = match BackendMessage::decode(&msg_bytes) {
127 Ok(decoded) => decoded,
128 Err(err) => return return_with_desync(self, PgError::Protocol(err)),
129 };
130
131 match msg {
132 BackendMessage::NotificationResponse {
133 process_id,
134 channel,
135 payload,
136 } => {
137 let notification = Notification {
138 process_id,
139 channel,
140 payload,
141 };
142 if got_ready {
143 return Ok(notification);
144 }
145 self.notifications.push_back(notification);
146 continue;
147 }
148 BackendMessage::EmptyQueryResponse => continue,
149 BackendMessage::NoticeResponse(_) => continue,
150 BackendMessage::ParameterStatus { .. } => continue,
151 BackendMessage::CommandComplete(_) => continue,
152 BackendMessage::ReadyForQuery(_) => {
153 got_ready = true;
154 if let Some(n) = self.notifications.pop_front() {
156 return Ok(n);
157 }
158 continue;
159 }
160 BackendMessage::ErrorResponse(err) => {
161 return Err(PgError::QueryServer(err.into()));
162 }
163 msg if is_ignorable_session_message(&msg) => continue,
164 other => {
165 return return_with_desync(
166 self,
167 unexpected_backend_message("listen/notify wait", &other),
168 );
169 }
170 }
171 }
172 }
173
174 if self.buffer.capacity() - self.buffer.len() < 65536 {
177 self.buffer.reserve(131072);
178 }
179
180 if got_ready {
181 let n = if self.buffer.is_empty() {
185 self.read_without_timeout().await?
186 } else {
187 self.read_with_timeout().await?
188 };
189 if n == 0 {
190 return return_with_desync(
191 self,
192 PgError::Connection("Connection closed".to_string()),
193 );
194 }
195 } else {
196 let n = self.read_with_timeout().await?;
199 if n == 0 {
200 return return_with_desync(
201 self,
202 PgError::Connection("Connection closed".to_string()),
203 );
204 }
205 }
206 }
207 }
208}
209
210#[cfg(test)]
211mod tests {
212 use super::return_with_desync;
213 use crate::driver::{PgConnection, PgError};
214
215 #[cfg(unix)]
216 fn test_conn_with_peer() -> (PgConnection, tokio::net::UnixStream) {
217 use crate::driver::connection::StatementCache;
218 use crate::driver::stream::PgStream;
219 use bytes::BytesMut;
220 use std::collections::{HashMap, VecDeque};
221 use std::num::NonZeroUsize;
222 use tokio::net::UnixStream;
223
224 let (unix_stream, peer) = UnixStream::pair().expect("unix stream pair");
225 (
226 PgConnection {
227 stream: PgStream::Unix(unix_stream),
228 buffer: BytesMut::with_capacity(1024),
229 write_buf: BytesMut::with_capacity(1024),
230 sql_buf: BytesMut::with_capacity(256),
231 params_buf: Vec::new(),
232 prepared_statements: HashMap::new(),
233 stmt_cache: StatementCache::new(NonZeroUsize::new(2).expect("non-zero")),
234 column_info_cache: HashMap::new(),
235 process_id: 0,
236 cancel_key_bytes: Vec::new(),
237 requested_protocol_minor: PgConnection::default_protocol_minor(),
238 negotiated_protocol_minor: PgConnection::default_protocol_minor(),
239 notifications: VecDeque::new(),
240 replication_stream_active: false,
241 replication_mode_enabled: false,
242 last_replication_wal_end: None,
243 io_desynced: false,
244 pending_statement_closes: Vec::new(),
245 draining_statement_closes: false,
246 },
247 peer,
248 )
249 }
250
251 #[cfg(unix)]
252 fn test_conn() -> PgConnection {
253 test_conn_with_peer().0
254 }
255
256 #[cfg(unix)]
257 fn push_backend_frame(conn: &mut PgConnection, msg_type: u8, payload: &[u8]) {
258 conn.buffer.extend_from_slice(&[msg_type]);
259 conn.buffer
260 .extend_from_slice(&((payload.len() + 4) as u32).to_be_bytes());
261 conn.buffer.extend_from_slice(payload);
262 }
263
264 #[cfg(unix)]
265 fn notification_payload(process_id: i32, channel: &str, payload: &str) -> Vec<u8> {
266 let mut bytes = Vec::new();
267 bytes.extend_from_slice(&process_id.to_be_bytes());
268 bytes.extend_from_slice(channel.as_bytes());
269 bytes.push(0);
270 bytes.extend_from_slice(payload.as_bytes());
271 bytes.push(0);
272 bytes
273 }
274
275 #[cfg(unix)]
276 #[tokio::test]
277 async fn notification_return_with_desync_marks_protocol_error() {
278 let mut conn = test_conn();
279
280 let err =
281 return_with_desync::<()>(&mut conn, PgError::Protocol("bad notify frame".to_string()))
282 .expect_err("protocol error must be returned");
283
284 assert!(err.to_string().contains("bad notify frame"));
285 assert!(conn.is_io_desynced());
286 }
287
288 #[cfg(unix)]
289 #[tokio::test]
290 async fn recv_notification_drains_empty_query_before_returning_pre_ready_notify() {
291 let (mut conn, _peer) = test_conn_with_peer();
292 let payload = notification_payload(42, "jobs", "ready");
293
294 push_backend_frame(&mut conn, b'A', &payload);
295 push_backend_frame(&mut conn, b'I', &[]);
296 push_backend_frame(&mut conn, b'Z', b"I");
297
298 let notification = conn
299 .recv_notification()
300 .await
301 .expect("pre-ready notification should be returned after flush drain");
302
303 assert_eq!(notification.process_id, 42);
304 assert_eq!(notification.channel, "jobs");
305 assert_eq!(notification.payload, "ready");
306 assert!(
307 conn.buffer.is_empty(),
308 "empty-query flush frames must not remain buffered"
309 );
310 assert!(!conn.is_io_desynced());
311 }
312}