Skip to main content

sqlx_postgres/
listener.rs

1use std::fmt::{self, Debug};
2use std::io;
3use std::str::from_utf8;
4
5use futures_channel::mpsc;
6use futures_core::future::BoxFuture;
7use futures_core::stream::{BoxStream, Stream};
8use futures_util::{FutureExt, StreamExt, TryFutureExt, TryStreamExt};
9use sqlx_core::acquire::Acquire;
10use sqlx_core::sql_str::{AssertSqlSafe, SqlStr};
11use sqlx_core::transaction::Transaction;
12use sqlx_core::Either;
13use tracing::Instrument;
14
15use crate::error::Error;
16use crate::executor::{Execute, Executor};
17use crate::message::{BackendMessageFormat, Notification};
18use crate::pool::PoolOptions;
19use crate::pool::{Pool, PoolConnection};
20use crate::{PgConnection, PgQueryResult, PgRow, PgStatement, PgTypeInfo, Postgres};
21
22/// A stream of asynchronous notifications from Postgres.
23///
24/// This listener will auto-reconnect. If the active
25/// connection being used ever dies, this listener will detect that event, create a
26/// new connection, will re-subscribe to all of the originally specified channels, and will resume
27/// operations as normal.
28pub struct PgListener {
29    pool: Pool<Postgres>,
30    connection: Option<PoolConnection<Postgres>>,
31    buffer_rx: mpsc::UnboundedReceiver<Notification>,
32    buffer_tx: Option<mpsc::UnboundedSender<Notification>>,
33    channels: Vec<String>,
34    ignore_close_event: bool,
35    eager_reconnect: bool,
36}
37
38/// An asynchronous notification from Postgres.
39pub struct PgNotification(Notification);
40
41impl PgListener {
42    pub async fn connect(url: &str) -> Result<Self, Error> {
43        // Create a pool of 1 without timeouts (as they don't apply here)
44        // We only use the pool to handle re-connections
45        let pool = PoolOptions::<Postgres>::new()
46            .max_connections(1)
47            .max_lifetime(None)
48            .idle_timeout(None)
49            .connect(url)
50            .await?;
51
52        let mut this = Self::connect_with(&pool).await?;
53        // We don't need to handle close events
54        this.ignore_close_event = true;
55
56        Ok(this)
57    }
58
59    pub async fn connect_with(pool: &Pool<Postgres>) -> Result<Self, Error> {
60        // Pull out an initial connection
61        let mut connection = pool.acquire().await?;
62
63        // Setup a notification buffer
64        let (sender, receiver) = mpsc::unbounded();
65        connection.inner.stream.notifications = Some(sender);
66
67        Ok(Self {
68            pool: pool.clone(),
69            connection: Some(connection),
70            buffer_rx: receiver,
71            buffer_tx: None,
72            channels: Vec::new(),
73            ignore_close_event: false,
74            eager_reconnect: true,
75        })
76    }
77
78    /// Set whether or not to ignore [`Pool::close_event()`]. Defaults to `false`.
79    ///
80    /// By default, when [`Pool::close()`] is called on the pool this listener is using
81    /// while [`Self::recv()`] or [`Self::try_recv()`] are waiting for a message, the wait is
82    /// cancelled and `Err(PoolClosed)` is returned.
83    ///
84    /// This is because `Pool::close()` will wait until _all_ connections are returned and closed,
85    /// including the one being used by this listener.
86    ///
87    /// Otherwise, `pool.close().await` would have to wait until `PgListener` encountered a
88    /// need to acquire a new connection (timeout, error, etc.) and dropped the one it was
89    /// currently holding, at which point `.recv()` or `.try_recv()` would return `Err(PoolClosed)`
90    /// on the attempt to acquire a new connection anyway.
91    ///
92    /// However, if you want `PgListener` to ignore the close event and continue waiting for a
93    /// message as long as it can, set this to `true`.
94    ///
95    /// Does nothing if this was constructed with [`PgListener::connect()`], as that creates an
96    /// internal pool just for the new instance of `PgListener` which cannot be closed manually.
97    pub fn ignore_pool_close_event(&mut self, val: bool) {
98        self.ignore_close_event = val;
99    }
100
101    /// Set whether a lost connection in `try_recv()` should be re-established before it returns
102    /// `Ok(None)`, or on the next call to `try_recv()`.
103    ///
104    /// By default, this is `true` and the connection is re-established before returning `Ok(None)`.
105    ///
106    /// If this is set to `false` then notifications will continue to be lost until the next call
107    /// to `try_recv()`. If your recovery logic uses a different database connection then
108    /// notifications that occur after it completes may be lost without any way to tell that they
109    /// have been.
110    pub fn eager_reconnect(&mut self, val: bool) {
111        self.eager_reconnect = val;
112    }
113
114    /// Starts listening for notifications on a channel.
115    /// The channel name is quoted here to ensure case sensitivity.
116    pub async fn listen(&mut self, channel: &str) -> Result<(), Error> {
117        self.connection()
118            .await?
119            .execute(AssertSqlSafe(format!(r#"LISTEN "{}""#, ident(channel))))
120            .await?;
121
122        self.channels.push(channel.to_owned());
123
124        Ok(())
125    }
126
127    /// Starts listening for notifications on all channels.
128    pub async fn listen_all(
129        &mut self,
130        channels: impl IntoIterator<Item = &str>,
131    ) -> Result<(), Error> {
132        let beg = self.channels.len();
133        self.channels.extend(channels.into_iter().map(|s| s.into()));
134
135        let query = build_listen_all_query(&self.channels[beg..]);
136        self.connection()
137            .await?
138            .execute(AssertSqlSafe(query))
139            .await?;
140
141        Ok(())
142    }
143
144    /// Stops listening for notifications on a channel.
145    /// The channel name is quoted here to ensure case sensitivity.
146    pub async fn unlisten(&mut self, channel: &str) -> Result<(), Error> {
147        // use RAW connection and do NOT re-connect automatically, since this is not required for
148        // UNLISTEN (we've disconnected anyways)
149        if let Some(connection) = self.connection.as_mut() {
150            connection
151                .execute(AssertSqlSafe(format!(r#"UNLISTEN "{}""#, ident(channel))))
152                .await?;
153        }
154
155        if let Some(pos) = self.channels.iter().position(|s| s == channel) {
156            self.channels.remove(pos);
157        }
158
159        Ok(())
160    }
161
162    /// Stops listening for notifications on all channels.
163    pub async fn unlisten_all(&mut self) -> Result<(), Error> {
164        // use RAW connection and do NOT re-connect automatically, since this is not required for
165        // UNLISTEN (we've disconnected anyways)
166        if let Some(connection) = self.connection.as_mut() {
167            connection.execute("UNLISTEN *").await?;
168        }
169
170        self.channels.clear();
171
172        Ok(())
173    }
174
175    #[inline]
176    async fn connect_if_needed(&mut self) -> Result<(), Error> {
177        if self.connection.is_none() {
178            let mut connection = self.pool.acquire().await?;
179            connection.inner.stream.notifications = self.buffer_tx.take();
180
181            connection
182                .execute(AssertSqlSafe(build_listen_all_query(&self.channels)))
183                .await?;
184
185            self.connection = Some(connection);
186        }
187
188        Ok(())
189    }
190
191    #[inline]
192    async fn connection(&mut self) -> Result<&mut PgConnection, Error> {
193        // Ensure we have an active connection to work with.
194        self.connect_if_needed().await?;
195
196        Ok(self.connection.as_mut().unwrap())
197    }
198
199    /// Receives the next notification available from any of the subscribed channels.
200    ///
201    /// If the connection to PostgreSQL is lost, it is automatically reconnected on the next
202    /// call to `recv()`, and should be entirely transparent (as long as it was just an
203    /// intermittent network failure or long-lived connection reaper).
204    ///
205    /// As notifications are transient, any received while the connection was lost, will not
206    /// be returned. If you'd prefer the reconnection to be explicit and have a chance to
207    /// do something before, please see [`try_recv`](Self::try_recv).
208    ///
209    /// # Example
210    ///
211    /// ```rust,no_run
212    /// # use sqlx::postgres::PgListener;
213    /// #
214    /// # sqlx::__rt::test_block_on(async move {
215    /// let mut listener = PgListener::connect("postgres:// ...").await?;
216    /// loop {
217    ///     // ask for next notification, re-connecting (transparently) if needed
218    ///     let notification = listener.recv().await?;
219    ///
220    ///     // handle notification, do something interesting
221    /// }
222    /// # Result::<(), sqlx::Error>::Ok(())
223    /// # }).unwrap();
224    /// ```
225    pub async fn recv(&mut self) -> Result<PgNotification, Error> {
226        loop {
227            if let Some(notification) = self.try_recv().await? {
228                return Ok(notification);
229            }
230        }
231    }
232
233    /// Receives the next notification available from any of the subscribed channels.
234    ///
235    /// If the connection to PostgreSQL is lost, `None` is returned, and the connection is
236    /// reconnected either immediately, or on the next call to `try_recv()`, depending on
237    /// the value of [`eager_reconnect`].
238    ///
239    /// # Example
240    ///
241    /// ```rust,no_run
242    /// # use sqlx::postgres::PgListener;
243    /// #
244    /// # sqlx::__rt::test_block_on(async move {
245    /// # let mut listener = PgListener::connect("postgres:// ...").await?;
246    /// loop {
247    ///     // start handling notifications, connecting if needed
248    ///     while let Some(notification) = listener.try_recv().await? {
249    ///         // handle notification
250    ///     }
251    ///
252    ///     // connection lost, do something interesting
253    /// }
254    /// # Result::<(), sqlx::Error>::Ok(())
255    /// # }).unwrap();
256    /// ```
257    ///
258    /// [`eager_reconnect`]: PgListener::eager_reconnect
259    pub async fn try_recv(&mut self) -> Result<Option<PgNotification>, Error> {
260        // Flush the buffer first, if anything
261        // This would only fill up if this listener is used as a connection
262        if let Some(notification) = self.next_buffered() {
263            return Ok(Some(notification));
264        }
265
266        // Fetch our `CloseEvent` listener, if applicable.
267        let mut close_event = (!self.ignore_close_event).then(|| self.pool.close_event());
268
269        loop {
270            let next_message = self.connection().await?.inner.stream.recv_unchecked();
271
272            let res = if let Some(ref mut close_event) = close_event {
273                // cancels the wait and returns `Err(PoolClosed)` if the pool is closed
274                // before `next_message` returns, or if the pool was already closed
275                close_event.do_until(next_message).await?
276            } else {
277                next_message.await
278            };
279
280            let message = match res {
281                Ok(message) => message,
282
283                // The connection is dead, ensure that it is dropped,
284                // update self state, and loop to try again.
285                Err(Error::Io(err))
286                    if matches!(
287                        err.kind(),
288                        io::ErrorKind::ConnectionAborted |
289                        io::ErrorKind::UnexpectedEof |
290                        // see ERRORS section in tcp(7) man page (https://man7.org/linux/man-pages/man7/tcp.7.html)
291                        io::ErrorKind::TimedOut |
292                        io::ErrorKind::BrokenPipe
293                    ) =>
294                {
295                    if let Some(mut conn) = self.connection.take() {
296                        self.buffer_tx = conn.inner.stream.notifications.take();
297                        // Close the connection in a background task, so we can continue.
298                        conn.close_on_drop();
299                    }
300
301                    if self.eager_reconnect {
302                        self.connect_if_needed().await?;
303                    }
304
305                    // lost connection
306                    return Ok(None);
307                }
308
309                // Forward other errors
310                Err(error) => {
311                    return Err(error);
312                }
313            };
314
315            match message.format {
316                // We've received an async notification, return it.
317                BackendMessageFormat::NotificationResponse => {
318                    return Ok(Some(PgNotification(message.decode()?)));
319                }
320
321                // Mark the connection as ready for another query
322                BackendMessageFormat::ReadyForQuery => {
323                    self.connection().await?.inner.pending_ready_for_query_count -= 1;
324                }
325
326                // Ignore unexpected messages
327                _ => {}
328            }
329        }
330    }
331
332    /// Receives the next notification that already exists in the connection buffer, if any.
333    ///
334    /// This is similar to `try_recv`, except it will not wait if the connection has not yet received a notification.
335    ///
336    /// This is helpful if you want to retrieve all buffered notifications and process them in batches.
337    pub fn next_buffered(&mut self) -> Option<PgNotification> {
338        if let Ok(Some(notification)) = self.buffer_rx.try_next() {
339            Some(PgNotification(notification))
340        } else {
341            None
342        }
343    }
344
345    /// Consume this listener, returning a `Stream` of notifications.
346    ///
347    /// The backing connection will be automatically reconnected should it be lost.
348    ///
349    /// This has the same potential drawbacks as [`recv`](PgListener::recv).
350    ///
351    pub fn into_stream(mut self) -> impl Stream<Item = Result<PgNotification, Error>> + Unpin {
352        Box::pin(try_stream! {
353            loop {
354                r#yield!(self.recv().await?);
355            }
356        })
357    }
358}
359
360impl Drop for PgListener {
361    fn drop(&mut self) {
362        if let Some(mut conn) = self.connection.take() {
363            let fut = async move {
364                let _ = conn.execute("UNLISTEN *").await;
365
366                // inline the drop handler from `PoolConnection` so it doesn't try to spawn another task
367                // otherwise, it may trigger a panic if this task is dropped because the runtime is going away:
368                // https://github.com/launchbadge/sqlx/issues/1389
369                conn.return_to_pool().await;
370            };
371
372            // Unregister any listeners before returning the connection to the pool.
373            crate::rt::spawn(fut.in_current_span());
374        }
375    }
376}
377
378impl<'c> Acquire<'c> for &'c mut PgListener {
379    type Database = Postgres;
380    type Connection = &'c mut PgConnection;
381
382    fn acquire(self) -> BoxFuture<'c, Result<Self::Connection, Error>> {
383        self.connection().boxed()
384    }
385
386    fn begin(self) -> BoxFuture<'c, Result<Transaction<'c, Self::Database>, Error>> {
387        self.connection().and_then(|c| c.begin()).boxed()
388    }
389}
390
391impl<'c> Executor<'c> for &'c mut PgListener {
392    type Database = Postgres;
393
394    fn fetch_many<'e, 'q, E>(
395        self,
396        query: E,
397    ) -> BoxStream<'e, Result<Either<PgQueryResult, PgRow>, Error>>
398    where
399        'c: 'e,
400        E: Execute<'q, Self::Database>,
401        'q: 'e,
402        E: 'q,
403    {
404        futures_util::stream::once(async move {
405            // need some basic type annotation to help the compiler a bit
406            let res: Result<_, Error> = Ok(self.connection().await?.fetch_many(query));
407            res
408        })
409        .try_flatten()
410        .boxed()
411    }
412
413    fn fetch_optional<'e, 'q, E>(self, query: E) -> BoxFuture<'e, Result<Option<PgRow>, Error>>
414    where
415        'c: 'e,
416        E: Execute<'q, Self::Database>,
417        'q: 'e,
418        E: 'q,
419    {
420        async move { self.connection().await?.fetch_optional(query).await }.boxed()
421    }
422
423    fn prepare_with<'e>(
424        self,
425        query: SqlStr,
426        parameters: &'e [PgTypeInfo],
427    ) -> BoxFuture<'e, Result<PgStatement, Error>>
428    where
429        'c: 'e,
430    {
431        async move {
432            self.connection()
433                .await?
434                .prepare_with(query, parameters)
435                .await
436        }
437        .boxed()
438    }
439
440    #[doc(hidden)]
441    #[cfg(feature = "offline")]
442    fn describe<'e>(
443        self,
444        query: SqlStr,
445    ) -> BoxFuture<'e, Result<crate::describe::Describe<Self::Database>, Error>>
446    where
447        'c: 'e,
448    {
449        async move { self.connection().await?.describe(query).await }.boxed()
450    }
451}
452
453impl PgNotification {
454    /// The process ID of the notifying backend process.
455    #[inline]
456    pub fn process_id(&self) -> u32 {
457        self.0.process_id
458    }
459
460    /// The channel that the notify has been raised on. This can be thought
461    /// of as the message topic.
462    #[inline]
463    pub fn channel(&self) -> &str {
464        from_utf8(&self.0.channel).unwrap()
465    }
466
467    /// The payload of the notification. An empty payload is received as an
468    /// empty string.
469    #[inline]
470    pub fn payload(&self) -> &str {
471        from_utf8(&self.0.payload).unwrap()
472    }
473}
474
475impl Debug for PgListener {
476    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
477        f.debug_struct("PgListener").finish()
478    }
479}
480
481impl Debug for PgNotification {
482    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
483        f.debug_struct("PgNotification")
484            .field("process_id", &self.process_id())
485            .field("channel", &self.channel())
486            .field("payload", &self.payload())
487            .finish()
488    }
489}
490
491fn ident(mut name: &str) -> String {
492    // If the input string contains a NUL byte, we should truncate the
493    // identifier.
494    if let Some(index) = name.find('\0') {
495        name = &name[..index];
496    }
497
498    // Any double quotes must be escaped
499    name.replace('"', "\"\"")
500}
501
502fn build_listen_all_query(channels: impl IntoIterator<Item = impl AsRef<str>>) -> String {
503    channels.into_iter().fold(String::new(), |mut acc, chan| {
504        acc.push_str(r#"LISTEN ""#);
505        acc.push_str(&ident(chan.as_ref()));
506        acc.push_str(r#"";"#);
507        acc
508    })
509}
510
511#[test]
512fn test_build_listen_all_query_with_single_channel() {
513    let output = build_listen_all_query(["test"]);
514    assert_eq!(output.as_str(), r#"LISTEN "test";"#);
515}
516
517#[test]
518fn test_build_listen_all_query_with_multiple_channels() {
519    let output = build_listen_all_query(["channel.0", "channel.1"]);
520    assert_eq!(output.as_str(), r#"LISTEN "channel.0";LISTEN "channel.1";"#);
521}