sqlx_postgres/listener.rs
1use std::fmt::{self, Debug};
2use std::io;
3use std::str::from_utf8;
4
5use futures_channel::mpsc;
6use futures_core::future::BoxFuture;
7use futures_core::stream::{BoxStream, Stream};
8use futures_util::{FutureExt, StreamExt, TryFutureExt, TryStreamExt};
9use sqlx_core::acquire::Acquire;
10use sqlx_core::sql_str::{AssertSqlSafe, SqlStr};
11use sqlx_core::transaction::Transaction;
12use sqlx_core::Either;
13use tracing::Instrument;
14
15use crate::error::Error;
16use crate::executor::{Execute, Executor};
17use crate::message::{BackendMessageFormat, Notification};
18use crate::pool::PoolOptions;
19use crate::pool::{Pool, PoolConnection};
20use crate::{PgConnection, PgQueryResult, PgRow, PgStatement, PgTypeInfo, Postgres};
21
22/// A stream of asynchronous notifications from Postgres.
23///
24/// This listener will auto-reconnect. If the active
25/// connection being used ever dies, this listener will detect that event, create a
26/// new connection, will re-subscribe to all of the originally specified channels, and will resume
27/// operations as normal.
28pub struct PgListener {
29 pool: Pool<Postgres>,
30 connection: Option<PoolConnection<Postgres>>,
31 buffer_rx: mpsc::UnboundedReceiver<Notification>,
32 buffer_tx: Option<mpsc::UnboundedSender<Notification>>,
33 channels: Vec<String>,
34 ignore_close_event: bool,
35 eager_reconnect: bool,
36}
37
38/// An asynchronous notification from Postgres.
39pub struct PgNotification(Notification);
40
41impl PgListener {
42 pub async fn connect(url: &str) -> Result<Self, Error> {
43 // Create a pool of 1 without timeouts (as they don't apply here)
44 // We only use the pool to handle re-connections
45 let pool = PoolOptions::<Postgres>::new()
46 .max_connections(1)
47 .max_lifetime(None)
48 .idle_timeout(None)
49 .connect(url)
50 .await?;
51
52 let mut this = Self::connect_with(&pool).await?;
53 // We don't need to handle close events
54 this.ignore_close_event = true;
55
56 Ok(this)
57 }
58
59 pub async fn connect_with(pool: &Pool<Postgres>) -> Result<Self, Error> {
60 // Pull out an initial connection
61 let mut connection = pool.acquire().await?;
62
63 // Setup a notification buffer
64 let (sender, receiver) = mpsc::unbounded();
65 connection.inner.stream.notifications = Some(sender);
66
67 Ok(Self {
68 pool: pool.clone(),
69 connection: Some(connection),
70 buffer_rx: receiver,
71 buffer_tx: None,
72 channels: Vec::new(),
73 ignore_close_event: false,
74 eager_reconnect: true,
75 })
76 }
77
78 /// Set whether or not to ignore [`Pool::close_event()`]. Defaults to `false`.
79 ///
80 /// By default, when [`Pool::close()`] is called on the pool this listener is using
81 /// while [`Self::recv()`] or [`Self::try_recv()`] are waiting for a message, the wait is
82 /// cancelled and `Err(PoolClosed)` is returned.
83 ///
84 /// This is because `Pool::close()` will wait until _all_ connections are returned and closed,
85 /// including the one being used by this listener.
86 ///
87 /// Otherwise, `pool.close().await` would have to wait until `PgListener` encountered a
88 /// need to acquire a new connection (timeout, error, etc.) and dropped the one it was
89 /// currently holding, at which point `.recv()` or `.try_recv()` would return `Err(PoolClosed)`
90 /// on the attempt to acquire a new connection anyway.
91 ///
92 /// However, if you want `PgListener` to ignore the close event and continue waiting for a
93 /// message as long as it can, set this to `true`.
94 ///
95 /// Does nothing if this was constructed with [`PgListener::connect()`], as that creates an
96 /// internal pool just for the new instance of `PgListener` which cannot be closed manually.
97 pub fn ignore_pool_close_event(&mut self, val: bool) {
98 self.ignore_close_event = val;
99 }
100
101 /// Set whether a lost connection in `try_recv()` should be re-established before it returns
102 /// `Ok(None)`, or on the next call to `try_recv()`.
103 ///
104 /// By default, this is `true` and the connection is re-established before returning `Ok(None)`.
105 ///
106 /// If this is set to `false` then notifications will continue to be lost until the next call
107 /// to `try_recv()`. If your recovery logic uses a different database connection then
108 /// notifications that occur after it completes may be lost without any way to tell that they
109 /// have been.
110 pub fn eager_reconnect(&mut self, val: bool) {
111 self.eager_reconnect = val;
112 }
113
114 /// Starts listening for notifications on a channel.
115 /// The channel name is quoted here to ensure case sensitivity.
116 pub async fn listen(&mut self, channel: &str) -> Result<(), Error> {
117 self.connection()
118 .await?
119 .execute(AssertSqlSafe(format!(r#"LISTEN "{}""#, ident(channel))))
120 .await?;
121
122 self.channels.push(channel.to_owned());
123
124 Ok(())
125 }
126
127 /// Starts listening for notifications on all channels.
128 pub async fn listen_all(
129 &mut self,
130 channels: impl IntoIterator<Item = &str>,
131 ) -> Result<(), Error> {
132 let beg = self.channels.len();
133 self.channels.extend(channels.into_iter().map(|s| s.into()));
134
135 let query = build_listen_all_query(&self.channels[beg..]);
136 self.connection()
137 .await?
138 .execute(AssertSqlSafe(query))
139 .await?;
140
141 Ok(())
142 }
143
144 /// Stops listening for notifications on a channel.
145 /// The channel name is quoted here to ensure case sensitivity.
146 pub async fn unlisten(&mut self, channel: &str) -> Result<(), Error> {
147 // use RAW connection and do NOT re-connect automatically, since this is not required for
148 // UNLISTEN (we've disconnected anyways)
149 if let Some(connection) = self.connection.as_mut() {
150 connection
151 .execute(AssertSqlSafe(format!(r#"UNLISTEN "{}""#, ident(channel))))
152 .await?;
153 }
154
155 if let Some(pos) = self.channels.iter().position(|s| s == channel) {
156 self.channels.remove(pos);
157 }
158
159 Ok(())
160 }
161
162 /// Stops listening for notifications on all channels.
163 pub async fn unlisten_all(&mut self) -> Result<(), Error> {
164 // use RAW connection and do NOT re-connect automatically, since this is not required for
165 // UNLISTEN (we've disconnected anyways)
166 if let Some(connection) = self.connection.as_mut() {
167 connection.execute("UNLISTEN *").await?;
168 }
169
170 self.channels.clear();
171
172 Ok(())
173 }
174
175 #[inline]
176 async fn connect_if_needed(&mut self) -> Result<(), Error> {
177 if self.connection.is_none() {
178 let mut connection = self.pool.acquire().await?;
179 connection.inner.stream.notifications = self.buffer_tx.take();
180
181 connection
182 .execute(AssertSqlSafe(build_listen_all_query(&self.channels)))
183 .await?;
184
185 self.connection = Some(connection);
186 }
187
188 Ok(())
189 }
190
191 #[inline]
192 async fn connection(&mut self) -> Result<&mut PgConnection, Error> {
193 // Ensure we have an active connection to work with.
194 self.connect_if_needed().await?;
195
196 Ok(self.connection.as_mut().unwrap())
197 }
198
199 /// Receives the next notification available from any of the subscribed channels.
200 ///
201 /// If the connection to PostgreSQL is lost, it is automatically reconnected on the next
202 /// call to `recv()`, and should be entirely transparent (as long as it was just an
203 /// intermittent network failure or long-lived connection reaper).
204 ///
205 /// As notifications are transient, any received while the connection was lost, will not
206 /// be returned. If you'd prefer the reconnection to be explicit and have a chance to
207 /// do something before, please see [`try_recv`](Self::try_recv).
208 ///
209 /// # Example
210 ///
211 /// ```rust,no_run
212 /// # use sqlx::postgres::PgListener;
213 /// #
214 /// # sqlx::__rt::test_block_on(async move {
215 /// let mut listener = PgListener::connect("postgres:// ...").await?;
216 /// loop {
217 /// // ask for next notification, re-connecting (transparently) if needed
218 /// let notification = listener.recv().await?;
219 ///
220 /// // handle notification, do something interesting
221 /// }
222 /// # Result::<(), sqlx::Error>::Ok(())
223 /// # }).unwrap();
224 /// ```
225 pub async fn recv(&mut self) -> Result<PgNotification, Error> {
226 loop {
227 if let Some(notification) = self.try_recv().await? {
228 return Ok(notification);
229 }
230 }
231 }
232
233 /// Receives the next notification available from any of the subscribed channels.
234 ///
235 /// If the connection to PostgreSQL is lost, `None` is returned, and the connection is
236 /// reconnected either immediately, or on the next call to `try_recv()`, depending on
237 /// the value of [`eager_reconnect`].
238 ///
239 /// # Example
240 ///
241 /// ```rust,no_run
242 /// # use sqlx::postgres::PgListener;
243 /// #
244 /// # sqlx::__rt::test_block_on(async move {
245 /// # let mut listener = PgListener::connect("postgres:// ...").await?;
246 /// loop {
247 /// // start handling notifications, connecting if needed
248 /// while let Some(notification) = listener.try_recv().await? {
249 /// // handle notification
250 /// }
251 ///
252 /// // connection lost, do something interesting
253 /// }
254 /// # Result::<(), sqlx::Error>::Ok(())
255 /// # }).unwrap();
256 /// ```
257 ///
258 /// [`eager_reconnect`]: PgListener::eager_reconnect
259 pub async fn try_recv(&mut self) -> Result<Option<PgNotification>, Error> {
260 // Flush the buffer first, if anything
261 // This would only fill up if this listener is used as a connection
262 if let Some(notification) = self.next_buffered() {
263 return Ok(Some(notification));
264 }
265
266 // Fetch our `CloseEvent` listener, if applicable.
267 let mut close_event = (!self.ignore_close_event).then(|| self.pool.close_event());
268
269 loop {
270 let next_message = self.connection().await?.inner.stream.recv_unchecked();
271
272 let res = if let Some(ref mut close_event) = close_event {
273 // cancels the wait and returns `Err(PoolClosed)` if the pool is closed
274 // before `next_message` returns, or if the pool was already closed
275 close_event.do_until(next_message).await?
276 } else {
277 next_message.await
278 };
279
280 let message = match res {
281 Ok(message) => message,
282
283 // The connection is dead, ensure that it is dropped,
284 // update self state, and loop to try again.
285 Err(Error::Io(err))
286 if matches!(
287 err.kind(),
288 io::ErrorKind::ConnectionAborted |
289 io::ErrorKind::UnexpectedEof |
290 // see ERRORS section in tcp(7) man page (https://man7.org/linux/man-pages/man7/tcp.7.html)
291 io::ErrorKind::TimedOut |
292 io::ErrorKind::BrokenPipe
293 ) =>
294 {
295 if let Some(mut conn) = self.connection.take() {
296 self.buffer_tx = conn.inner.stream.notifications.take();
297 // Close the connection in a background task, so we can continue.
298 conn.close_on_drop();
299 }
300
301 if self.eager_reconnect {
302 self.connect_if_needed().await?;
303 }
304
305 // lost connection
306 return Ok(None);
307 }
308
309 // Forward other errors
310 Err(error) => {
311 return Err(error);
312 }
313 };
314
315 match message.format {
316 // We've received an async notification, return it.
317 BackendMessageFormat::NotificationResponse => {
318 return Ok(Some(PgNotification(message.decode()?)));
319 }
320
321 // Mark the connection as ready for another query
322 BackendMessageFormat::ReadyForQuery => {
323 self.connection().await?.inner.pending_ready_for_query_count -= 1;
324 }
325
326 // Ignore unexpected messages
327 _ => {}
328 }
329 }
330 }
331
332 /// Receives the next notification that already exists in the connection buffer, if any.
333 ///
334 /// This is similar to `try_recv`, except it will not wait if the connection has not yet received a notification.
335 ///
336 /// This is helpful if you want to retrieve all buffered notifications and process them in batches.
337 pub fn next_buffered(&mut self) -> Option<PgNotification> {
338 if let Ok(Some(notification)) = self.buffer_rx.try_next() {
339 Some(PgNotification(notification))
340 } else {
341 None
342 }
343 }
344
345 /// Consume this listener, returning a `Stream` of notifications.
346 ///
347 /// The backing connection will be automatically reconnected should it be lost.
348 ///
349 /// This has the same potential drawbacks as [`recv`](PgListener::recv).
350 ///
351 pub fn into_stream(mut self) -> impl Stream<Item = Result<PgNotification, Error>> + Unpin {
352 Box::pin(try_stream! {
353 loop {
354 r#yield!(self.recv().await?);
355 }
356 })
357 }
358}
359
360impl Drop for PgListener {
361 fn drop(&mut self) {
362 if let Some(mut conn) = self.connection.take() {
363 let fut = async move {
364 let _ = conn.execute("UNLISTEN *").await;
365
366 // inline the drop handler from `PoolConnection` so it doesn't try to spawn another task
367 // otherwise, it may trigger a panic if this task is dropped because the runtime is going away:
368 // https://github.com/launchbadge/sqlx/issues/1389
369 conn.return_to_pool().await;
370 };
371
372 // Unregister any listeners before returning the connection to the pool.
373 crate::rt::spawn(fut.in_current_span());
374 }
375 }
376}
377
378impl<'c> Acquire<'c> for &'c mut PgListener {
379 type Database = Postgres;
380 type Connection = &'c mut PgConnection;
381
382 fn acquire(self) -> BoxFuture<'c, Result<Self::Connection, Error>> {
383 self.connection().boxed()
384 }
385
386 fn begin(self) -> BoxFuture<'c, Result<Transaction<'c, Self::Database>, Error>> {
387 self.connection().and_then(|c| c.begin()).boxed()
388 }
389}
390
391impl<'c> Executor<'c> for &'c mut PgListener {
392 type Database = Postgres;
393
394 fn fetch_many<'e, 'q, E>(
395 self,
396 query: E,
397 ) -> BoxStream<'e, Result<Either<PgQueryResult, PgRow>, Error>>
398 where
399 'c: 'e,
400 E: Execute<'q, Self::Database>,
401 'q: 'e,
402 E: 'q,
403 {
404 futures_util::stream::once(async move {
405 // need some basic type annotation to help the compiler a bit
406 let res: Result<_, Error> = Ok(self.connection().await?.fetch_many(query));
407 res
408 })
409 .try_flatten()
410 .boxed()
411 }
412
413 fn fetch_optional<'e, 'q, E>(self, query: E) -> BoxFuture<'e, Result<Option<PgRow>, Error>>
414 where
415 'c: 'e,
416 E: Execute<'q, Self::Database>,
417 'q: 'e,
418 E: 'q,
419 {
420 async move { self.connection().await?.fetch_optional(query).await }.boxed()
421 }
422
423 fn prepare_with<'e>(
424 self,
425 query: SqlStr,
426 parameters: &'e [PgTypeInfo],
427 ) -> BoxFuture<'e, Result<PgStatement, Error>>
428 where
429 'c: 'e,
430 {
431 async move {
432 self.connection()
433 .await?
434 .prepare_with(query, parameters)
435 .await
436 }
437 .boxed()
438 }
439
440 #[doc(hidden)]
441 #[cfg(feature = "offline")]
442 fn describe<'e>(
443 self,
444 query: SqlStr,
445 ) -> BoxFuture<'e, Result<crate::describe::Describe<Self::Database>, Error>>
446 where
447 'c: 'e,
448 {
449 async move { self.connection().await?.describe(query).await }.boxed()
450 }
451}
452
453impl PgNotification {
454 /// The process ID of the notifying backend process.
455 #[inline]
456 pub fn process_id(&self) -> u32 {
457 self.0.process_id
458 }
459
460 /// The channel that the notify has been raised on. This can be thought
461 /// of as the message topic.
462 #[inline]
463 pub fn channel(&self) -> &str {
464 from_utf8(&self.0.channel).unwrap()
465 }
466
467 /// The payload of the notification. An empty payload is received as an
468 /// empty string.
469 #[inline]
470 pub fn payload(&self) -> &str {
471 from_utf8(&self.0.payload).unwrap()
472 }
473}
474
475impl Debug for PgListener {
476 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
477 f.debug_struct("PgListener").finish()
478 }
479}
480
481impl Debug for PgNotification {
482 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
483 f.debug_struct("PgNotification")
484 .field("process_id", &self.process_id())
485 .field("channel", &self.channel())
486 .field("payload", &self.payload())
487 .finish()
488 }
489}
490
491fn ident(mut name: &str) -> String {
492 // If the input string contains a NUL byte, we should truncate the
493 // identifier.
494 if let Some(index) = name.find('\0') {
495 name = &name[..index];
496 }
497
498 // Any double quotes must be escaped
499 name.replace('"', "\"\"")
500}
501
502fn build_listen_all_query(channels: impl IntoIterator<Item = impl AsRef<str>>) -> String {
503 channels.into_iter().fold(String::new(), |mut acc, chan| {
504 acc.push_str(r#"LISTEN ""#);
505 acc.push_str(&ident(chan.as_ref()));
506 acc.push_str(r#"";"#);
507 acc
508 })
509}
510
511#[test]
512fn test_build_listen_all_query_with_single_channel() {
513 let output = build_listen_all_query(["test"]);
514 assert_eq!(output.as_str(), r#"LISTEN "test";"#);
515}
516
517#[test]
518fn test_build_listen_all_query_with_multiple_channels() {
519 let output = build_listen_all_query(["channel.0", "channel.1"]);
520 assert_eq!(output.as_str(), r#"LISTEN "channel.0";LISTEN "channel.1";"#);
521}