postgres_notify/
lib.rs

1//!
2//! [`PGNotifier`] makes it easy to subscribe to PostgreSQL notifications.
3//!
4//! There are few examples in Rust that show how to capture these notifications
5//! mostly because tokio_postgres examples spawn off the connection half such
6//! that you can't listen for notifications anymore. [`PGNotifier`] also spawns
7//! a task for the connection, but it also listens for notifications.
8//!
9//! [`PGNotifier`] maintains a two list of callback functions, which are called
10//! every time the it receives a notification. These two lists match the types
11//! of notifications sent by Postgres: `NOTIFY` and `RAISE`.
12//!
13//! # LISTEN/NOTIFY
14//!
15//! For a very long time (at least since version 7.1) postgres has supported
16//! asynchronous notifications based on LISTEN/NOTIFY commands. This allows
17//! the database to send notifications to the client in an "out-of-band"
18//! channel.
19//!
20//! Once the client has issued a `LISTEN <channel>` command, the database will
21//! send notifications to the client whenever a `NOTIFY <channel> <payload>`
22//! is issued on the database regardless of which session has issued it.
23//! This can act as a cheap alternative to a pubsub system.
24//!
25//! When calling `subscribe_notify` with a channel name, [`PGNotifier`] will
26//! call the supplied closure upon receiving a NOTIFY message but only if it
27//! matches the requested channel name.
28//!
29//! ```rust
30//! use postgres_notify::PGNotifier;
31//!
32//! let mut notifier = PGNotifier::spawn(client, conn);
33//!
34//! notifier.subscribe_notify("test-channel", |notify| {
35//!     println!("[{}]: {}", &notify.channel, &notify.payload);
36//! });
37//! ```
38//!
39//!
40//! # RAISE/LOGS
41//!
42//! Logs in PostgreSQL are created by issuing `RAISE <level> <message>` commands
43//! within your functions, stored procedures and scripts. When such a command is
44//! issued, [`PGNotify`] receives a notification even if the call is in progress,
45//! which allows a user to capture the execution log in realtime.
46//!
47//! [`PGNotify`] simplifies log collection in two ways: first it provides the
48//! `subscribe_raise` function, which registers a callback. Second, it also
49//! provides the [`capture_log`](PGNotifier::capture_log) and
50//! [`with_captured_log`](PGNotifier::with_captured_log) functions.
51//!
52//! ```rust
53//! use postgres_notify::PGNotifier;
54//!
55//! let mut notifier = PGNotifier::spawn(client, conn);
56//!
57//! notifier.subscribe_raise(|notice| {
58//!     // Will print the below message to stdout
59//!     println!("{}", &notice);
60//! });
61//!
62//! // Will capture the notices in a Vec
63//! let (_, log) = notifier.with_captured_log(async |client| {
64//!     client.batch_execute(r#"
65//!        do $$
66//!        begin
67//!            raise debug 'this is a DEBUG notification';
68//!            raise log 'this is a LOG notification';
69//!            raise info 'this is a INFO notification';
70//!            raise notice 'this is a NOTICE notification';
71//!            raise warning 'this is a WARNING notification';
72//!        end;
73//!        $$
74//!     "#).await;
75//!     Ok(())
76//! }).await?
77//!
78//! println!("{:#?}", &log);
79//! ```
80//!
81//! You can look at the unit tests for a more in-depth example.
82//!
83#[cfg(feature = "chrono")]
84use chrono::{DateTime, SecondsFormat, Utc};
85#[cfg(not(feature = "chrono"))]
86use std::time::SystemTime;
87
88use {
89    futures::{StreamExt, stream},
90    std::{
91        collections::BTreeMap,
92        fmt::{self, Display},
93        str::FromStr,
94        sync::{Arc, RwLock},
95    },
96    tokio::{
97        io::{AsyncRead, AsyncWrite},
98        task::JoinHandle,
99    },
100    tokio_postgres::{
101        AsyncMessage, Client as PGClient, Connection as PGConnection, Error as PGError,
102        Notification, error::DbError,
103    },
104};
105
106/// Shorthand for Result with tokio_postgres::Error
107pub type PGResult<T> = Result<T, PGError>;
108
109/// Type used to store callbacks for LISTEN/NOTIFY calls.
110pub type NotifyCallbacks =
111    Arc<RwLock<BTreeMap<String, Vec<Box<dyn for<'a> Fn(&'a PGNotify) + Send + Sync + 'static>>>>>;
112
113/// Type used to store callbacks for RAISE &lt;level&gt; &lt;message&gt; calls.
114pub type RaiseCallbacks =
115    Arc<RwLock<Vec<Box<dyn for<'a> Fn(&'a PGRaise) + Send + Sync + 'static>>>>;
116
117///
118/// Wraps a [`PGNotifier`] and reconnects upon connection loss.
119///
120/// This struct keeps a callback that can be use to spawn new connections to postgres.
121/// It's called upon each time a connection is lost. All the heavy lifting is actually
122/// done by the [`PGNotifier`] struct.
123///
124#[allow(unused)]
125pub struct PGRobustNotifier<F> {
126    notify_callbacks: NotifyCallbacks,
127    raise_callbacks: RaiseCallbacks,
128    subscriptions: Vec<JoinHandle<()>>,
129    connect: F,
130    inner: PGNotifier,
131}
132
133impl<F, S, T> PGRobustNotifier<F>
134where
135    F: AsyncFn() -> PGResult<(PGClient, PGConnection<S, T>)>,
136    S: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static,
137    T: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static,
138{
139    pub async fn new(connect: F) -> PGResult<Self> {
140        //
141        let (client, conn) = connect().await?;
142        let inner = PGNotifier::spawn(client, conn);
143        let notify_callbacks = inner.notify_callbacks.clone();
144        let raise_callbacks = inner.raise_callbacks.clone();
145
146        Ok(Self {
147            notify_callbacks,
148            raise_callbacks,
149            subscriptions: vec![],
150            connect,
151            inner,
152        })
153    }
154
155    ///
156    /// Attempts to reconnect after a connection loss.
157    ///
158    async fn reconnect(&mut self) -> PGResult<()> {
159        let (client, conn) = (self.connect)().await?;
160        self.inner =
161            PGNotifier::respawn(client, conn, &self.notify_callbacks, &self.raise_callbacks)
162                .await?;
163        Ok(())
164    }
165
166    ///
167    /// Returns the underlying postgres client.
168    /// If the connection has been closed then it is reconnected.
169    ///
170    /// Note that `Client::is_closed` is not reliable unless we have a high frequency TPC keepalive,
171    /// over which we have no control. So we actually attempt a real query each time. Taking inspiration
172    /// from sqlx, we issue a comment request so that it does not show up in logs. The possibility of
173    /// the connection being closed right after the ping still exist but should be handled by the
174    /// caller.
175    ///
176    /// The pool will keep trying to reconnect until it succeeds using exponential backoff with
177    /// additional jitter.
178    ///
179    pub async fn client(&mut self) -> PGResult<&PGClient> {
180        if let Err(e) = self.inner.client.execute("/* PING */", &[]).await
181            && e.is_closed()
182        {
183            // We implement exponential backoff + jitter.
184            let mut k = 1;
185            let mut attempts = 1;
186
187            loop {
188                tracing::info!("Connection is closed. Reconnect attempt #{}", attempts);
189                attempts += 1;
190
191                match self.reconnect().await {
192                    Ok(_) => {
193                        break;
194                    }
195                    Err(e) if e.is_closed() => {
196                        k *= std::cmp::min(k, 60);
197                        let t = k + rand::random_range(0..k);
198                        tokio::time::sleep(tokio::time::Duration::from_secs(t)).await;
199                    }
200                    Err(e) => return Err(e),
201                }
202            }
203        }
204
205        Ok(&self.inner.client)
206    }
207
208    // Forwards the call to the inner notifier.
209    pub async fn subscribe_notify<CB>(
210        &mut self,
211        channel: impl Into<String>,
212        callback: CB,
213    ) -> PGResult<()>
214    where
215        CB: Fn(&PGNotify) + Send + Sync + 'static,
216    {
217        self.inner.subscribe_notify(channel, callback).await
218    }
219
220    // Forwards the call to the inner notifier.
221    pub async fn subscribe_raise(&mut self, callback: impl Fn(&PGRaise) + Send + Sync + 'static) {
222        self.inner.subscribe_raise(callback)
223    }
224
225    // Forwards the call to the inner notifier.
226    pub async fn capture_log(&mut self) -> Option<Vec<PGRaise>> {
227        self.inner.capture_log()
228    }
229
230    // Forwards the call to the inner notifier.
231    pub async fn with_captured_log<CB, Data>(&mut self, f: CB) -> PGResult<(Data, Vec<PGRaise>)>
232    where
233        CB: AsyncFnOnce(&PGClient) -> PGResult<Data>,
234    {
235        self.inner.with_captured_log(f).await
236    }
237}
238
239///
240/// Forwards PostgreSQL `NOTIFY` and `RAISE` commands to subscribers.
241///
242pub struct PGNotifier {
243    pub client: PGClient,
244    listen_handle: JoinHandle<()>,
245    log: Arc<RwLock<Option<Vec<PGRaise>>>>,
246    raise_callbacks: RaiseCallbacks,
247    notify_callbacks: NotifyCallbacks,
248}
249
250impl PGNotifier {
251    ///
252    /// Spawns a new postgres client/connection pair.
253    ///
254    pub fn spawn<S, T>(client: PGClient, mut conn: PGConnection<S, T>) -> Self
255    where
256        S: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static,
257        T: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static,
258    {
259        let log = Arc::new(RwLock::new(Some(Vec::default())));
260        let notify_callbacks: NotifyCallbacks = Arc::new(RwLock::new(BTreeMap::new()));
261        let raise_callbacks: RaiseCallbacks = Arc::new(RwLock::new(Vec::new()));
262
263        // Spawn the connection and poll for messages on it.
264        let listen_handle = {
265            //
266            let log = log.clone();
267            let notify_callbacks = notify_callbacks.clone();
268            let raise_callbacks = raise_callbacks.clone();
269
270            tokio::spawn(async move {
271                //
272                let mut stream =
273                    stream::poll_fn(move |cx| conn.poll_message(cx).map_err(|e| panic!("{}", e)));
274
275                while let Some(msg) = stream.next().await {
276                    match msg {
277                        Ok(AsyncMessage::Notice(raise)) => {
278                            Self::handle_raise(&raise_callbacks, &log, raise)
279                        }
280                        Ok(AsyncMessage::Notification(notice)) => {
281                            Self::handle_notify(&notify_callbacks, notice)
282                        }
283                        _ => {
284                            #[cfg(feature = "tracing")]
285                            tracing::error!("connection to the server was closed");
286                            #[cfg(not(feature = "tracing"))]
287                            eprintln!("connection to the server was closed");
288                            break;
289                        }
290                    }
291                }
292            })
293        };
294
295        Self {
296            client,
297            listen_handle,
298            log,
299            notify_callbacks,
300            raise_callbacks,
301        }
302    }
303
304    ///
305    /// Spawns a new postgres client/connection pair after detecting that connection was lost.
306    ///
307    pub async fn respawn<S, T>(
308        client: PGClient,
309        conn: PGConnection<S, T>,
310        notify_callbacks: &NotifyCallbacks,
311        raise_callbacks: &RaiseCallbacks,
312    ) -> PGResult<Self>
313    where
314        S: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static,
315        T: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static,
316    {
317        let mut notifier = Self::spawn(client, conn);
318        notifier.notify_callbacks = notify_callbacks.clone();
319        notifier.raise_callbacks = raise_callbacks.clone();
320
321        if let Ok(guard) = notify_callbacks.read() {
322            let sql = guard
323                .keys()
324                .map(|channel| format!("LISTEN {}", channel))
325                .collect::<Vec<_>>()
326                .join(";\n");
327            notifier.client.batch_execute(&sql).await?;
328        }
329
330        Ok(notifier)
331    }
332
333    ///
334    /// Handles the notification of LISTEN/NOTIFY subscribers.
335    ///
336    fn handle_notify(callbacks: &NotifyCallbacks, note: Notification) {
337        let notice = PGNotify::new(note.channel(), note.payload());
338        if let Ok(guard) = callbacks.read() {
339            if let Some(cbs) = guard.get(note.channel()) {
340                for callback in cbs.iter() {
341                    callback(&notice);
342                }
343            }
344        }
345    }
346
347    ///
348    /// Handles the notification of `RAISE <level> <message>` subscribers.
349    ///
350    fn handle_raise(
351        callbacks: &RaiseCallbacks,
352        log: &Arc<RwLock<Option<Vec<PGRaise>>>>,
353        raise: DbError,
354    ) {
355        let log_item = PGRaise {
356            #[cfg(feature = "chrono")]
357            timestamp: Utc::now(),
358            #[cfg(not(feature = "chrono"))]
359            timestamp: SystemTime::now(),
360            level: PGRaiseLevel::from_str(raise.severity()).unwrap_or(PGRaiseLevel::Error),
361            message: raise.message().into(),
362        };
363
364        if let Ok(guard) = callbacks.read() {
365            for callback in guard.iter() {
366                callback(&log_item);
367            }
368        }
369
370        if let Ok(mut guard) = log.write() {
371            guard.as_mut().map(|log| log.push(log_item));
372        }
373    }
374
375    ///
376    /// Subscribes to notifications on a particular channel.
377    ///
378    /// The call will issue the `LISTEN` command to PostgreSQL. There is
379    /// currently no mechanism to unsubscribe even though postgres does
380    /// supports UNLISTEN.
381    ///
382    pub async fn subscribe_notify<F>(
383        &mut self,
384        channel: impl Into<String>,
385        callback: F,
386    ) -> PGResult<()>
387    where
388        F: Fn(&PGNotify) + Send + Sync + 'static,
389    {
390        // Issue the listen command to postgres
391        let channel = channel.into();
392        self.client
393            .execute(&format!("LISTEN {}", &channel), &[])
394            .await?;
395
396        // Add the callback to the list of callbacks
397        if let Ok(mut guard) = self.notify_callbacks.write() {
398            guard.entry(channel).or_default().push(Box::new(callback));
399        }
400
401        Ok(())
402    }
403
404    ///
405    /// Subscribes to `RAISE <level> <message>` notifications.
406    ///
407    /// There is currently no mechanism to unsubscribe. This would only require
408    /// returning some form of "token", which could be used to unsubscribe.
409    ///
410    pub fn subscribe_raise(&mut self, callback: impl Fn(&PGRaise) + Send + Sync + 'static) {
411        if let Ok(mut guard) = self.raise_callbacks.write() {
412            guard.push(Box::new(callback));
413        }
414    }
415
416    ///
417    /// Returns the accumulated log since the last capture.
418    ///
419    /// If the code being called issues many `RAISE` commands and you never
420    /// call [`capture_log`](PGNotifier::capture_log), then eventually, you
421    /// might run out of memory. To ensure that this does not happen, you
422    /// might consider using [`with_captured_log`](PGNotifier::with_captured_log)
423    /// instead.
424    ///
425    pub fn capture_log(&self) -> Option<Vec<PGRaise>> {
426        if let Ok(mut guard) = self.log.write() {
427            let captured = guard.take();
428            *guard = Some(Vec::default());
429            captured
430        } else {
431            None
432        }
433    }
434
435    ///
436    /// Given an async closure taking the postgres client, returns the result
437    /// of said closure along with the accumulated log since the beginning of
438    /// the closure.
439    ///
440    /// If you use query pipelining then collect the logs for all queries in
441    /// the pipeline. Otherwise, the logs might not be what you expect.
442    ///
443    pub async fn with_captured_log<F, T>(&self, f: F) -> PGResult<(T, Vec<PGRaise>)>
444    where
445        F: AsyncFnOnce(&PGClient) -> PGResult<T>,
446    {
447        self.capture_log(); // clear the log
448        let result = f(&self.client).await?;
449        let log = self.capture_log().unwrap_or_default();
450        Ok((result, log))
451    }
452}
453
454impl Drop for PGNotifier {
455    fn drop(&mut self) {
456        self.listen_handle.abort();
457    }
458}
459
460///
461/// Message received when a `NOTIFY [channel] [payload]` is issued on PostgreSQL.
462///
463#[derive(Debug, Clone)]
464#[cfg_attr(any(feature = "serde", test), derive(serde::Serialize))]
465pub struct PGNotify {
466    pub channel: String,
467    pub payload: String,
468}
469
470impl PGNotify {
471    pub fn new(channel: impl Into<String>, payload: impl Into<String>) -> Self {
472        Self {
473            channel: channel.into(),
474            payload: payload.into(),
475        }
476    }
477}
478
479///
480/// # Message received when a `raise <level> <message>` is issued on PostgreSQL.
481///
482#[derive(Debug, Clone)]
483#[cfg_attr(any(feature = "serde", test), derive(serde::Serialize))]
484pub struct PGRaise {
485    #[cfg(feature = "chrono")]
486    pub timestamp: DateTime<Utc>,
487    #[cfg(not(feature = "chrono"))]
488    pub timestamp: SystemTime,
489    pub level: PGRaiseLevel,
490    pub message: String,
491}
492
493impl From<DbError> for PGRaise {
494    #[cfg(feature = "chrono")]
495    fn from(raise: DbError) -> Self {
496        PGRaise {
497            timestamp: Utc::now(),
498            level: PGRaiseLevel::from_str(raise.severity()).unwrap_or(PGRaiseLevel::Error),
499            message: raise.message().into(),
500        }
501    }
502
503    #[cfg(not(feature = "chrono"))]
504    fn from(raise: DbError) -> Self {
505        PGRaise {
506            timestamp: SystemTime::now(),
507            level: PGRaiseLevel::from_str(raise.severity()).unwrap_or(PGRaiseLevel::Error),
508            message: raise.message().into(),
509        }
510    }
511}
512
513impl Display for PGRaise {
514    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
515        #[cfg(feature = "chrono")]
516        let ts = self.timestamp.to_rfc3339_opts(SecondsFormat::Millis, true);
517
518        #[cfg(not(feature = "chrono"))]
519        let ts = {
520            let duration = self
521                .timestamp
522                .duration_since(SystemTime::UNIX_EPOCH)
523                .unwrap();
524            let millis = duration.as_millis();
525            format!("{}", millis)
526        };
527
528        write!(f, "{}{:>8}: {}", &ts, &self.level.as_ref(), self.message)
529    }
530}
531
532#[derive(Debug, Clone, Copy)]
533#[cfg_attr(any(feature = "serde", test), derive(serde::Serialize))]
534#[cfg_attr(any(feature = "serde", test), serde(rename_all = "UPPERCASE"))]
535pub enum PGRaiseLevel {
536    Debug,
537    Log,
538    Info,
539    Notice,
540    Warning,
541    Error,
542    Fatal,
543    Panic,
544}
545
546impl AsRef<str> for PGRaiseLevel {
547    fn as_ref(&self) -> &str {
548        use PGRaiseLevel::*;
549        match self {
550            Debug => "DEBUG",
551            Log => "LOG",
552            Info => "INFO",
553            Notice => "NOTICE",
554            Warning => "WARNING",
555            Error => "ERROR",
556            Fatal => "FATAL",
557            Panic => "PANIC",
558        }
559    }
560}
561
562impl FromStr for PGRaiseLevel {
563    type Err = ();
564    fn from_str(s: &str) -> Result<Self, Self::Err> {
565        match s {
566            "DEBUG" => Ok(PGRaiseLevel::Debug),
567            "LOG" => Ok(PGRaiseLevel::Log),
568            "INFO" => Ok(PGRaiseLevel::Info),
569            "NOTICE" => Ok(PGRaiseLevel::Notice),
570            "WARNING" => Ok(PGRaiseLevel::Warning),
571            "ERROR" => Ok(PGRaiseLevel::Error),
572            "FATAL" => Ok(PGRaiseLevel::Fatal),
573            "PANIC" => Ok(PGRaiseLevel::Panic),
574            _ => Err(()),
575        }
576    }
577}
578
579#[cfg(test)]
580mod tests {
581
582    use super::{PGClient, PGNotifier, PGNotify, PGRobustNotifier};
583    use insta::*;
584    use std::sync::{
585        Arc, RwLock,
586        atomic::{AtomicI32, Ordering},
587    };
588    use testcontainers::{ImageExt, runners::AsyncRunner};
589    use testcontainers_modules::postgres::Postgres;
590
591    #[tokio::test]
592    async fn test_integration() {
593        //
594        // --------------------------------------------------------------------
595        // Setup Postgres Server
596        // --------------------------------------------------------------------
597
598        let pg_server = Postgres::default()
599            .with_tag("16.4")
600            .start()
601            .await
602            .expect("could not start postgres server");
603
604        let database_url = format!(
605            "postgres://postgres:postgres@{}:{}/postgres",
606            pg_server.get_host().await.unwrap(),
607            pg_server.get_host_port_ipv4(5432).await.unwrap()
608        );
609
610        // --------------------------------------------------------------------
611        // Connect to the server
612        // --------------------------------------------------------------------
613
614        let (client, conn) = tokio_postgres::connect(&database_url, tokio_postgres::NoTls)
615            .await
616            .expect("could not connect to postgres server");
617
618        let mut notifier = PGNotifier::spawn(client, conn);
619
620        // --------------------------------------------------------------------
621        // Subscribe to notify and raise
622        // --------------------------------------------------------------------
623
624        let notices = Arc::new(RwLock::new(Vec::new()));
625        let notices_clone = notices.clone();
626
627        notifier
628            .subscribe_notify("test", move |notify: &PGNotify| {
629                if let Ok(mut guard) = notices_clone.write() {
630                    guard.push(notify.clone());
631                }
632            })
633            .await
634            .expect("could not subscribe to notifications");
635
636        let (_, execution_log) = notifier
637            .with_captured_log(async |client| {
638                client
639                    .batch_execute(
640                        r#"
641                    set client_min_messages to 'debug';
642                    do $$
643                    begin
644                        raise debug 'this is a DEBUG notification';
645                        notify test, 'test#1';
646                        raise log 'this is a LOG notification';
647                        notify test, 'test#2';
648                        raise info 'this is a INFO notification';
649                        notify test, 'test#3';
650                        raise notice 'this is a NOTICE notification';
651                        notify test, 'test#4';
652                        raise warning 'this is a WARNING notification';
653                        notify test, 'test#5';
654                    end;
655                    $$;
656                "#,
657                    )
658                    .await
659            })
660            .await
661            .expect("could not execute queries on postgres");
662
663        assert_json_snapshot!("raise-notices", &execution_log, {
664            "[].timestamp" => "<timestamp>"
665        });
666
667        let guard = notices.read().expect("could not read notices");
668        let raise_notices = guard.clone();
669        assert_json_snapshot!("listen/notify", &raise_notices);
670
671        // --------------------------------------------------------------------
672        // RobustNotifier
673        // --------------------------------------------------------------------
674
675        let counter = Arc::new(AtomicI32::new(0));
676        let (client, conn) = tokio_postgres::connect(&database_url, tokio_postgres::NoTls)
677            .await
678            .expect("could not connect to postgres server");
679        let admin = PGNotifier::spawn(client, conn);
680
681        let database_url = database_url.to_string();
682        let counter_clone = counter.clone();
683        let mut notifier = PGRobustNotifier::new(async move || {
684            counter_clone.fetch_add(1, Ordering::Relaxed);
685            tokio_postgres::connect(&database_url, tokio_postgres::NoTls).await
686        })
687        .await
688        .expect("could not connect to postgres server");
689
690        let client: &PGClient = notifier.client().await.expect("could not get client");
691        assert!(client.execute("select 1", &[]).await.is_ok());
692
693        admin
694            .client
695            .execute(
696                r#"
697                SELECT pg_terminate_backend(pg_stat_activity.pid)
698                FROM pg_stat_activity
699                WHERE pid <> pg_backend_pid();
700            "#,
701                &[],
702            )
703            .await
704            .expect("could kill other connections");
705
706        let client: &PGClient = notifier.client().await.expect("could not get client");
707        assert!(client.execute("select 1", &[]).await.is_ok());
708        assert!(counter.load(Ordering::Relaxed) == 2);
709    }
710}