cdbc_pg/
listener.rs

1use cdbc::describe::Describe;
2use cdbc::error::Error;
3use cdbc::executor::{Execute, Executor};
4use cdbc::pool::PoolOptions;
5use cdbc::pool::{Pool, PoolConnection};
6use crate::message::{MessageFormat, Notification};
7use crate::{PgConnection, PgQueryResult, PgRow, PgStatement, PgTypeInfo, Postgres};
8use either::Either;
9use std::fmt::{self, Debug};
10use std::io;
11use std::str::from_utf8;
12use mco::{chan, co};
13use mco::std::sync::channel::{Receiver, Sender};
14use mco::std::sync::channel;
15use cdbc::io::chan_stream::ChanStream;
16
17/// A stream of asynchronous notifications from Postgres.
18///
19/// This listener will auto-reconnect. If the active
20/// connection being used ever dies, this listener will detect that event, create a
21/// new connection, will re-subscribe to all of the originally specified channels, and will resume
22/// operations as normal.
23pub struct PgListener {
24    pool: Pool<Postgres>,
25    connection: Option<PoolConnection<Postgres>>,
26    buffer_rx: Receiver<Notification>,
27    buffer_tx: Option<Sender<Notification>>,
28    channels: Vec<String>,
29}
30
31/// An asynchronous notification from Postgres.
32pub struct PgNotification(Notification);
33
34impl PgListener {
35    pub fn connect(uri: &str) -> Result<Self, Error> {
36        // Create a pool of 1 without timeouts (as they don't apply here)
37        // We only use the pool to handle re-connections
38        let pool = PoolOptions::<Postgres>::new()
39            .max_connections(1)
40            .max_lifetime(None)
41            .idle_timeout(None)
42            .connect(uri)
43            ?;
44
45        Self::connect_with(&pool)
46    }
47
48    pub fn connect_with(pool: &Pool<Postgres>) -> Result<Self, Error> {
49        // Pull out an initial connection
50        let mut connection = pool.acquire()?;
51
52        // Setup a notification buffer
53        let (sender, receiver) = chan!();
54        connection.stream.notifications = Some(sender);
55
56        Ok(Self {
57            pool: pool.clone(),
58            connection: Some(connection),
59            buffer_rx: receiver,
60            buffer_tx: None,
61            channels: Vec::new(),
62        })
63    }
64
65    /// Starts listening for notifications on a channel.
66    /// The channel name is quoted here to ensure case sensitivity.
67    pub fn listen(&mut self, channel: &str) -> Result<(), Error> {
68        self.connection()
69            .execute(&*format!(r#"LISTEN "{}""#, ident(channel)))
70            ?;
71
72        self.channels.push(channel.to_owned());
73
74        Ok(())
75    }
76
77    /// Starts listening for notifications on all channels.
78    pub fn listen_all<'a>(
79        &mut self,
80        channels: impl IntoIterator<Item = &'a str>,
81    ) -> Result<(), Error> {
82        let beg = self.channels.len();
83        self.channels.extend(channels.into_iter().map(|s| s.into()));
84
85        self.connection
86            .as_mut()
87            .unwrap()
88            .execute(&*build_listen_all_query(&self.channels[beg..]))
89            ?;
90
91        Ok(())
92    }
93
94    /// Stops listening for notifications on a channel.
95    /// The channel name is quoted here to ensure case sensitivity.
96    pub fn unlisten(&mut self, channel: &str) -> Result<(), Error> {
97        self.connection()
98            .execute(&*format!(r#"UNLISTEN "{}""#, ident(channel)))
99            ?;
100
101        if let Some(pos) = self.channels.iter().position(|s| s == channel) {
102            self.channels.remove(pos);
103        }
104
105        Ok(())
106    }
107
108    /// Stops listening for notifications on all channels.
109    pub fn unlisten_all(&mut self) -> Result<(), Error> {
110        self.connection().execute("UNLISTEN *")?;
111
112        self.channels.clear();
113
114        Ok(())
115    }
116
117    #[inline]
118    fn connect_if_needed(&mut self) -> Result<(), Error> {
119        if self.connection.is_none() {
120            let mut connection = self.pool.acquire()?;
121            connection.stream.notifications = self.buffer_tx.take();
122
123            connection
124                .execute(&*build_listen_all_query(&self.channels))
125                ?;
126
127            self.connection = Some(connection);
128        }
129
130        Ok(())
131    }
132
133    #[inline]
134    fn connection(&mut self) -> &mut PgConnection {
135        self.connection.as_mut().unwrap()
136    }
137
138    /// Receives the next notification available from any of the subscribed channels.
139    ///
140    /// If the connection to PostgreSQL is lost, it is automatically reconnected on the next
141    /// call to `recv()`, and should be entirely transparent (as long as it was just an
142    /// intermittent network failure or long-lived connection reaper).
143    ///
144    /// As notifications are transient, any received while the connection was lost, will not
145    /// be returned. If you'd prefer the reconnection to be explicit and have a chance to
146    /// do something before, please see [`try_recv`](Self::try_recv).
147    ///
148    /// # Example
149    ///
150    /// ```rust,no_run
151    /// # use cdbc_pg::PgListener;
152    /// # use cdbc::error::Error;
153    /// #
154    /// # #[cfg(feature = "_rt-async-std")]
155    /// # sqlx_rt::block_on::<_, Result<(), Error>>(async move {
156    /// # let mut listener = PgListener::connect("postgres:// ...")?;
157    /// loop {
158    ///     // ask for next notification, re-connecting (transparently) if needed
159    ///     let notification = listener.recv()?;
160    ///
161    ///     // handle notification, do something interesting
162    /// }
163    /// # Ok(())
164    /// # }).unwrap();
165    /// ```
166    pub fn recv(&mut self) -> Result<PgNotification, Error> {
167        loop {
168            if let Some(notification) = self.try_recv()? {
169                return Ok(notification);
170            }
171        }
172    }
173
174    /// Receives the next notification available from any of the subscribed channels.
175    ///
176    /// If the connection to PostgreSQL is lost, `None` is returned, and the connection is
177    /// reconnected on the next call to `try_recv()`.
178    ///
179    /// # Example
180    ///
181    /// ```rust,no_run
182    /// # use cdbc_pg::PgListener;
183    /// # use cdbc::error::Error;
184    /// #
185    /// # #[cfg(feature = "_rt-async-std")]
186    /// # sqlx_rt::block_on::<_, Result<(), Error>>(async move {
187    /// # let mut listener = PgListener::connect("postgres:// ...")?;
188    /// loop {
189    ///     // start handling notifications, connecting if needed
190    ///     while let Some(notification) = listener.try_recv()? {
191    ///         // handle notification
192    ///     }
193    ///
194    ///     // connection lost, do something interesting
195    /// }
196    /// # Ok(())
197    /// # }).unwrap();
198    /// ```
199    pub fn try_recv(&mut self) -> Result<Option<PgNotification>, Error> {
200        // Flush the buffer first, if anything
201        // This would only fill up if this listener is used as a connection
202        if let Ok(notification) = self.buffer_rx.try_recv() {
203            return Ok(Some(PgNotification(notification)));
204        }
205
206        loop {
207            // Ensure we have an active connection to work with.
208            self.connect_if_needed()?;
209
210            let message = match self.connection().stream.recv_unchecked() {
211                Ok(message) => message,
212
213                // The connection is dead, ensure that it is dropped,
214                // update self state, and loop to try again.
215                Err(Error::Io(err)) if err.kind() == io::ErrorKind::ConnectionAborted => {
216                    self.buffer_tx = self.connection().stream.notifications.take();
217                    self.connection = None;
218
219                    // lost connection
220                    return Ok(None);
221                }
222
223                // Forward other errors
224                Err(error) => {
225                    return Err(error);
226                }
227            };
228
229            match message.format {
230                // We've received an async notification, return it.
231                MessageFormat::NotificationResponse => {
232                    return Ok(Some(PgNotification(message.decode()?)));
233                }
234
235                // Mark the connection as ready for another query
236                MessageFormat::ReadyForQuery => {
237                    self.connection().pending_ready_for_query_count -= 1;
238                }
239
240                // Ignore unexpected messages
241                _ => {}
242            }
243        }
244    }
245
246    /// Consume this listener, returning a `Stream` of notifications.
247    ///
248    /// The backing connection will be automatically reconnected should it be lost.
249    ///
250    /// This has the same potential drawbacks as [`recv`](PgListener::recv).
251    ///
252    pub fn into_stream(mut self) -> ChanStream<PgNotification> {
253        chan_stream!( {
254            loop {
255                r#yield!(self.recv()?);
256            }
257        })
258    }
259}
260
261impl Drop for PgListener {
262    fn drop(&mut self) {
263        if let Some(mut conn) = self.connection.take() {
264            let fut = move || {
265                let _ = conn.execute("UNLISTEN *");
266
267                // inline the drop handler from `PoolConnection` so it doesn't try to spawn another task
268                // otherwise, it may trigger a panic if this task is dropped because the runtime is going away:
269                // https://github.com/launchbadge/sqlx/issues/1389
270                conn.return_to_pool();
271            };
272
273            // Unregister any listeners before returning the connection to the pool.
274            co!(fut);
275        }
276    }
277}
278
279impl<'c> Executor for &'c mut PgListener {
280    type Database = Postgres;
281
282    fn fetch_many<'q, E: 'q>(
283        &mut self,
284        query: E,
285    ) -> ChanStream<Either<PgQueryResult, PgRow>>
286    where
287        E: Execute<'q, Self::Database>,
288    {
289        self.connection().fetch_many(query)
290    }
291
292    fn fetch_optional<'q, E: 'q>(
293        &mut self,
294        query: E,
295    ) -> Result<Option<PgRow>, Error>
296    where E: Execute<'q, Self::Database>,
297    {
298        self.connection().fetch_optional(query)
299    }
300
301    fn prepare_with<'q>(
302        &mut self,
303        query: &'q str,
304        parameters: &'q [PgTypeInfo],
305    ) -> Result<PgStatement, Error>
306    where
307    {
308        self.connection().prepare_with(query, parameters)
309    }
310
311    #[doc(hidden)]
312    fn describe< 'q>(
313        &mut self,
314        query: &'q str,
315    ) -> Result<Describe<Self::Database>, Error>
316    where
317    {
318        self.connection().describe(query)
319    }
320}
321
322impl PgNotification {
323    /// The process ID of the notifying backend process.
324    #[inline]
325    pub fn process_id(&self) -> u32 {
326        self.0.process_id
327    }
328
329    /// The channel that the notify has been raised on. This can be thought
330    /// of as the message topic.
331    #[inline]
332    pub fn channel(&self) -> &str {
333        from_utf8(&self.0.channel).unwrap()
334    }
335
336    /// The payload of the notification. An empty payload is received as an
337    /// empty string.
338    #[inline]
339    pub fn payload(&self) -> &str {
340        from_utf8(&self.0.payload).unwrap()
341    }
342}
343
344impl Debug for PgListener {
345    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
346        f.debug_struct("PgListener").finish()
347    }
348}
349
350impl Debug for PgNotification {
351    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
352        f.debug_struct("PgNotification")
353            .field("process_id", &self.process_id())
354            .field("channel", &self.channel())
355            .field("payload", &self.payload())
356            .finish()
357    }
358}
359
360fn ident(mut name: &str) -> String {
361    // If the input string contains a NUL byte, we should truncate the
362    // identifier.
363    if let Some(index) = name.find('\0') {
364        name = &name[..index];
365    }
366
367    // Any double quotes must be escaped
368    name.replace('"', "\"\"")
369}
370
371fn build_listen_all_query(channels: impl IntoIterator<Item = impl AsRef<str>>) -> String {
372    channels.into_iter().fold(String::new(), |mut acc, chan| {
373        acc.push_str(r#"LISTEN ""#);
374        acc.push_str(&ident(chan.as_ref()));
375        acc.push_str(r#"";"#);
376        acc
377    })
378}
379
380#[test]
381fn test_build_listen_all_query_with_single_channel() {
382    let output = build_listen_all_query(&["test"]);
383    assert_eq!(output.as_str(), r#"LISTEN "test";"#);
384}
385
386#[test]
387fn test_build_listen_all_query_with_multiple_channels() {
388    let output = build_listen_all_query(&["channel.0", "channel.1"]);
389    assert_eq!(output.as_str(), r#"LISTEN "channel.0";LISTEN "channel.1";"#);
390}