postgres_notify/
lib.rs

1//!
2//! [`PGNotifier`] makes it easy to subscribe to PostgreSQL notifications.
3//!
4//! There are few examples in Rust that show how to capture these notifications
5//! mostly because tokio_postgres examples spawn off the connection half such
6//! that you can't listen for notifications anymore. [`PGNotifier`] also spawns
7//! a task for the connection, but it also listens for notifications.
8//!
9//! [`PGNotifier`] maintains a two list of callback functions, which are called
10//! every time the it receives a notification. These two lists match the types
11//! of notifications sent by Postgres: `NOTIFY` and `RAISE`.
12//!
13//! # LISTEN/NOTIFY
14//!
15//! For a very long time (at least since version 7.1) postgres has supported
16//! asynchronous notifications based on LISTEN/NOTIFY commands. This allows
17//! the database to send notifications to the client in an "out-of-band"
18//! channel.
19//!
20//! Once the client has issued a `LISTEN <channel>` command, the database will
21//! send notifications to the client whenever a `NOTIFY <channel> <payload>`
22//! is issued on the database regardless of which session has issued it.
23//! This can act as a cheap alternative to a pubsub system.
24//!
25//! When calling `subscribe_notify` with a channel name, [`PGNotifier`] will
26//! call the supplied closure upon receiving a NOTIFY message but only if it
27//! matches the requested channel name.
28//!
29//! ```rust
30//! use postgres_notify::PGNotifier;
31//!
32//! let mut notifier = PGNotifier::spawn(client, conn);
33//!
34//! notifier.subscribe_notify("test-channel", |notify| {
35//!     println!("[{}]: {}", &notify.channel, &notify.payload);
36//! });
37//! ```
38//!
39//!
40//! # RAISE/LOGS
41//!
42//! Logs in PostgreSQL are created by issuing `RAISE <level> <message>` commands
43//! within your functions, stored procedures and scripts. When such a command is
44//! issued, [`PGNotify`] receives a notification even if the call is in progress,
45//! which allows a user to capture the execution log in realtime.
46//!
47//! [`PGNotify`] simplifies log collection in two ways: first it provides the
48//! `subscribe_raise` function, which registers a callback. Second, it also
49//! provides the [`capture_log`](PGNotifier::capture_log) and
50//! [`with_captured_log`](PGNotifier::with_captured_log) functions.
51//!
52//! ```rust
53//! use postgres_notify::PGNotifier;
54//!
55//! let mut notifier = PGNotifier::spawn(client, conn);
56//!
57//! notifier.subscribe_raise(|notice| {
58//!     // Will print the below message to stdout
59//!     println!("{}", &notice);
60//! });
61//!
62//! // Will capture the notices in a Vec
63//! let (_, log) = notifier.with_captured_log(async |client| {
64//!     client.batch_execute(r#"
65//!        do $$
66//!        begin
67//!            raise debug 'this is a DEBUG notification';
68//!            raise log 'this is a LOG notification';
69//!            raise info 'this is a INFO notification';
70//!            raise notice 'this is a NOTICE notification';
71//!            raise warning 'this is a WARNING notification';
72//!        end;
73//!        $$
74//!     "#).await;
75//!     Ok(())
76//! }).await?
77//!
78//! println!("{:#?}", &log);
79//! ```
80//!
81//! You can look at the unit tests for a more in-depth example.
82//!
83#[cfg(feature = "chrono")]
84use chrono::{DateTime, SecondsFormat, Utc};
85#[cfg(not(feature = "chrono"))]
86use std::time::SystemTime;
87
88use {
89    futures::{StreamExt, stream},
90    std::{
91        collections::BTreeMap,
92        fmt::{self, Display},
93        str::FromStr,
94        sync::{Arc, RwLock},
95    },
96    tokio::{
97        io::{AsyncRead, AsyncWrite},
98        task::JoinHandle,
99    },
100    tokio_postgres::{
101        AsyncMessage, Client as PGClient, Connection as PGConnection, Notification, error::DbError,
102    },
103};
104
105/// Type used to store callbacks for LISTEN/NOTIFY calls.
106pub type NotifyCallbacks =
107    Arc<RwLock<BTreeMap<String, Vec<Box<dyn for<'a> Fn(&'a PGNotify) + Send + Sync + 'static>>>>>;
108
109/// Type used to store callbacks for RAISE &lt;level&gt; &lt;message&gt; calls.
110pub type RaiseCallbacks =
111    Arc<RwLock<Vec<Box<dyn for<'a> Fn(&'a PGRaise) + Send + Sync + 'static>>>>;
112
113///
114/// Forwards PostgreSQL `NOTIFY` and `RAISE` commands to subscribers.
115///
116pub struct PGNotifier {
117    pub client: PGClient,
118    listen_handle: JoinHandle<()>,
119    log: Arc<RwLock<Option<Vec<PGRaise>>>>,
120    raise_callbacks: RaiseCallbacks,
121    notify_callbacks: NotifyCallbacks,
122}
123
124impl PGNotifier {
125    ///
126    /// Spawns a new postgres client/connection pair.
127    ///
128    pub fn spawn<S, T>(client: PGClient, mut conn: PGConnection<S, T>) -> Self
129    where
130        S: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static,
131        T: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static,
132    {
133        let log = Arc::new(RwLock::new(Some(Vec::default())));
134        let notify_callbacks: NotifyCallbacks = Arc::new(RwLock::new(BTreeMap::new()));
135        let raise_callbacks: RaiseCallbacks = Arc::new(RwLock::new(Vec::new()));
136
137        // Spawn the connection and poll for messages on it.
138        let listen_handle = {
139            //
140            let log = log.clone();
141            let notify_callbacks = notify_callbacks.clone();
142            let raise_callbacks = raise_callbacks.clone();
143
144            tokio::spawn(async move {
145                //
146                let mut stream =
147                    stream::poll_fn(move |cx| conn.poll_message(cx).map_err(|e| panic!("{}", e)));
148
149                while let Some(msg) = stream.next().await {
150                    match msg {
151                        Ok(AsyncMessage::Notice(raise)) => {
152                            Self::handle_raise(&raise_callbacks, &log, raise)
153                        }
154                        Ok(AsyncMessage::Notification(notice)) => {
155                            Self::handle_notify(&notify_callbacks, notice)
156                        }
157                        _ => {
158                            #[cfg(feature = "tracing")]
159                            tracing::error!("connection to the server was closed");
160                            #[cfg(not(feature = "tracing"))]
161                            eprintln!("connection to the server was closed");
162                            break;
163                        }
164                    }
165                }
166            })
167        };
168
169        Self {
170            client,
171            listen_handle,
172            log,
173            notify_callbacks,
174            raise_callbacks,
175        }
176    }
177
178    ///
179    /// Handles the notification of LISTEN/NOTIFY subscribers.
180    ///
181    fn handle_notify(callbacks: &NotifyCallbacks, note: Notification) {
182        let notice = PGNotify::new(note.channel(), note.payload());
183        if let Ok(guard) = callbacks.read() {
184            if let Some(cbs) = guard.get(note.channel()) {
185                for callback in cbs.iter() {
186                    callback(&notice);
187                }
188            }
189        }
190    }
191
192    ///
193    /// Handles the notification of `RAISE <level> <message>` subscribers.
194    ///
195    fn handle_raise(
196        callbacks: &RaiseCallbacks,
197        log: &Arc<RwLock<Option<Vec<PGRaise>>>>,
198        raise: DbError,
199    ) {
200        let log_item = PGRaise {
201            #[cfg(feature = "chrono")]
202            timestamp: Utc::now(),
203            #[cfg(not(feature = "chrono"))]
204            timestamp: SystemTime::now(),
205            level: PGRaiseLevel::from_str(raise.severity()).unwrap_or(PGRaiseLevel::Error),
206            message: raise.message().into(),
207        };
208
209        if let Ok(guard) = callbacks.read() {
210            for callback in guard.iter() {
211                callback(&log_item);
212            }
213        }
214
215        if let Ok(mut guard) = log.write() {
216            guard.as_mut().map(|log| log.push(log_item));
217        }
218    }
219
220    ///
221    /// Subscribes to notifications on a particular channel.
222    ///
223    /// The call will issue the `LISTEN` command to PostgreSQL. There is
224    /// currently no mechanism to unsubscribe even though postgres does
225    /// supports UNLISTEN.
226    ///
227    pub async fn subscribe_notify<F>(
228        &mut self,
229        channel: impl Into<String>,
230        callback: F,
231    ) -> Result<(), tokio_postgres::Error>
232    where
233        F: Fn(&PGNotify) + Send + Sync + 'static,
234    {
235        // Issue the listen command to postgres
236        let channel = channel.into();
237        self.client
238            .execute(&format!("LISTEN {}", &channel), &[])
239            .await?;
240
241        // Add the callback to the list of callbacks
242        if let Ok(mut guard) = self.notify_callbacks.write() {
243            guard.entry(channel).or_default().push(Box::new(callback));
244        }
245
246        Ok(())
247    }
248
249    ///
250    /// Subscribes to `RAISE <level> <message>` notifications.
251    ///
252    /// There is currently no mechanism to unsubscribe. This would only require
253    /// returning some form of "token", which could be used to unsubscribe.
254    ///
255    pub fn subscribe_raise(&mut self, callback: impl Fn(&PGRaise) + Send + Sync + 'static) {
256        if let Ok(mut guard) = self.raise_callbacks.write() {
257            guard.push(Box::new(callback));
258        }
259    }
260
261    ///
262    /// Returns the accumulated log since the last capture.
263    ///
264    /// If the code being called issues many `RAISE` commands and you never
265    /// call [`capture_log`](PGNotifier::capture_log), then eventually, you
266    /// might run out of memory. To ensure that this does not happen, you
267    /// might consider using [`with_captured_log`](PGNotifier::with_captured_log)
268    /// instead.
269    ///
270    pub fn capture_log(&self) -> Option<Vec<PGRaise>> {
271        if let Ok(mut guard) = self.log.write() {
272            let captured = guard.take();
273            *guard = Some(Vec::default());
274            captured
275        } else {
276            None
277        }
278    }
279
280    ///
281    /// Given an async closure taking the postgres client, returns the result
282    /// of said closure along with the accumulated log since the beginning of
283    /// the closure.
284    ///
285    /// If you use query pipelining then collect the logs for all queries in
286    /// the pipeline. Otherwise, the logs might not be what you expect.
287    ///
288    pub async fn with_captured_log<F, T>(
289        &self,
290        f: F,
291    ) -> Result<(T, Vec<PGRaise>), tokio_postgres::Error>
292    where
293        F: AsyncFnOnce(&PGClient) -> Result<T, tokio_postgres::Error>,
294    {
295        self.capture_log(); // clear the log
296        let result = f(&self.client).await?;
297        let log = self.capture_log().unwrap_or_default();
298        Ok((result, log))
299    }
300}
301
302impl Drop for PGNotifier {
303    fn drop(&mut self) {
304        self.listen_handle.abort();
305    }
306}
307
308///
309/// Message received when a `NOTIFY [channel] [payload]` is issued on PostgreSQL.
310///
311#[derive(Debug, Clone)]
312#[cfg_attr(any(feature = "serde", test), derive(serde::Serialize))]
313pub struct PGNotify {
314    pub channel: String,
315    pub payload: String,
316}
317
318impl PGNotify {
319    pub fn new(channel: impl Into<String>, payload: impl Into<String>) -> Self {
320        Self {
321            channel: channel.into(),
322            payload: payload.into(),
323        }
324    }
325}
326
327///
328/// # Message received when a `raise <level> <message>` is issued on PostgreSQL.
329///
330#[derive(Debug, Clone)]
331#[cfg_attr(any(feature = "serde", test), derive(serde::Serialize))]
332pub struct PGRaise {
333    #[cfg(feature = "chrono")]
334    pub timestamp: DateTime<Utc>,
335    #[cfg(not(feature = "chrono"))]
336    pub timestamp: SystemTime,
337    pub level: PGRaiseLevel,
338    pub message: String,
339}
340
341impl From<DbError> for PGRaise {
342    #[cfg(feature = "chrono")]
343    fn from(raise: DbError) -> Self {
344        PGRaise {
345            timestamp: Utc::now(),
346            level: PGRaiseLevel::from_str(raise.severity()).unwrap_or(PGRaiseLevel::Error),
347            message: raise.message().into(),
348        }
349    }
350
351    #[cfg(not(feature = "chrono"))]
352    fn from(raise: DbError) -> Self {
353        PGRaise {
354            timestamp: SystemTime::now(),
355            level: PGRaiseLevel::from_str(raise.severity()).unwrap_or(PGRaiseLevel::Error),
356            message: raise.message().into(),
357        }
358    }
359}
360
361impl Display for PGRaise {
362    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
363        #[cfg(feature = "chrono")]
364        let ts = self.timestamp.to_rfc3339_opts(SecondsFormat::Millis, true);
365
366        #[cfg(not(feature = "chrono"))]
367        let ts = {
368            let duration = self
369                .timestamp
370                .duration_since(SystemTime::UNIX_EPOCH)
371                .unwrap();
372            let millis = duration.as_millis();
373            format!("{}", millis)
374        };
375
376        write!(f, "{}{:>8}: {}", &ts, &self.level.as_ref(), self.message)
377    }
378}
379
380#[derive(Debug, Clone, Copy)]
381#[cfg_attr(any(feature = "serde", test), derive(serde::Serialize))]
382#[cfg_attr(any(feature = "serde", test), serde(rename_all = "UPPERCASE"))]
383pub enum PGRaiseLevel {
384    Debug,
385    Log,
386    Info,
387    Notice,
388    Warning,
389    Error,
390    Fatal,
391    Panic,
392}
393
394impl AsRef<str> for PGRaiseLevel {
395    fn as_ref(&self) -> &str {
396        use PGRaiseLevel::*;
397        match self {
398            Debug => "DEBUG",
399            Log => "LOG",
400            Info => "INFO",
401            Notice => "NOTICE",
402            Warning => "WARNING",
403            Error => "ERROR",
404            Fatal => "FATAL",
405            Panic => "PANIC",
406        }
407    }
408}
409
410impl FromStr for PGRaiseLevel {
411    type Err = ();
412    fn from_str(s: &str) -> Result<Self, Self::Err> {
413        match s {
414            "DEBUG" => Ok(PGRaiseLevel::Debug),
415            "LOG" => Ok(PGRaiseLevel::Log),
416            "INFO" => Ok(PGRaiseLevel::Info),
417            "NOTICE" => Ok(PGRaiseLevel::Notice),
418            "WARNING" => Ok(PGRaiseLevel::Warning),
419            "ERROR" => Ok(PGRaiseLevel::Error),
420            "FATAL" => Ok(PGRaiseLevel::Fatal),
421            "PANIC" => Ok(PGRaiseLevel::Panic),
422            _ => Err(()),
423        }
424    }
425}
426
427#[cfg(test)]
428mod tests {
429
430    use super::{PGNotifier, PGNotify};
431    use insta::*;
432    use std::sync::{Arc, RwLock};
433    use testcontainers::{ImageExt, runners::AsyncRunner};
434    use testcontainers_modules::postgres::Postgres;
435
436    #[tokio::test]
437    async fn test_integration() {
438        //
439        // --------------------------------------------------------------------
440        // Setup Postgres Server
441        // --------------------------------------------------------------------
442
443        let pg_server = Postgres::default()
444            .with_tag("16.4")
445            .start()
446            .await
447            .expect("could not start postgres server");
448
449        let database_url = format!(
450            "postgres://postgres:postgres@{}:{}/postgres",
451            pg_server.get_host().await.unwrap(),
452            pg_server.get_host_port_ipv4(5432).await.unwrap()
453        );
454
455        // --------------------------------------------------------------------
456        // Connect to the server
457        // --------------------------------------------------------------------
458
459        let (client, conn) = tokio_postgres::connect(&database_url, tokio_postgres::NoTls)
460            .await
461            .expect("could not connect to postgres server");
462
463        let mut notifier = PGNotifier::spawn(client, conn);
464
465        // --------------------------------------------------------------------
466        // Subscribe to notify and raise
467        // --------------------------------------------------------------------
468
469        let notices = Arc::new(RwLock::new(Vec::new()));
470        let notices_clone = notices.clone();
471
472        notifier
473            .subscribe_notify("test", move |notify: &PGNotify| {
474                if let Ok(mut guard) = notices_clone.write() {
475                    guard.push(notify.clone());
476                }
477            })
478            .await
479            .expect("could not subscribe to notifications");
480
481        let (_, execution_log) = notifier
482            .with_captured_log(async |client| {
483                client
484                    .batch_execute(
485                        r#"
486                    set client_min_messages to 'debug';
487                    do $$
488                    begin
489                        raise debug 'this is a DEBUG notification';
490                        notify test, 'test#1';
491                        raise log 'this is a LOG notification';
492                        notify test, 'test#2';
493                        raise info 'this is a INFO notification';
494                        notify test, 'test#3';
495                        raise notice 'this is a NOTICE notification';
496                        notify test, 'test#4';
497                        raise warning 'this is a WARNING notification';
498                        notify test, 'test#5';
499                    end;
500                    $$;
501                "#,
502                    )
503                    .await
504            })
505            .await
506            .expect("could not execute queries on postgres");
507
508        assert_json_snapshot!("raise-notices", &execution_log, {
509            "[].timestamp" => "<timestamp>"
510        });
511
512        let guard = notices.read().expect("could not read notices");
513        let raise_notices = guard.clone();
514        assert_json_snapshot!("listen/notify", &raise_notices);
515    }
516}