qail_pg/driver/
notification.rs1use super::{
13 PgConnection, PgError, PgResult, io::MAX_MESSAGE_SIZE, is_ignorable_session_message,
14 unexpected_backend_message,
15};
16use crate::protocol::PgEncoder;
17
18#[derive(Debug, Clone)]
20pub struct Notification {
21 pub process_id: i32,
23 pub channel: String,
25 pub payload: String,
27}
28
29#[inline]
30fn return_with_desync<T>(conn: &mut PgConnection, err: PgError) -> PgResult<T> {
31 if matches!(
32 err,
33 PgError::Protocol(_) | PgError::Connection(_) | PgError::Timeout(_)
34 ) {
35 conn.mark_io_desynced();
36 }
37 Err(err)
38}
39
40impl PgConnection {
41 pub async fn listen(&mut self, channel: &str) -> PgResult<()> {
47 let sql = format!("LISTEN \"{}\"", channel.replace('"', "\"\""));
49 self.execute_simple(&sql).await
50 }
51
52 pub async fn unlisten(&mut self, channel: &str) -> PgResult<()> {
54 let sql = format!("UNLISTEN \"{}\"", channel.replace('"', "\"\""));
55 self.execute_simple(&sql).await
56 }
57
58 pub async fn unlisten_all(&mut self) -> PgResult<()> {
60 self.execute_simple("UNLISTEN *").await
61 }
62
63 pub fn poll_notifications(&mut self) -> Vec<Notification> {
69 self.notifications.drain(..).collect()
70 }
71
72 pub async fn recv_notification(&mut self) -> PgResult<Notification> {
80 use crate::protocol::BackendMessage;
81
82 if let Some(n) = self.notifications.pop_front() {
84 return Ok(n);
85 }
86
87 let bytes = PgEncoder::try_encode_query_string("")?;
89 self.write_all_with_timeout(&bytes, "stream write").await?;
90
91 let mut got_ready = false;
94 loop {
95 if self.buffer.len() >= 5 {
97 let msg_len = u32::from_be_bytes([
98 self.buffer[1],
99 self.buffer[2],
100 self.buffer[3],
101 self.buffer[4],
102 ]) as usize;
103
104 if msg_len < 4 {
105 return return_with_desync(
106 self,
107 PgError::Protocol(format!(
108 "Invalid message length: {} (minimum 4)",
109 msg_len
110 )),
111 );
112 }
113
114 if msg_len > MAX_MESSAGE_SIZE {
115 return return_with_desync(
116 self,
117 PgError::Protocol(format!(
118 "Message too large: {} bytes (max {})",
119 msg_len, MAX_MESSAGE_SIZE
120 )),
121 );
122 }
123
124 if self.buffer.len() > msg_len {
125 let msg_bytes = self.buffer.split_to(msg_len + 1);
126 let (msg, _) = match BackendMessage::decode(&msg_bytes) {
127 Ok(decoded) => decoded,
128 Err(err) => return return_with_desync(self, PgError::Protocol(err)),
129 };
130
131 match msg {
132 BackendMessage::NotificationResponse {
133 process_id,
134 channel,
135 payload,
136 } => {
137 return Ok(Notification {
138 process_id,
139 channel,
140 payload,
141 });
142 }
143 BackendMessage::EmptyQueryResponse => continue,
144 BackendMessage::NoticeResponse(_) => continue,
145 BackendMessage::ParameterStatus { .. } => continue,
146 BackendMessage::CommandComplete(_) => continue,
147 BackendMessage::ReadyForQuery(_) => {
148 got_ready = true;
149 if let Some(n) = self.notifications.pop_front() {
151 return Ok(n);
152 }
153 continue;
154 }
155 BackendMessage::ErrorResponse(err) => {
156 return Err(PgError::QueryServer(err.into()));
157 }
158 msg if is_ignorable_session_message(&msg) => continue,
159 other => {
160 return return_with_desync(
161 self,
162 unexpected_backend_message("listen/notify wait", &other),
163 );
164 }
165 }
166 }
167 }
168
169 if self.buffer.capacity() - self.buffer.len() < 65536 {
172 self.buffer.reserve(131072);
173 }
174
175 if got_ready {
176 let n = if self.buffer.is_empty() {
180 self.read_without_timeout().await?
181 } else {
182 self.read_with_timeout().await?
183 };
184 if n == 0 {
185 return return_with_desync(
186 self,
187 PgError::Connection("Connection closed".to_string()),
188 );
189 }
190 } else {
191 let n = self.read_with_timeout().await?;
194 if n == 0 {
195 return return_with_desync(
196 self,
197 PgError::Connection("Connection closed".to_string()),
198 );
199 }
200 }
201 }
202 }
203}
204
205#[cfg(test)]
206mod tests {
207 use super::return_with_desync;
208 use crate::driver::{PgConnection, PgError};
209
210 #[cfg(unix)]
211 fn test_conn() -> PgConnection {
212 use crate::driver::connection::StatementCache;
213 use crate::driver::stream::PgStream;
214 use bytes::BytesMut;
215 use std::collections::{HashMap, VecDeque};
216 use std::num::NonZeroUsize;
217 use tokio::net::UnixStream;
218
219 let (unix_stream, _peer) = UnixStream::pair().expect("unix stream pair");
220 PgConnection {
221 stream: PgStream::Unix(unix_stream),
222 buffer: BytesMut::with_capacity(1024),
223 write_buf: BytesMut::with_capacity(1024),
224 sql_buf: BytesMut::with_capacity(256),
225 params_buf: Vec::new(),
226 prepared_statements: HashMap::new(),
227 stmt_cache: StatementCache::new(NonZeroUsize::new(2).expect("non-zero")),
228 column_info_cache: HashMap::new(),
229 process_id: 0,
230 cancel_key_bytes: Vec::new(),
231 requested_protocol_minor: PgConnection::default_protocol_minor(),
232 negotiated_protocol_minor: PgConnection::default_protocol_minor(),
233 notifications: VecDeque::new(),
234 replication_stream_active: false,
235 replication_mode_enabled: false,
236 last_replication_wal_end: None,
237 io_desynced: false,
238 pending_statement_closes: Vec::new(),
239 draining_statement_closes: false,
240 }
241 }
242
243 #[cfg(unix)]
244 #[tokio::test]
245 async fn notification_return_with_desync_marks_protocol_error() {
246 let mut conn = test_conn();
247
248 let err =
249 return_with_desync::<()>(&mut conn, PgError::Protocol("bad notify frame".to_string()))
250 .expect_err("protocol error must be returned");
251
252 assert!(err.to_string().contains("bad notify frame"));
253 assert!(conn.is_io_desynced());
254 }
255}