postgres_notify/
lib.rs

1//!
2//! `postgres-notify` started out as an easy way to receive PostgreSQL
3//! notifications but has since evolved into a much more useful client
4//! and is able to handle the following:
5//!
6//! - Receive `NOTIFY <channel> <payload>` pub/sub style notifications
7//!
8//! - Receive `RAISE` messages and collects execution logs
9//!
10//! - Applies a timeout to all queries. If a query timesout then the
11//!   client will attempt to cancel the ongoing query before returning
12//!   an error.
13//!
14//! - Supports cancelling an ongoing query.
15//!
16//! - Automatically reconnects if the connection is lost and uses
17//!   exponential backoff with jitter to avoid thundering herd effect.
18//!
19//! - Has a familiar API with an additional `timeout` argument.
20//!
21//!
22//!
23//! # BREAKING CHANGE in v0.3.0
24//!
25//! This latest version is a breaking change. The `PGNotifyingClient` has
26//! been renamed `PGRobustClient` and queries don't need to be made through
27//! the inner client anymore. Furthermore, a single callback handles all
28//! of the notifications: NOTIFY, RAISE, TIMOUT, RECONNECT.
29//!
30//!
31//!
32//! # LISTEN/NOTIFY
33//!
34//! For a very long time (at least since version 7.1) postgres has supported
35//! asynchronous notifications based on LISTEN/NOTIFY commands. This allows
36//! the database to send notifications to the client in an "out-of-band"
37//! channel.
38//!
39//! Once the client has issued a `LISTEN <channel>` command, the database will
40//! send notifications to the client whenever a `NOTIFY <channel> <payload>`
41//! is issued on the database regardless of which session has issued it.
42//! This can act as a cheap alternative to a pub/sub system though without
43//! mailboxes or persistence.
44//!
45//! When calling `subscribe_notify` with a list of channel names, [`PGRobustClient`]
46//! will the client callback any time a `NOTIFY` message is received for any of
47//! the subscribed channels.
48//!
49//! ```rust
50//! use postgres_notify::{PGRobustClient, PGMessage};
51//! use tokio_postgres::NoTls;
52//! use std::time::Duration;
53//!
54//! let rt = tokio::runtime::Builder::new_current_thread()
55//!     .enable_io()
56//!     .enable_time()
57//!     .build()
58//!     .expect("could not start tokio runtime");
59//!
60//! rt.block_on(async move {
61//!     let database_url = "postgres://postgres:postgres@localhost:5432/postgres";
62//!     let callback = |msg:PGMessage| println!("{:?}", &msg);
63//!     let mut client = PGRobustClient::spawn(database_url, NoTls, callback)
64//!         .await.expect("Could not connect to postgres");
65//!
66//!     client.subscribe_notify(&["test"], Some(Duration::from_millis(100)))
67//!         .await.expect("Could not subscribe to channels");
68//! });
69//! ```
70//!
71//!
72//!
73//! # RAISE/LOGS
74//!
75//! Logs in PostgreSQL are created by writing `RAISE <level> <message>` statements
76//! within your functions, stored procedures and scripts. When such a command is
77//! issued, [`PGRobustClient`] receives a notification even if the call is still
78//! in progress. This allows the caller to capture the execution log in realtime
79//! if needed.
80//!
81//! [`PGRobustClient`] simplifies log collection in two ways. Firstly it provides
82//! the [`with_captured_log`](PGRobustClient::with_captured_log) functions,
83//! which collects the execution log and returns it along with the query result.
84//! This is probably what most people will want to use.
85//!
86//! If your needs are more complex or if you want to propagate realtime logs,
87//! then using client callback can be used to forwand the message on an
88//! asynchonous channel.
89//!
90//! ```rust
91//! use postgres_notify::{PGRobustClient, PGMessage};
92//! use tokio_postgres::NoTls;
93//! use std::time::Duration;
94//!
95//! let rt = tokio::runtime::Builder::new_current_thread()
96//!     .enable_io()
97//!     .enable_time()
98//!     .build()
99//!     .expect("could not start tokio runtime");
100//!
101//! rt.block_on(async move {
102//!
103//!     let callback = |msg:PGMessage| println!("{:?}", &msg);
104//!
105//!     let database_url = "postgres://postgres:postgres@localhost:5432/postgres";
106//!     let mut client = PGRobustClient::spawn(database_url, NoTls, callback)
107//!         .await.expect("Could not connect to postgres");
108//!
109//!     // Will capture the notices in a Vec
110//!     let (_, log) = client.with_captured_log(async |client| {
111//!         client.simple_query("
112//!             do $$
113//!             begin
114//!                 raise debug 'this is a DEBUG notification';
115//!                 raise log 'this is a LOG notification';
116//!                 raise info 'this is a INFO notification';
117//!                 raise notice 'this is a NOTICE notification';
118//!                 raise warning 'this is a WARNING notification';
119//!             end;
120//!             $$",
121//!             Some(Duration::from_secs(1))
122//!         ).await.expect("Error during query execution");
123//!         Ok(())
124//!     }).await.expect("Error during captur log");
125//!
126//!     println!("{:#?}", &log);
127//!  });
128//! ```
129//!
130//! Note that the client passed to the async callback is `&mut self`, which
131//! means that all queries within that block are subject to the same timeout
132//! and reconnect handling.
133//!
134//! You can look at the unit tests for a more in-depth example.
135//!
136//!
137//!
138//! # TIMEOUT
139//!
140//! All of the query functions in [`PGRobustClient`] have a `timeout` argument.
141//! If the query takes longer than the timeout, then an error is returned.
142//! If not specified, the default timeout is 1 hour.
143//!
144//!
145//! # RECONNECT
146//!
147//! If the connection to the database is lost, then [`PGRobustClient`] will
148//! attempt to reconnect to the database automatically. If the maximum number
149//! of reconnect attempts is reached then an error is returned. Furthermore,
150//! it uses a exponential backoff with jitter in order to avoid thundering
151//! herd effect.
152//!
153
154mod error;
155mod messages;
156mod notify;
157
158pub use error::*;
159pub use messages::*;
160use tokio_postgres::{SimpleQueryMessage, ToStatement};
161
162use {
163    futures::TryFutureExt,
164    std::{
165        collections::BTreeSet,
166        sync::{Arc, RwLock},
167        time::Duration,
168    },
169    tokio::{
170        task::JoinHandle,
171        time::{sleep, timeout},
172    },
173    tokio_postgres::{
174        CancelToken, Client as PGClient, Row, RowStream, Socket, Statement, Transaction,
175        tls::MakeTlsConnect,
176        types::{BorrowToSql, ToSql, Type},
177    },
178};
179
180/// Shorthand for Result with tokio_postgres::Error
181pub type PGResult<T> = Result<T, PGError>;
182
183pub struct PGRobustClient<TLS>
184where
185    TLS: MakeTlsConnect<Socket>,
186{
187    database_url: String,
188    make_tls: TLS,
189    client: PGClient,
190    conn_handle: JoinHandle<()>,
191    cancel_token: CancelToken,
192    subscriptions: BTreeSet<String>,
193    callback: Arc<dyn Fn(PGMessage) + Send + Sync + 'static>,
194    max_reconnect_attempts: u32,
195    default_timeout: Duration,
196    log: Arc<RwLock<Vec<PGMessage>>>,
197}
198
199#[allow(unused)]
200impl<TLS> PGRobustClient<TLS>
201where
202    TLS: MakeTlsConnect<Socket> + Clone,
203    <TLS as MakeTlsConnect<Socket>>::Stream: Send + Sync + 'static,
204{
205    ///
206    /// Given a connect factory and a callback, returns a new [`PGRobustClient`].
207    ///
208    /// The callback will be called whenever a new NOTIFY/RAISE message is received.
209    /// Furthermore, it is also called with a [`PGMessage::Timeout`], when a query
210    /// times out, [`PGMessage::Disconnected`] if the internal state of the client
211    /// is not as expected (Poisoned lock, dropped connections, etc.) or
212    /// [`PGMessage::Reconnect`] whenever a new reconnect attempt is made.
213    ///
214    pub async fn spawn(
215        database_url: impl AsRef<str>,
216        make_tls: TLS,
217        callback: impl Fn(PGMessage) + Send + Sync + 'static,
218    ) -> PGResult<Self> {
219        //
220        // Setup log and other default values
221        //
222        let log = Arc::new(RwLock::new(Vec::default()));
223        let default_timeout = Duration::from_secs(60 * 60);
224
225        //
226        // We wrap the callback so that it also inserts into the log.
227        //
228        // NOTE: we need to type erase here because otherwise the call to Self::connect
229        //      will not compile.
230        //
231        let callback: Arc<dyn Fn(PGMessage) + Send + Sync + 'static> = Arc::new({
232            let log = log.clone();
233            move |msg: PGMessage| {
234                callback(msg.clone());
235                if let Ok(mut log) = log.write() {
236                    log.push(msg);
237                }
238            }
239        });
240
241        // Connect to the database
242        let (client, conn_handle, cancel_token) =
243            Self::connect(database_url.as_ref(), &make_tls, &callback).await?;
244
245        Ok(Self {
246            database_url: database_url.as_ref().to_string(),
247            make_tls,
248            client,
249            conn_handle,
250            cancel_token,
251            subscriptions: BTreeSet::new(),
252            callback,
253            max_reconnect_attempts: u32::MAX,
254            default_timeout,
255            log,
256        })
257    }
258
259    ///
260    /// Sets the default timeout for all queries. Defaults to 1 hour.
261    ///
262    /// This function consumes and returns self and is therefor usually used
263    /// just after [`PGRobustClient::spawn`].
264    ///
265    pub fn with_default_timeout(mut self, timeout: Duration) -> Self {
266        self.default_timeout = timeout;
267        self
268    }
269
270    ///
271    /// Sets the maximum number of reconnect attempts before giving up.
272    /// Defaults to `u32::MAX`.
273    ///
274    /// This function consumes and returns self and is therefor usually used
275    /// just after [`PGRobustClient::spawn`].
276    ///
277    pub fn with_max_reconnect_attempts(mut self, max_attempts: u32) -> Self {
278        self.max_reconnect_attempts = max_attempts;
279        self
280    }
281
282    ///
283    /// PRIVATE
284    /// Does the necessary details to connect to the database and hookup callbacks and notifications.
285    ///
286    async fn connect(
287        database_url: &str,
288        make_tls: &TLS,
289        callback: &Arc<dyn Fn(PGMessage) + Send + Sync + 'static>,
290    ) -> PGResult<(PGClient, JoinHandle<()>, CancelToken)> {
291        //
292        let (client, conn) = tokio_postgres::connect(database_url, make_tls.clone()).await?;
293        let cancel_token = client.cancel_token();
294
295        let callback = callback.clone();
296        let handle = tokio::spawn(notify::handle_connection_polling(conn, move |msg| {
297            callback(msg)
298        }));
299
300        Ok((client, handle, cancel_token))
301    }
302
303    ///
304    /// Cancels any in-progress query.
305    ///
306    /// This is the only function that does not take a timeout nor does it
307    /// attempt to reconnect if the connection is lost. It will simply
308    /// return the original error.
309    ///
310    pub async fn cancel_query(&mut self) -> PGResult<()> {
311        self.cancel_token
312            .cancel_query(self.make_tls.clone())
313            .await
314            .map_err(Into::into)
315    }
316
317    ///
318    /// Returns the log messages captured since the last call to this function.
319    /// It also clears the log.
320    ///
321    pub fn capture_and_clear_log(&mut self) -> Vec<PGMessage> {
322        if let Ok(mut guard) = self.log.write() {
323            let empty_log = Vec::default();
324            std::mem::replace(&mut *guard, empty_log)
325        } else {
326            Vec::default()
327        }
328    }
329
330    ///
331    /// Given an async closure taking the postgres client, returns the result
332    /// of said closure along with the accumulated log since the beginning of
333    /// the closure.
334    ///
335    /// If you use query pipelining then collect the logs for all queries in
336    /// the pipeline. Otherwise, the logs might not be what you expect.
337    ///
338    pub async fn with_captured_log<F, T>(&mut self, f: F) -> PGResult<(T, Vec<PGMessage>)>
339    where
340        F: AsyncFn(&mut Self) -> PGResult<T>,
341    {
342        self.capture_and_clear_log(); // clear the log just in case...
343        let result = f(self).await?;
344        let log = self.capture_and_clear_log();
345        Ok((result, log))
346    }
347
348    ///
349    /// Attempts to reconnect after a connection loss.
350    ///
351    /// Reconnection applies an exponention backoff with jitter in order to
352    /// avoid thundering herd effect. If the maximum number of attempts is
353    /// reached then an error is returned.
354    ///
355    /// If an error unrelated to establishing a new connection is returned
356    /// when trying to connect then that error is returned.
357    ///
358    pub async fn reconnect(&mut self) -> PGResult<()> {
359        //
360        use std::cmp::{max, min};
361        let mut attempts = 1;
362        let mut k = 500;
363
364        while attempts <= self.max_reconnect_attempts {
365            //
366            // Implement exponential backoff + jitter
367            // Initial delay will be 500ms, max delay is 1h.
368            //
369            sleep(Duration::from_millis(k + rand::random_range(0..k / 2))).await;
370            k = min(k * 2, 60000);
371
372            tracing::info!("Reconnect attempt #{}", attempts);
373            (self.callback)(PGMessage::reconnect(attempts, self.max_reconnect_attempts));
374
375            attempts += 1;
376
377            let maybe_triple =
378                Self::connect(&self.database_url, &self.make_tls, &self.callback).await;
379
380            match maybe_triple {
381                Ok((client, conn_handle, cancel_token)) => {
382                    // Abort the old connection just in case
383                    self.conn_handle.abort();
384
385                    self.client = client;
386                    self.conn_handle = conn_handle;
387                    self.cancel_token = cancel_token;
388
389                    // Resubscribe to previously subscribed channels
390                    let subs: Vec<_> = self.subscriptions.iter().map(String::from).collect();
391
392                    match Self::subscribe_notify_impl(&self.client, &subs).await {
393                        Ok(_) => {
394                            return Ok(());
395                        }
396                        Err(e) if is_pg_connection_issue(&e) => {
397                            continue;
398                        }
399                        Err(e) => {
400                            return Err(e.into());
401                        }
402                    }
403                }
404                Err(e) if e.is_pg_connection_issue() => {
405                    continue;
406                }
407                Err(e) => {
408                    return Err(e);
409                }
410            }
411        }
412
413        // Issue the failed to reconnect message
414        (self.callback)(PGMessage::failed_to_reconnect(self.max_reconnect_attempts));
415        // Return the error
416        Err(PGError::FailedToReconnect(self.max_reconnect_attempts))
417    }
418
419    pub async fn wrap_reconnect<T>(
420        &mut self,
421        max_dur: Option<Duration>,
422        factory: impl AsyncFn(&mut PGClient) -> Result<T, tokio_postgres::Error>,
423    ) -> PGResult<T> {
424        let max_dur = max_dur.unwrap_or(self.default_timeout);
425        loop {
426            match timeout(max_dur, factory(&mut self.client)).await {
427                // Query succeeded so return the result
428                Ok(Ok(o)) => return Ok(o),
429                // Query failed because of connection issues
430                Ok(Err(e)) if is_pg_connection_issue(&e) => {
431                    self.reconnect().await?;
432                }
433                // Query failed for some other reason
434                Ok(Err(e)) => {
435                    return Err(e.into());
436                }
437                // Query timed out!
438                Err(_) => {
439                    // Callback with timeout message
440                    (self.callback)(PGMessage::timeout(max_dur));
441                    // Cancel the ongoing query
442                    let status = self.cancel_token.cancel_query(self.make_tls.clone()).await;
443                    // Callback with cancelled message
444                    (self.callback)(PGMessage::cancelled(!status.is_err()));
445                    // Return the timeout error
446                    return Err(PGError::Timeout(max_dur));
447                }
448            }
449        }
450    }
451
452    pub async fn subscribe_notify(
453        &mut self,
454        channels: &[impl AsRef<str> + Send + Sync + 'static],
455        timeout: Option<Duration>,
456    ) -> PGResult<()> {
457        if !channels.is_empty() {
458            self.wrap_reconnect(timeout, async |client: &mut PGClient| {
459                Self::subscribe_notify_impl(client, channels).await
460            })
461            .await?;
462
463            // Add to our subscriptions
464            channels.iter().for_each(|ch| {
465                self.subscriptions.insert(ch.as_ref().to_string());
466            });
467        }
468        Ok(())
469    }
470
471    async fn subscribe_notify_impl(
472        client: &PGClient,
473        channels: &[impl AsRef<str> + Send + Sync + 'static],
474    ) -> Result<(), tokio_postgres::Error> {
475        // Build a sequence of `LISTEN` commands
476        let sql = channels
477            .iter()
478            .map(|ch| format!("LISTEN {};", ch.as_ref()))
479            .collect::<Vec<_>>()
480            .join("\n");
481
482        // Tell the world we are about to subscribe
483        #[cfg(feature = "tracing")]
484        tracing::info!(
485            "Subscribing to channels: \"{}\"",
486            &channels
487                .iter()
488                .map(AsRef::as_ref)
489                .collect::<Vec<_>>()
490                .join(",")
491        );
492
493        // Issue the `LISTEN` commands
494        client.simple_query(&sql).await?;
495        Ok(())
496    }
497
498    pub async fn unsubscribe_notify(
499        &mut self,
500        channels: &[impl AsRef<str> + Send + Sync + 'static],
501        timeout: Option<Duration>,
502    ) -> PGResult<()> {
503        if !channels.is_empty() {
504            self.wrap_reconnect(timeout, async move |client: &mut PGClient| {
505                // Build a sequence of `LISTEN` commands
506                let sql = channels
507                    .iter()
508                    .map(|ch| format!("UNLISTEN {};", ch.as_ref()))
509                    .collect::<Vec<_>>()
510                    .join("\n");
511
512                // Tell the world we are about to subscribe
513                #[cfg(feature = "tracing")]
514                tracing::info!(
515                    "Unsubscribing from channels: \"{}\"",
516                    &channels
517                        .iter()
518                        .map(AsRef::as_ref)
519                        .collect::<Vec<_>>()
520                        .join(",")
521                );
522
523                // Issue the `LISTEN` commands
524                client.simple_query(&sql).await?;
525                Ok(())
526            })
527            .await?;
528
529            // Remove subscriptions
530            channels.iter().for_each(|ch| {
531                self.subscriptions.remove(ch.as_ref());
532            });
533        }
534        Ok(())
535    }
536
537    ///
538    /// Unsubscribes from all channels.
539    ///
540    pub async fn unsubscribe_notify_all(&mut self, timeout: Option<Duration>) -> PGResult<()> {
541        self.wrap_reconnect(timeout, async |client: &mut PGClient| {
542            // Tell the world we are about to unsubscribe
543            #[cfg(feature = "tracing")]
544            tracing::info!("Unsubscribing from channels: *");
545            // Issue the `UNLISTEN` commands
546            client.simple_query("UNLISTEN *").await?;
547            Ok(())
548        })
549        .await
550    }
551
552    /// Like [`Client::execute_raw`].
553    pub async fn execute_raw<P, I, T>(
554        &mut self,
555        statement: &T,
556        params: I,
557        timeout: Option<Duration>,
558    ) -> PGResult<u64>
559    where
560        T: ?Sized + ToStatement + Sync + Send,
561        P: BorrowToSql + Clone + Send + Sync,
562        I: IntoIterator<Item = P> + Sync + Send,
563        I::IntoIter: ExactSizeIterator,
564    {
565        let params: Vec<_> = params.into_iter().collect();
566        self.wrap_reconnect(timeout, async |client: &mut PGClient| {
567            client.execute_raw(statement, params.clone()).await
568        })
569        .await
570    }
571
572    /// Like [`Client::query`].
573    pub async fn query<T>(
574        &mut self,
575        query: &T,
576        params: &[&(dyn ToSql + Sync)],
577        timeout: Option<Duration>,
578    ) -> PGResult<Vec<Row>>
579    where
580        T: ?Sized + ToStatement + Sync + Send,
581    {
582        let params = params.to_vec();
583        self.wrap_reconnect(timeout, async |client: &mut PGClient| {
584            client.query(query, &params).await
585        })
586        .await
587    }
588
589    /// Like [`Client::query_one`].
590    pub async fn query_one<T>(
591        &mut self,
592        statement: &T,
593        params: &[&(dyn ToSql + Sync)],
594        timeout: Option<Duration>,
595    ) -> PGResult<Row>
596    where
597        T: ?Sized + ToStatement + Sync + Send,
598    {
599        let params = params.to_vec();
600        self.wrap_reconnect(timeout, async |client: &mut PGClient| {
601            client.query_one(statement, &params).await
602        })
603        .await
604    }
605
606    /// Like [`Client::query_opt`].
607    pub async fn query_opt<T>(
608        &mut self,
609        statement: &T,
610        params: &[&(dyn ToSql + Sync)],
611        timeout: Option<Duration>,
612    ) -> PGResult<Option<Row>>
613    where
614        T: ?Sized + ToStatement + Sync + Send,
615    {
616        let params = params.to_vec();
617        self.wrap_reconnect(timeout, async |client: &mut PGClient| {
618            client.query_opt(statement, &params).await
619        })
620        .await
621    }
622
623    /// Like [`Client::query_raw`].
624    pub async fn query_raw<T, P, I>(
625        &mut self,
626        statement: &T,
627        params: I,
628        timeout: Option<Duration>,
629    ) -> PGResult<RowStream>
630    where
631        T: ?Sized + ToStatement + Sync + Send,
632        P: BorrowToSql + Clone + Send + Sync,
633        I: IntoIterator<Item = P> + Sync + Send,
634        I::IntoIter: ExactSizeIterator,
635    {
636        let params: Vec<_> = params.into_iter().collect();
637        self.wrap_reconnect(timeout, async |client: &mut PGClient| {
638            client.query_raw(statement, params.clone()).await
639        })
640        .await
641    }
642
643    /// Like [`Client::query_typed`]
644    pub async fn query_typed(
645        &mut self,
646        statement: &str,
647        params: &[(&(dyn ToSql + Sync), Type)],
648        timeout: Option<Duration>,
649    ) -> PGResult<Vec<Row>> {
650        let params = params.to_vec();
651        self.wrap_reconnect(timeout, async |client: &mut PGClient| {
652            client.query_typed(statement, &params).await
653        })
654        .await
655    }
656
657    /// Like [`Client::query_typed_raw`]
658    pub async fn query_typed_raw<P, I>(
659        &mut self,
660        statement: &str,
661        params: I,
662        timeout: Option<Duration>,
663    ) -> PGResult<RowStream>
664    where
665        P: BorrowToSql + Clone + Send + Sync,
666        I: IntoIterator<Item = (P, Type)> + Sync + Send,
667    {
668        let params: Vec<_> = params.into_iter().collect();
669        self.wrap_reconnect(timeout, async |client: &mut PGClient| {
670            client.query_typed_raw(statement, params.clone()).await
671        })
672        .await
673    }
674
675    /// Like [`Client::prepare`].
676    pub async fn prepare(&mut self, query: &str, timeout: Option<Duration>) -> PGResult<Statement> {
677        self.wrap_reconnect(timeout, async |client: &mut PGClient| {
678            client.prepare(query).map_err(Into::into).await
679        })
680        .await
681    }
682
683    /// Like [`Client::prepare_typed`].
684    pub async fn prepare_typed(
685        &mut self,
686        query: &str,
687        parameter_types: &[Type],
688        timeout: Option<Duration>,
689    ) -> PGResult<Statement> {
690        self.wrap_reconnect(timeout, async |client: &mut PGClient| {
691            client.prepare_typed(query, parameter_types).await
692        })
693        .await
694    }
695
696    //
697    /// Similar but not quite the same as [`Client::transaction`].
698    ///
699    /// Executes the closure as a single transaction.
700    /// Commit is automatically called after the closure. If any connection
701    /// issues occur during the transaction then the transaction is rolled
702    /// back (on drop) and retried a new with the new connection subject to
703    /// the maximum number of reconnect attempts.
704    ///
705    pub async fn transaction<F>(&mut self, timeout: Option<Duration>, f: F) -> PGResult<()>
706    where
707        for<'a> F: AsyncFn(&'a mut Transaction) -> Result<(), tokio_postgres::Error>,
708    {
709        self.wrap_reconnect(timeout, async |client: &mut PGClient| {
710            let mut tx = client.transaction().await?;
711            f(&mut tx).await?;
712            tx.commit().await?;
713            Ok(())
714        })
715        .await
716    }
717
718    /// Like [`Client::batch_execute`].
719    pub async fn batch_execute(&mut self, query: &str, timeout: Option<Duration>) -> PGResult<()> {
720        self.wrap_reconnect(timeout, async |client: &mut PGClient| {
721            client.batch_execute(query).await
722        })
723        .await
724    }
725
726    /// Like [`Client::simple_query`].
727    pub async fn simple_query(
728        &mut self,
729        query: &str,
730        timeout: Option<Duration>,
731    ) -> PGResult<Vec<SimpleQueryMessage>> {
732        self.wrap_reconnect(timeout, async |client: &mut PGClient| {
733            client.simple_query(query).await
734        })
735        .await
736    }
737
738    /// Returns a reference to the underlying [`Client`].
739    pub fn client(&self) -> &PGClient {
740        &self.client
741    }
742}
743
744///
745/// Wraps any future in a tokio timeout and maps the Elapsed error to a PGError::Timeout.
746///
747pub async fn wrap_timeout<T>(dur: Duration, fut: impl Future<Output = PGResult<T>>) -> PGResult<T> {
748    match timeout(dur, fut).await {
749        Ok(out) => out,
750        Err(_) => Err(PGError::Timeout(dur)),
751    }
752}
753
754#[cfg(test)]
755mod tests {
756
757    use {
758        super::{PGError, PGMessage, PGRaiseLevel, PGRobustClient},
759        insta::*,
760        std::{
761            sync::{Arc, RwLock},
762            time::Duration,
763        },
764        testcontainers::{ImageExt, runners::AsyncRunner},
765        testcontainers_modules::postgres::Postgres,
766    };
767
768    fn sql_for_log_and_notify_test(level: PGRaiseLevel) -> String {
769        format!(
770            r#"
771                    set client_min_messages to '{}';
772                    do $$
773                    begin
774                        raise debug 'this is a DEBUG notification';
775                        notify test, 'test#1';
776                        raise log 'this is a LOG notification';
777                        notify test, 'test#2';
778                        raise info 'this is a INFO notification';
779                        notify test, 'test#3';
780                        raise notice 'this is a NOTICE notification';
781                        notify test, 'test#4';
782                        raise warning 'this is a WARNING notification';
783                        notify test, 'test#5';
784                    end;
785                    $$;
786                "#,
787            level
788        )
789    }
790
791    #[tokio::test]
792    async fn test_integration() {
793        //
794        // --------------------------------------------------------------------
795        // Setup Postgres Server
796        // --------------------------------------------------------------------
797
798        let pg_server = Postgres::default()
799            .with_tag("16.4")
800            .start()
801            .await
802            .expect("could not start postgres server");
803
804        // NOTE: this stuff with Box::leak allows us to create a static string
805        let database_url = format!(
806            "postgres://postgres:postgres@{}:{}/postgres",
807            pg_server.get_host().await.unwrap(),
808            pg_server.get_host_port_ipv4(5432).await.unwrap()
809        );
810
811        // let database_url = "postgres://postgres:postgres@localhost:5432/postgres";
812
813        // --------------------------------------------------------------------
814        // Connect to the server
815        // --------------------------------------------------------------------
816
817        let notices = Arc::new(RwLock::new(Vec::new()));
818        let notices_clone = notices.clone();
819
820        let callback = move |msg: PGMessage| {
821            if let Ok(mut guard) = notices_clone.write() {
822                guard.push(msg.to_string());
823            }
824        };
825
826        let mut admin = PGRobustClient::spawn(&database_url, tokio_postgres::NoTls, |_| {})
827            .await
828            .expect("could not create initial client");
829
830        let mut client = PGRobustClient::spawn(&database_url, tokio_postgres::NoTls, callback)
831            .await
832            .expect("could not create initial client")
833            .with_max_reconnect_attempts(2);
834
835        // --------------------------------------------------------------------
836        // Subscribe to notify and raise
837        // --------------------------------------------------------------------
838
839        client
840            .subscribe_notify(&["test"], None)
841            .await
842            .expect("could not subscribe");
843
844        let (_, execution_log) = client
845            .with_captured_log(async |client: &mut PGRobustClient<_>| {
846                client
847                    .simple_query(&sql_for_log_and_notify_test(PGRaiseLevel::Debug), None)
848                    .await
849            })
850            .await
851            .expect("could not execute queries on postgres");
852
853        assert_json_snapshot!("subscribed-executionlog", &execution_log, {
854            "[].timestamp" => "<timestamp>",
855            "[].process_id" => "<pid>",
856        });
857
858        assert_snapshot!("subscribed-notify", extract_and_clear_logs(&notices));
859
860        // --------------------------------------------------------------------
861        // Unsubscribe
862        // --------------------------------------------------------------------
863
864        client
865            .unsubscribe_notify(&["test"], None)
866            .await
867            .expect("could not unsubscribe");
868
869        let (_, execution_log) = client
870            .with_captured_log(async |client| {
871                client
872                    .simple_query(&sql_for_log_and_notify_test(PGRaiseLevel::Warning), None)
873                    .await
874            })
875            .await
876            .expect("could not execute queries on postgres");
877
878        assert_json_snapshot!("unsubscribed-executionlog", &execution_log, {
879            "[].timestamp" => "<timestamp>",
880            "[].process_id" => "<pid>",
881        });
882
883        assert_snapshot!("unsubscribed-notify", extract_and_clear_logs(&notices));
884
885        // --------------------------------------------------------------------
886        // Timeout
887        // --------------------------------------------------------------------
888
889        let result = client
890            .simple_query(
891                "
892                    do $$
893                    begin
894                        raise info 'before sleep';
895                        perform pg_sleep(3);
896                        raise info 'after sleep';
897                    end;
898                    $$
899                ",
900                Some(Duration::from_secs(1)),
901            )
902            .await;
903
904        assert!(matches!(result, Err(PGError::Timeout(_))));
905        assert_snapshot!("timeout-messages", extract_and_clear_logs(&notices));
906
907        // --------------------------------------------------------------------
908        // Reconnect (before query)
909        // --------------------------------------------------------------------
910
911        admin.simple_query("select pg_terminate_backend(pid) from pg_stat_activity where pid != pg_backend_pid()", None)
912            .await.expect("could not kill other client");
913
914        let result = client
915            .simple_query(
916                "
917                    do $$
918                    begin
919                        raise info 'before sleep';
920                        perform pg_sleep(1);
921                        raise info 'after sleep';
922                    end;
923                    $$
924                ",
925                Some(Duration::from_secs(10)),
926            )
927            .await;
928
929        assert!(matches!(result, Ok(_)));
930        assert_snapshot!("reconnect-before", extract_and_clear_logs(&notices));
931
932        // --------------------------------------------------------------------
933        // Reconnect (during query)
934        // --------------------------------------------------------------------
935
936        let query = client.simple_query(
937            "
938                    do $$
939                    begin
940                        raise info 'before sleep';
941                        perform pg_sleep(1);
942                        raise info 'after sleep';
943                    end;
944                    $$
945                ",
946            None,
947        );
948
949        let kill_later = 
950            admin.simple_query("
951                select pg_sleep(0.5); 
952                select pg_terminate_backend(pid) from pg_stat_activity where pid != pg_backend_pid()", 
953                None
954            );
955
956        let (_, result) = tokio::join!(kill_later, query);
957
958        assert!(matches!(result, Ok(_)));
959        assert_snapshot!("reconnect-during", extract_and_clear_logs(&notices));
960
961        // --------------------------------------------------------------------
962        // Reconnect (failure)
963        // --------------------------------------------------------------------
964
965        pg_server.stop().await.expect("could not stop server");
966
967        let result = client.simple_query(
968            "
969                do $$
970                begin
971                    raise info 'before sleep';
972                    perform pg_sleep(1);
973                    raise info 'after sleep';
974                end;
975                $$
976            ",
977            None,
978        ).await;
979
980        eprintln!("result: {result:?}");
981        assert!(matches!(result, Err(PGError::FailedToReconnect(2))));
982        assert_snapshot!("reconnect-failure", extract_and_clear_logs(&notices));
983
984
985    }
986
987    fn extract_and_clear_logs(logs: &Arc<RwLock<Vec<String>>>) -> String {
988        let mut guard = logs.write().expect("could not read notices");
989        let emtpy_log = Vec::default();
990        let log = std::mem::replace(&mut *guard, emtpy_log);
991        redact_pids(&redact_timestamps(&log.join("\n")))
992    }
993
994    fn redact_timestamps(text: &str) -> String {
995        use regex::Regex;
996        use std::sync::OnceLock;
997        pub static TIMESTAMP_PATTERN: OnceLock<Regex> = OnceLock::new();
998        let pat = TIMESTAMP_PATTERN.get_or_init(|| {
999            Regex::new(r"\d{4}-\d{2}-\d{2}.?\d{2}:\d{2}:\d{2}(\.\d{3,9})?(Z| UTC|[+-]\d{2}:\d{2})?")
1000                .unwrap()
1001        });
1002        pat.replace_all(text, "<timestamp>").to_string()
1003    }
1004
1005    fn redact_pids(text: &str) -> String {
1006        use regex::Regex;
1007        use std::sync::OnceLock;
1008        pub static TIMESTAMP_PATTERN: OnceLock<Regex> = OnceLock::new();
1009        let pat = TIMESTAMP_PATTERN.get_or_init(|| Regex::new(r"pid=\d+").unwrap());
1010        pat.replace_all(text, "<pid>").to_string()
1011    }
1012}