postgres_notify/
lib.rs

1//!
2//! `postgres-notify` started out as an easy way to receive PostgreSQL
3//! notifications but has since evolved into a much more useful client
4//! and is able to handle the following:
5//!
6//! - Receive `NOTIFY <channel> <payload>` pub/sub style notifications
7//!
8//! - Receive `RAISE` messages and collects execution logs
9//!
10//! - Applies a timeout to all queries. If a query timesout then the
11//!   client will attempt to cancel the ongoing query before returning
12//!   an error.
13//!
14//! - Supports cancelling an ongoing query.
15//!
16//! - Automatically reconnects if the connection is lost and uses
17//!   exponential backoff with jitter to avoid thundering herd effect.
18//!
19//! - Supports an `connect_script`, which can be executed on connect.
20//! 
21//! - Has a familiar API with an additional `timeout` argument.
22//!
23//!
24//! # BREAKING CHANGE in v0.3.2
25//!
26//! Configuration is done through the [`PGRobustClientConfig`] struct.
27//! 
28//!
29//! # BREAKING CHANGE in v0.3.0
30//!
31//! This latest version is a breaking change. The `PGNotifyingClient` has
32//! been renamed `PGRobustClient` and queries don't need to be made through
33//! the inner client anymore. Furthermore, a single callback handles all
34//! of the notifications: NOTIFY, RAISE, TIMOUT, RECONNECT.
35//!
36//!
37//!
38//! # LISTEN/NOTIFY
39//!
40//! For a very long time (at least since version 7.1) postgres has supported
41//! asynchronous notifications based on LISTEN/NOTIFY commands. This allows
42//! the database to send notifications to the client in an "out-of-band"
43//! channel.
44//!
45//! Once the client has issued a `LISTEN <channel>` command, the database will
46//! send notifications to the client whenever a `NOTIFY <channel> <payload>`
47//! is issued on the database regardless of which session has issued it.
48//! This can act as a cheap alternative to a pub/sub system though without
49//! mailboxes or persistence.
50//!
51//! When calling `subscribe_notify` with a list of channel names, [`PGRobustClient`]
52//! will the client callback any time a `NOTIFY` message is received for any of
53//! the subscribed channels.
54//!
55//! ```rust
56//! use postgres_notify::{PGRobustClientConfig, PGRobustClient, PGMessage};
57//! use tokio_postgres::NoTls;
58//! use std::time::Duration;
59//!
60//! let rt = tokio::runtime::Builder::new_current_thread()
61//!     .enable_io()
62//!     .enable_time()
63//!     .build()
64//!     .expect("could not start tokio runtime");
65//!
66//! rt.block_on(async move {
67//!     
68//!     let database_url = "postgres://postgres:postgres@localhost:5432/postgres";
69//!     let config = PGRobustClientConfig::new(database_url, NoTls)
70//!         .callback(|msg:PGMessage| println!("{:?}", &msg));
71//!
72//!     let mut client = PGRobustClient::spawn(config)
73//!         .await.expect("Could not connect to postgres");
74//!
75//!     client.subscribe_notify(&["test"], Some(Duration::from_millis(100)))
76//!         .await.expect("Could not subscribe to channels");
77//! });
78//! ```
79//!
80//!
81//!
82//! # RAISE/LOGS
83//!
84//! Logs in PostgreSQL are created by writing `RAISE <level> <message>` statements
85//! within your functions, stored procedures and scripts. When such a command is
86//! issued, [`PGRobustClient`] receives a notification even if the call is still
87//! in progress. This allows the caller to capture the execution log in realtime
88//! if needed.
89//!
90//! [`PGRobustClient`] simplifies log collection in two ways. Firstly it provides
91//! the [`with_captured_log`](PGRobustClient::with_captured_log) functions,
92//! which collects the execution log and returns it along with the query result.
93//! This is probably what most people will want to use.
94//!
95//! If your needs are more complex or if you want to propagate realtime logs,
96//! then using client callback can be used to forwand the message on an
97//! asynchonous channel.
98//!
99//! ```rust
100//! use postgres_notify::{PGRobustClient, PGRobustClientConfig, PGMessage};
101//! use tokio_postgres::NoTls;
102//! use std::time::Duration;
103//!
104//! let rt = tokio::runtime::Builder::new_current_thread()
105//!     .enable_io()
106//!     .enable_time()
107//!     .build()
108//!     .expect("could not start tokio runtime");
109//!
110//! rt.block_on(async move {
111//!
112//!     let database_url = "postgres://postgres:postgres@localhost:5432/postgres";
113//!     let config = PGRobustClientConfig::new(database_url, NoTls)
114//!         .callback(|msg:PGMessage| println!("{:?}", &msg));
115//! 
116//!     let mut client = PGRobustClient::spawn(config)
117//!         .await.expect("Could not connect to postgres");
118//!
119//!     // Will capture the notices in a Vec
120//!     let (_, log) = client.with_captured_log(async |client| {
121//!         client.simple_query("
122//!             do $$
123//!             begin
124//!                 raise debug 'this is a DEBUG notification';
125//!                 raise log 'this is a LOG notification';
126//!                 raise info 'this is a INFO notification';
127//!                 raise notice 'this is a NOTICE notification';
128//!                 raise warning 'this is a WARNING notification';
129//!             end;
130//!             $$",
131//!             Some(Duration::from_secs(1))
132//!         ).await.expect("Error during query execution");
133//!         Ok(())
134//!     }).await.expect("Error during captur log");
135//!
136//!     println!("{:#?}", &log);
137//!  });
138//! ```
139//!
140//! Note that the client passed to the async callback is `&mut self`, which
141//! means that all queries within that block are subject to the same timeout
142//! and reconnect handling.
143//!
144//! You can look at the unit tests for a more in-depth example.
145//!
146//!
147//!
148//! # TIMEOUT
149//!
150//! All of the query functions in [`PGRobustClient`] have a `timeout` argument.
151//! If the query takes longer than the timeout, then an error is returned.
152//! If not specified, the default timeout is 1 hour.
153//!
154//!
155//! # RECONNECT
156//!
157//! If the connection to the database is lost, then [`PGRobustClient`] will
158//! attempt to reconnect to the database automatically. If the maximum number
159//! of reconnect attempts is reached then an error is returned. Furthermore,
160//! it uses a exponential backoff with jitter in order to avoid thundering
161//! herd effect.
162//!
163//!
164//! # CALLBACK SAFETY
165//!
166//! The callback function runs in a background tokio task that polls the
167//! PostgreSQL connection. If the callback panics:
168//!
169//! - The `RwLock` protecting the message log will be poisoned
170//! - Subsequent calls to [`capture_and_clear_log`](PGRobustClient::capture_and_clear_log) will return empty vectors
171//! - The connection polling task will terminate
172//!
173//! **Recommendation**: Ensure callbacks do not panic. Use `std::panic::catch_unwind`
174//! if calling untrusted code within the callback.
175
176mod error;
177mod messages;
178mod notify;
179mod config;
180mod inner;
181
182pub use error::*;
183pub use messages::*;
184use inner::*;
185pub use config::*;
186
187use tokio_postgres::{SimpleQueryMessage, ToStatement};
188
189use {
190    futures::TryFutureExt,
191    std::{
192        time::Duration,
193    },
194    tokio::{
195        time::{sleep, timeout},
196    },
197    tokio_postgres::{
198        Row, RowStream, Socket, Statement, Transaction,
199        tls::MakeTlsConnect,
200        types::{BorrowToSql, ToSql, Type},
201    },
202};
203
204/// Shorthand for Result with tokio_postgres::Error
205pub type PGResult<T> = Result<T, PGError>;
206
207
208
209pub struct PGRobustClient<TLS>
210{
211    config: PGRobustClientConfig<TLS>,
212    inner: PGClient,
213}
214
215#[allow(unused)]
216impl<TLS> PGRobustClient<TLS>
217where
218    TLS: MakeTlsConnect<Socket> + Clone,
219    <TLS as MakeTlsConnect<Socket>>::Stream: Send + Sync + 'static,
220{
221    ///
222    /// Connects to the database and returns a new client.
223    /// 
224    pub async fn spawn(config: PGRobustClientConfig<TLS>) -> PGResult<PGRobustClient<TLS>> {
225        let inner = PGClient::connect(&config).await?;
226        Ok(PGRobustClient { config, inner })
227    }
228
229    ///
230    /// Returns a reference to the config object used to create this client.
231    /// 
232    pub fn config(&self) -> &PGRobustClientConfig<TLS> {
233        &self.config
234    }
235
236    ///
237    /// Returns a mutable reference to the config object used to create this client.
238    /// Some changes only take effect on the next connection. Others are immediate.
239    ///
240    pub fn config_mut(&mut self) -> &mut PGRobustClientConfig<TLS> {
241        &mut self.config
242    }   
243    
244    ///
245    /// Cancels any query in-progress.
246    ///
247    /// This is the only function that does not take a timeout nor does it
248    /// attempt to reconnect if the connection is lost. It will simply
249    /// return the original error.
250    ///
251    pub async fn cancel_query(&mut self) -> PGResult<()> {
252        self.inner
253            .cancel_token
254            .cancel_query(self.config.make_tls.clone())
255            .await
256            .map_err(Into::into)
257    }
258
259    ///
260    /// Returns the log messages captured since the last call to this function.
261    /// It also clears the log.
262    ///
263    pub fn capture_and_clear_log(&mut self) -> Vec<PGMessage> {
264        match self.inner.log.write() {
265            Ok(mut guard) => {
266                let empty_log = Vec::default();
267                std::mem::replace(&mut *guard, empty_log)
268            }
269            Err(_) => {
270                #[cfg(feature = "tracing")]
271                tracing::error!("Lock poisoned in capture_and_clear_log - returning empty log");
272                Vec::default()
273            }
274        }
275    }
276
277    ///
278    /// Clears the message log without returning its contents.
279    ///
280    fn clear_log(&mut self) {
281        if let Ok(mut guard) = self.inner.log.write() {
282            guard.clear();
283        }
284    }
285
286    ///
287    /// Given an async closure taking the postgres client, returns the result
288    /// of said closure along with the accumulated log since the beginning of
289    /// the closure.
290    ///
291    /// If you use query pipelining then collect the logs for all queries in
292    /// the pipeline. Otherwise, the logs might not be what you expect.
293    ///
294    pub async fn with_captured_log<F, T>(&mut self, f: F) -> PGResult<(T, Vec<PGMessage>)>
295    where
296        F: AsyncFn(&mut Self) -> PGResult<T>,
297    {
298        self.capture_and_clear_log(); // clear the log just in case...
299        let result = f(self).await?;
300        let log = self.capture_and_clear_log();
301        Ok((result, log))
302    }
303
304    ///
305    /// Attempts to reconnect after a connection loss.
306    ///
307    /// Reconnection applies an exponention backoff with jitter in order to
308    /// avoid thundering herd effect. If the maximum number of attempts is
309    /// reached then an error is returned.
310    ///
311    /// If an error unrelated to establishing a new connection is returned
312    /// when trying to connect then that error is returned.
313    ///
314    async fn reconnect(&mut self) -> PGResult<()> {
315        //
316        use std::cmp::{max, min};
317        let mut attempts = 1;
318        let mut k = 500;
319
320        while attempts <= self.config.max_reconnect_attempts {
321            //
322            // Implement exponential backoff + jitter
323            // Initial delay will be 500ms, max delay is 1h.
324            //
325            sleep(Duration::from_millis(k + rand::random_range(0..k / 2))).await;
326            k = min(k * 2, 60000);
327
328            #[cfg(feature = "tracing")]
329            tracing::info!("Reconnect attempt #{}", attempts);
330            (self.config.callback)(PGMessage::reconnect(attempts, self.config.max_reconnect_attempts));
331
332            attempts += 1;
333
334            match PGClient::connect(&self.config).await {
335                Ok(inner) => {
336
337                    self.inner = inner;
338
339                    (self.config.callback)(PGMessage::connected());
340                    
341                    if let Some(sql) = self.config.full_connect_script() {
342                        match self.inner.simple_query(&sql).await {
343                            Ok(_) => {
344                                return Ok(());
345                            }
346                            Err(e) if is_pg_connection_issue(&e) => {
347                                continue;
348                            }
349                            Err(e) => {
350                                return Err(e.into());
351                            }
352                        }
353                    } else {
354                        return Ok(());
355                    }
356                }
357                Err(e) if e.is_pg_connection_issue() => {
358                    continue;
359                }
360                Err(e) => {
361                    return Err(e);
362                }
363            }
364        }
365
366        // Issue the failed to reconnect message
367        (self.config.callback)(PGMessage::failed_to_reconnect(self.config.max_reconnect_attempts));
368        // Return the error
369        Err(PGError::FailedToReconnect(self.config.max_reconnect_attempts))
370    }
371
372
373    ///
374    /// Wraps most calls that use the client with a timeout and reconnect loop.
375    ///
376    /// If you lose the connection during a query, the client will automatically
377    /// reconnect and retry the query.
378    ///
379    /// **Note**: This method clears the message log at the start of each call.
380    /// Messages from previous operations are discarded. Use [`with_captured_log`](Self::with_captured_log)
381    /// if you need to preserve and retrieve messages from a specific operation.
382    ///
383    pub async fn wrap_reconnect<T>(
384        &mut self,
385        max_dur: Option<Duration>,
386        factory: impl AsyncFn(&mut PGClient) -> Result<T, tokio_postgres::Error>,
387    ) -> PGResult<T> {
388        // Clear any accumulated messages from previous operations
389        self.clear_log();
390        let max_dur = max_dur.unwrap_or(self.config.default_timeout);
391        loop {
392            match timeout(max_dur, factory(&mut self.inner)).await {
393                // Query succeeded so return the result
394                Ok(Ok(o)) => return Ok(o),
395                // Query failed because of connection issues
396                Ok(Err(e)) if is_pg_connection_issue(&e) => {
397                    self.reconnect().await?;
398                }
399                // Query failed for some other reason
400                Ok(Err(e)) => {
401                    return Err(e.into());
402                }
403                // Query timed out!
404                Err(_) => {
405                    // Callback with timeout message
406                    (self.config.callback)(PGMessage::timeout(max_dur));
407                    // Cancel the ongoing query
408                    let status = self.inner.cancel_token.cancel_query(self.config.make_tls.clone()).await;
409                    // Callback with cancelled message
410                    (self.config.callback)(PGMessage::cancelled(!status.is_err()));
411                    // Return the timeout error
412                    return Err(PGError::Timeout(max_dur));
413                }
414            }
415        }
416    }
417
418    pub async fn subscribe_notify(
419        &mut self,
420        channels: &[impl AsRef<str> + Send + Sync + 'static],
421        timeout: Option<Duration>,
422    ) -> PGResult<()> {
423
424        if !channels.is_empty() {
425            // Issue the `LISTEN` commands with protection
426            self.wrap_reconnect(timeout, async |client: &mut PGClient| {
427                PGClient::issue_listen(client, channels).await
428            })
429            .await?;
430
431            // Add to our subscriptions
432            self.config.with_subscriptions(channels.iter().map(AsRef::as_ref));
433        }
434        Ok(())
435    }
436
437
438
439    pub async fn unsubscribe_notify(
440        &mut self,
441        channels: &[impl AsRef<str> + Send + Sync + 'static],
442        timeout: Option<Duration>,
443    ) -> PGResult<()> {
444        if !channels.is_empty() {
445            // Issue the `UNLISTEN` commands with protection
446            self.wrap_reconnect(timeout, async move |client: &mut PGClient| {
447                PGClient::issue_unlisten(client, channels).await
448            })
449            .await?;
450
451            // Remove subscriptions
452            self.config.without_subscriptions(channels.iter().map(AsRef::as_ref));
453        }
454        Ok(())
455    }
456
457    ///
458    /// Unsubscribes from all channels.
459    ///
460    pub async fn unsubscribe_notify_all(&mut self, timeout: Option<Duration>) -> PGResult<()> {
461        self.wrap_reconnect(timeout, async |client: &mut PGClient| {
462            // Tell the world we are about to unsubscribe
463            #[cfg(feature = "tracing")]
464            tracing::info!("Unsubscribing from channels: *");
465            // Issue the `UNLISTEN` commands
466            client.simple_query("UNLISTEN *").await?;
467            Ok(())
468        })
469        .await
470    }
471
472
473    /// Like [`Client::execute_raw`].
474    pub async fn execute_raw<P, I, T>(
475        &mut self,
476        statement: &T,
477        params: I,
478        timeout: Option<Duration>,
479    ) -> PGResult<u64>
480    where
481        T: ?Sized + ToStatement + Sync + Send,
482        P: BorrowToSql + Clone + Send + Sync,
483        I: IntoIterator<Item = P> + Sync + Send,
484        I::IntoIter: ExactSizeIterator,
485    {
486        let params: Vec<_> = params.into_iter().collect();
487        self.wrap_reconnect(timeout, async |client: &mut PGClient| {
488            client.execute_raw(statement, params.clone()).await
489        })
490        .await
491    }
492
493    /// Like [`Client::query`].
494    ///
495    /// **Note**: Parameters are cloned into a `Vec` before the async operation
496    /// to satisfy lifetime requirements. For bulk operations with many large
497    /// parameters, consider using [`query_raw`](Self::query_raw) or [`execute_raw`](Self::execute_raw)
498    /// which may be more efficient depending on your use case.
499    pub async fn query<T>(
500        &mut self,
501        query: &T,
502        params: &[&(dyn ToSql + Sync)],
503        timeout: Option<Duration>,
504    ) -> PGResult<Vec<Row>>
505    where
506        T: ?Sized + ToStatement + Sync + Send,
507    {
508        let params = params.to_vec();
509        self.wrap_reconnect(timeout, async |client: &mut PGClient| {
510            client.query(query, &params).await
511        })
512        .await
513    }
514
515    /// Like [`Client::query_one`].
516    pub async fn query_one<T>(
517        &mut self,
518        statement: &T,
519        params: &[&(dyn ToSql + Sync)],
520        timeout: Option<Duration>,
521    ) -> PGResult<Row>
522    where
523        T: ?Sized + ToStatement + Sync + Send,
524    {
525        let params = params.to_vec();
526        self.wrap_reconnect(timeout, async |client: &mut PGClient| {
527            client.query_one(statement, &params).await
528        })
529        .await
530    }
531
532    /// Like [`Client::query_opt`].
533    pub async fn query_opt<T>(
534        &mut self,
535        statement: &T,
536        params: &[&(dyn ToSql + Sync)],
537        timeout: Option<Duration>,
538    ) -> PGResult<Option<Row>>
539    where
540        T: ?Sized + ToStatement + Sync + Send,
541    {
542        let params = params.to_vec();
543        self.wrap_reconnect(timeout, async |client: &mut PGClient| {
544            client.query_opt(statement, &params).await
545        })
546        .await
547    }
548
549    /// Like [`Client::query_raw`].
550    pub async fn query_raw<T, P, I>(
551        &mut self,
552        statement: &T,
553        params: I,
554        timeout: Option<Duration>,
555    ) -> PGResult<RowStream>
556    where
557        T: ?Sized + ToStatement + Sync + Send,
558        P: BorrowToSql + Clone + Send + Sync,
559        I: IntoIterator<Item = P> + Sync + Send,
560        I::IntoIter: ExactSizeIterator,
561    {
562        let params: Vec<_> = params.into_iter().collect();
563        self.wrap_reconnect(timeout, async |client: &mut PGClient| {
564            client.query_raw(statement, params.clone()).await
565        })
566        .await
567    }
568
569    /// Like [`Client::query_typed`]
570    pub async fn query_typed(
571        &mut self,
572        statement: &str,
573        params: &[(&(dyn ToSql + Sync), Type)],
574        timeout: Option<Duration>,
575    ) -> PGResult<Vec<Row>> {
576        let params = params.to_vec();
577        self.wrap_reconnect(timeout, async |client: &mut PGClient| {
578            client.query_typed(statement, &params).await
579        })
580        .await
581    }
582
583    /// Like [`Client::query_typed_raw`]
584    pub async fn query_typed_raw<P, I>(
585        &mut self,
586        statement: &str,
587        params: I,
588        timeout: Option<Duration>,
589    ) -> PGResult<RowStream>
590    where
591        P: BorrowToSql + Clone + Send + Sync,
592        I: IntoIterator<Item = (P, Type)> + Sync + Send,
593    {
594        let params: Vec<_> = params.into_iter().collect();
595        self.wrap_reconnect(timeout, async |client: &mut PGClient| {
596            client.query_typed_raw(statement, params.clone()).await
597        })
598        .await
599    }
600
601    /// Like [`Client::prepare`].
602    pub async fn prepare(&mut self, query: &str, timeout: Option<Duration>) -> PGResult<Statement> {
603        self.wrap_reconnect(timeout, async |client: &mut PGClient| {
604            client.prepare(query).map_err(Into::into).await
605        })
606        .await
607    }
608
609    /// Like [`Client::prepare_typed`].
610    pub async fn prepare_typed(
611        &mut self,
612        query: &str,
613        parameter_types: &[Type],
614        timeout: Option<Duration>,
615    ) -> PGResult<Statement> {
616        self.wrap_reconnect(timeout, async |client: &mut PGClient| {
617            client.prepare_typed(query, parameter_types).await
618        })
619        .await
620    }
621
622    //
623    /// Similar but not quite the same as [`Client::transaction`].
624    ///
625    /// Executes the closure as a single transaction.
626    /// Commit is automatically called after the closure. If any connection
627    /// issues occur during the transaction then the transaction is rolled
628    /// back (on drop) and retried a new with the new connection subject to
629    /// the maximum number of reconnect attempts.
630    ///
631    pub async fn transaction<F>(&mut self, timeout: Option<Duration>, f: F) -> PGResult<()>
632    where
633        for<'a> F: AsyncFn(&'a mut Transaction) -> Result<(), tokio_postgres::Error>,
634    {
635        self.wrap_reconnect(timeout, async |client: &mut PGClient| {
636            let mut tx = client.transaction().await?;
637            f(&mut tx).await?;
638            tx.commit().await?;
639            Ok(())
640        })
641        .await
642    }
643
644    /// Like [`Client::batch_execute`].
645    pub async fn batch_execute(&mut self, query: &str, timeout: Option<Duration>) -> PGResult<()> {
646        self.wrap_reconnect(timeout, async |client: &mut PGClient| {
647            client.batch_execute(query).await
648        })
649        .await
650    }
651
652    /// Like [`Client::simple_query`].
653    pub async fn simple_query(
654        &mut self,
655        query: &str,
656        timeout: Option<Duration>,
657    ) -> PGResult<Vec<SimpleQueryMessage>> {
658        self.wrap_reconnect(timeout, async |client: &mut PGClient| {
659            client.simple_query(query).await
660        })
661        .await
662    }
663
664    /// Returns a reference to the underlying [`tokio_postgres::Client`].
665    pub fn client(&self) -> &tokio_postgres::Client {
666        &self.inner
667    }
668}
669
670///
671/// Wraps any future in a tokio timeout and maps the Elapsed error to a PGError::Timeout.
672///
673pub async fn wrap_timeout<T>(dur: Duration, fut: impl Future<Output = PGResult<T>>) -> PGResult<T> {
674    match timeout(dur, fut).await {
675        Ok(out) => out,
676        Err(_) => Err(PGError::Timeout(dur)),
677    }
678}
679
680#[cfg(test)]
681mod tests {
682
683    use {
684        super::{PGError, PGMessage, PGRaiseLevel, PGRobustClient, PGRobustClientConfig},
685        insta::*,
686        std::{
687            sync::{Arc, RwLock},
688            time::Duration,
689        },
690        testcontainers::{ImageExt, runners::AsyncRunner},
691        testcontainers_modules::postgres::Postgres,
692    };
693
694    // ========================================================================
695    // UNIT TESTS (no database required)
696    // ========================================================================
697
698    mod unit {
699        use super::*;
700        use tokio_postgres::NoTls;
701
702        // --------------------------------------------------------------------
703        // Config Builder Tests
704        // --------------------------------------------------------------------
705
706        #[test]
707        fn config_default_values() {
708            let config = PGRobustClientConfig::new("postgres://localhost/test", NoTls);
709
710            assert_eq!(config.max_reconnect_attempts, 10);
711            assert_eq!(config.default_timeout, Duration::from_secs(3600));
712            assert!(config.subscriptions.is_empty());
713            assert!(config.connect_script.is_none());
714            assert!(config.application_name.is_none());
715        }
716
717        #[test]
718        fn config_builder_chaining() {
719            let config = PGRobustClientConfig::new("postgres://localhost/test", NoTls)
720                .max_reconnect_attempts(5)
721                .default_timeout(Duration::from_secs(30))
722                .application_name("test_app")
723                .connect_script("SET timezone = 'UTC'")
724                .subscriptions(["channel1", "channel2"]);
725
726            assert_eq!(config.max_reconnect_attempts, 5);
727            assert_eq!(config.default_timeout, Duration::from_secs(30));
728            assert_eq!(config.application_name, Some("test_app".to_string()));
729            assert_eq!(config.connect_script, Some("SET timezone = 'UTC'".to_string()));
730            assert!(config.subscriptions.contains("channel1"));
731            assert!(config.subscriptions.contains("channel2"));
732        }
733
734        #[test]
735        fn config_with_methods() {
736            let mut config = PGRobustClientConfig::new("postgres://localhost/test", NoTls);
737
738            config.with_max_reconnect_attempts(Some(3));
739            config.with_default_timeout(Some(Duration::from_secs(60)));
740            config.with_application_name(Some("my_app"));
741            config.with_connect_script(Some("SELECT 1"));
742            config.with_subscriptions(["events"]);
743
744            assert_eq!(config.max_reconnect_attempts, 3);
745            assert_eq!(config.default_timeout, Duration::from_secs(60));
746            assert_eq!(config.application_name, Some("my_app".to_string()));
747            assert_eq!(config.connect_script, Some("SELECT 1".to_string()));
748            assert!(config.subscriptions.contains("events"));
749        }
750
751        #[test]
752        fn config_full_connect_script_empty() {
753            let config = PGRobustClientConfig::new("postgres://localhost/test", NoTls);
754            assert!(config.full_connect_script().is_none());
755        }
756
757        #[test]
758        fn config_full_connect_script_with_app_name() {
759            let config = PGRobustClientConfig::new("postgres://localhost/test", NoTls)
760                .application_name("my_app");
761
762            let script = config.full_connect_script().unwrap();
763            assert!(script.contains("SET application_name = 'my_app'"));
764        }
765
766        #[test]
767        fn config_full_connect_script_with_subscriptions() {
768            let config = PGRobustClientConfig::new("postgres://localhost/test", NoTls)
769                .subscriptions(["chan1", "chan2"]);
770
771            let script = config.full_connect_script().unwrap();
772            assert!(script.contains("LISTEN chan1;"));
773            assert!(script.contains("LISTEN chan2;"));
774        }
775
776        #[test]
777        fn config_full_connect_script_combined() {
778            let config = PGRobustClientConfig::new("postgres://localhost/test", NoTls)
779                .application_name("app")
780                .connect_script("SET timezone = 'UTC';")
781                .subscriptions(["events"]);
782
783            let script = config.full_connect_script().unwrap();
784            assert!(script.contains("SET application_name = 'app'"));
785            assert!(script.contains("SET timezone = 'UTC';"));
786            assert!(script.contains("LISTEN events;"));
787        }
788
789        #[test]
790        fn config_without_subscriptions() {
791            let mut config = PGRobustClientConfig::new("postgres://localhost/test", NoTls)
792                .subscriptions(["a", "b", "c"]);
793
794            config.without_subscriptions(["b"]);
795
796            assert!(config.subscriptions.contains("a"));
797            assert!(!config.subscriptions.contains("b"));
798            assert!(config.subscriptions.contains("c"));
799        }
800
801        // --------------------------------------------------------------------
802        // PGError Tests
803        // --------------------------------------------------------------------
804
805        #[test]
806        fn error_timeout_display() {
807            let err = PGError::Timeout(Duration::from_secs(30));
808            let msg = err.to_string();
809            assert!(msg.contains("timed out"));
810            assert!(msg.contains("30"));
811        }
812
813        #[test]
814        fn error_failed_to_reconnect_display() {
815            let err = PGError::FailedToReconnect(5);
816            let msg = err.to_string();
817            assert!(msg.contains("5"));
818            assert!(msg.contains("reconnect"));
819        }
820
821        #[test]
822        fn error_is_timeout() {
823            let timeout_err = PGError::Timeout(Duration::from_secs(1));
824            let reconnect_err = PGError::FailedToReconnect(1);
825
826            assert!(timeout_err.is_timeout());
827            assert!(!reconnect_err.is_timeout());
828        }
829
830        #[test]
831        fn error_other() {
832            let custom_err = std::io::Error::new(std::io::ErrorKind::Other, "custom error");
833            let pg_err = PGError::other(custom_err);
834
835            assert!(matches!(pg_err, PGError::Other(_)));
836            assert!(pg_err.to_string().contains("custom error"));
837        }
838
839        // --------------------------------------------------------------------
840        // PGMessage Tests
841        // --------------------------------------------------------------------
842
843        #[test]
844        fn message_reconnect_creation() {
845            let msg = PGMessage::reconnect(3, 10);
846            match msg {
847                PGMessage::Reconnect { attempts, max_attempts, .. } => {
848                    assert_eq!(attempts, 3);
849                    assert_eq!(max_attempts, 10);
850                }
851                _ => panic!("Expected Reconnect variant"),
852            }
853        }
854
855        #[test]
856        fn message_connected_creation() {
857            let msg = PGMessage::connected();
858            assert!(matches!(msg, PGMessage::Connected { .. }));
859        }
860
861        #[test]
862        fn message_timeout_creation() {
863            let msg = PGMessage::timeout(Duration::from_secs(5));
864            match msg {
865                PGMessage::Timeout { duration, .. } => {
866                    assert_eq!(duration, Duration::from_secs(5));
867                }
868                _ => panic!("Expected Timeout variant"),
869            }
870        }
871
872        #[test]
873        fn message_cancelled_creation() {
874            let msg_success = PGMessage::cancelled(true);
875            let msg_failure = PGMessage::cancelled(false);
876
877            match msg_success {
878                PGMessage::Cancelled { success, .. } => assert!(success),
879                _ => panic!("Expected Cancelled variant"),
880            }
881            match msg_failure {
882                PGMessage::Cancelled { success, .. } => assert!(!success),
883                _ => panic!("Expected Cancelled variant"),
884            }
885        }
886
887        #[test]
888        fn message_failed_to_reconnect_creation() {
889            let msg = PGMessage::failed_to_reconnect(5);
890            match msg {
891                PGMessage::FailedToReconnect { attempts, .. } => {
892                    assert_eq!(attempts, 5);
893                }
894                _ => panic!("Expected FailedToReconnect variant"),
895            }
896        }
897
898        #[test]
899        fn message_disconnected_creation() {
900            let msg = PGMessage::disconnected("Connection reset");
901            match msg {
902                PGMessage::Disconnected { reason, .. } => {
903                    assert_eq!(reason, "Connection reset");
904                }
905                _ => panic!("Expected Disconnected variant"),
906            }
907        }
908
909        #[test]
910        fn message_display_reconnect() {
911            let msg = PGMessage::reconnect(2, 10);
912            let display = msg.to_string();
913            assert!(display.contains("RECONNECT"));
914            assert!(display.contains("2"));
915            assert!(display.contains("10"));
916        }
917
918        #[test]
919        fn message_display_timeout() {
920            let msg = PGMessage::timeout(Duration::from_millis(500));
921            let display = msg.to_string();
922            assert!(display.contains("TIMEOUT"));
923        }
924
925        // --------------------------------------------------------------------
926        // PGRaiseLevel Tests
927        // --------------------------------------------------------------------
928
929        #[test]
930        fn raise_level_from_str() {
931            use std::str::FromStr;
932
933            // Test all known levels parse correctly
934            assert!(PGRaiseLevel::from_str("DEBUG").is_ok());
935            assert!(PGRaiseLevel::from_str("LOG").is_ok());
936            assert!(PGRaiseLevel::from_str("INFO").is_ok());
937            assert!(PGRaiseLevel::from_str("NOTICE").is_ok());
938            assert!(PGRaiseLevel::from_str("WARNING").is_ok());
939            assert!(PGRaiseLevel::from_str("ERROR").is_ok());
940            assert!(PGRaiseLevel::from_str("FATAL").is_ok());
941            assert!(PGRaiseLevel::from_str("PANIC").is_ok());
942        }
943
944        #[test]
945        fn raise_level_display() {
946            assert_eq!(PGRaiseLevel::Debug.to_string(), "DEBUG");
947            assert_eq!(PGRaiseLevel::Log.to_string(), "LOG");
948            assert_eq!(PGRaiseLevel::Warning.to_string(), "WARNING");
949        }
950
951        #[test]
952        fn raise_level_unknown_returns_error() {
953            use std::str::FromStr;
954            assert!(PGRaiseLevel::from_str("UNKNOWN_LEVEL").is_err());
955            assert!(PGRaiseLevel::from_str("debug").is_err()); // case sensitive
956        }
957    }
958
959    // ========================================================================
960    // INTEGRATION TESTS (require database)
961    // ========================================================================
962
963    fn sql_for_log_and_notify_test(level: PGRaiseLevel) -> String {
964        format!(
965            r#"
966                    set client_min_messages to '{}';
967                    do $$
968                    begin
969                        raise debug 'this is a DEBUG notification';
970                        notify test, 'test#1';
971                        raise log 'this is a LOG notification';
972                        notify test, 'test#2';
973                        raise info 'this is a INFO notification';
974                        notify test, 'test#3';
975                        raise notice 'this is a NOTICE notification';
976                        notify test, 'test#4';
977                        raise warning 'this is a WARNING notification';
978                        notify test, 'test#5';
979                    end;
980                    $$;
981                "#,
982            level
983        )
984    }
985
986    #[tokio::test]
987    async fn test_integration() {
988        //
989        // --------------------------------------------------------------------
990        // Setup Postgres Server
991        // --------------------------------------------------------------------
992
993        let pg_server = Postgres::default()
994            .with_tag("16.4")
995            .start()
996            .await
997            .expect("could not start postgres server");
998
999        // NOTE: this stuff with Box::leak allows us to create a static string
1000        let database_url = format!(
1001            "postgres://postgres:postgres@{}:{}/postgres",
1002            pg_server.get_host().await.unwrap(),
1003            pg_server.get_host_port_ipv4(5432).await.unwrap()
1004        );
1005
1006        // let database_url = "postgres://postgres:postgres@localhost:5432/postgres";
1007
1008        // --------------------------------------------------------------------
1009        // Connect to the server
1010        // --------------------------------------------------------------------
1011
1012        let notices = Arc::new(RwLock::new(Vec::new()));
1013        let notices_clone = notices.clone();
1014
1015        let callback = move |msg: PGMessage| {
1016            if let Ok(mut guard) = notices_clone.write() {
1017                guard.push(msg.to_string());
1018            }
1019        };
1020
1021        let config = PGRobustClientConfig::new(database_url, tokio_postgres::NoTls);
1022
1023        let mut admin = PGRobustClient::spawn(config.clone())
1024            .await
1025            .expect("could not create initial client");
1026
1027        let mut client = PGRobustClient::spawn(config.callback(callback).max_reconnect_attempts(2))
1028            .await
1029            .expect("could not create initial client");
1030
1031        // --------------------------------------------------------------------
1032        // Subscribe to notify and raise
1033        // --------------------------------------------------------------------
1034
1035        client
1036            .subscribe_notify(&["test"], None)
1037            .await
1038            .expect("could not subscribe");
1039
1040        let (_, execution_log) = client
1041            .with_captured_log(async |client: &mut PGRobustClient<_>| {
1042                client
1043                    .simple_query(&sql_for_log_and_notify_test(PGRaiseLevel::Debug), None)
1044                    .await
1045            })
1046            .await
1047            .expect("could not execute queries on postgres");
1048
1049        assert_json_snapshot!("subscribed-executionlog", &execution_log, {
1050            "[].timestamp" => "<timestamp>",
1051            "[].process_id" => "<pid>",
1052        });
1053
1054        assert_snapshot!("subscribed-notify", extract_and_clear_logs(&notices));
1055
1056        // --------------------------------------------------------------------
1057        // Unsubscribe
1058        // --------------------------------------------------------------------
1059
1060        client
1061            .unsubscribe_notify(&["test"], None)
1062            .await
1063            .expect("could not unsubscribe");
1064
1065        let (_, execution_log) = client
1066            .with_captured_log(async |client| {
1067                client
1068                    .simple_query(&sql_for_log_and_notify_test(PGRaiseLevel::Warning), None)
1069                    .await
1070            })
1071            .await
1072            .expect("could not execute queries on postgres");
1073
1074        assert_json_snapshot!("unsubscribed-executionlog", &execution_log, {
1075            "[].timestamp" => "<timestamp>",
1076            "[].process_id" => "<pid>",
1077        });
1078
1079        assert_snapshot!("unsubscribed-notify", extract_and_clear_logs(&notices));
1080
1081        // --------------------------------------------------------------------
1082        // Timeout
1083        // --------------------------------------------------------------------
1084
1085        let result = client
1086            .simple_query(
1087                "
1088                    do $$
1089                    begin
1090                        raise info 'before sleep';
1091                        perform pg_sleep(3);
1092                        raise info 'after sleep';
1093                    end;
1094                    $$
1095                ",
1096                Some(Duration::from_secs(1)),
1097            )
1098            .await;
1099
1100        assert!(matches!(result, Err(PGError::Timeout(_))));
1101        assert_snapshot!("timeout-messages", extract_and_clear_logs(&notices));
1102
1103        // --------------------------------------------------------------------
1104        // Reconnect (before query)
1105        // --------------------------------------------------------------------
1106
1107        admin.simple_query("select pg_terminate_backend(pid) from pg_stat_activity where pid != pg_backend_pid()", None)
1108            .await.expect("could not kill other client");
1109
1110        let result = client
1111            .simple_query(
1112                "
1113                    do $$
1114                    begin
1115                        raise info 'before sleep';
1116                        perform pg_sleep(1);
1117                        raise info 'after sleep';
1118                    end;
1119                    $$
1120                ",
1121                Some(Duration::from_secs(10)),
1122            )
1123            .await;
1124
1125        assert!(matches!(result, Ok(_)));
1126        assert_snapshot!("reconnect-before", extract_and_clear_logs(&notices));
1127
1128        // --------------------------------------------------------------------
1129        // Reconnect (during query)
1130        // --------------------------------------------------------------------
1131
1132        let query = client.simple_query(
1133            "
1134                    do $$
1135                    begin
1136                        raise info 'before sleep';
1137                        perform pg_sleep(1);
1138                        raise info 'after sleep';
1139                    end;
1140                    $$
1141                ",
1142            None,
1143        );
1144
1145        let kill_later = 
1146            admin.simple_query("
1147                select pg_sleep(0.5); 
1148                select pg_terminate_backend(pid) from pg_stat_activity where pid != pg_backend_pid()", 
1149                None
1150            );
1151
1152        let (_, result) = tokio::join!(kill_later, query);
1153
1154        assert!(matches!(result, Ok(_)));
1155        assert_snapshot!("reconnect-during", extract_and_clear_logs(&notices));
1156
1157        // --------------------------------------------------------------------
1158        // Reconnect (failure)
1159        // --------------------------------------------------------------------
1160
1161        pg_server.stop().await.expect("could not stop server");
1162
1163        let result = client.simple_query(
1164            "
1165                do $$
1166                begin
1167                    raise info 'before sleep';
1168                    perform pg_sleep(1);
1169                    raise info 'after sleep';
1170                end;
1171                $$
1172            ",
1173            None,
1174        ).await;
1175
1176        eprintln!("result: {result:?}");
1177        assert!(matches!(result, Err(PGError::FailedToReconnect(2))));
1178        assert_snapshot!("reconnect-failure", extract_and_clear_logs(&notices));
1179
1180
1181    }
1182
1183    fn extract_and_clear_logs(logs: &Arc<RwLock<Vec<String>>>) -> String {
1184        let mut guard = logs.write().expect("could not read notices");
1185        let emtpy_log = Vec::default();
1186        let log = std::mem::replace(&mut *guard, emtpy_log);
1187        redact_pids(&redact_timestamps(&log.join("\n")))
1188    }
1189
1190    fn redact_timestamps(text: &str) -> String {
1191        use regex::Regex;
1192        use std::sync::OnceLock;
1193        pub static TIMESTAMP_PATTERN: OnceLock<Regex> = OnceLock::new();
1194        let pat = TIMESTAMP_PATTERN.get_or_init(|| {
1195            Regex::new(r"\d{4}-\d{2}-\d{2}.?\d{2}:\d{2}:\d{2}(\.\d{3,9})?(Z| UTC|[+-]\d{2}:\d{2})?")
1196                .unwrap()
1197        });
1198        pat.replace_all(text, "<timestamp>").to_string()
1199    }
1200
1201    fn redact_pids(text: &str) -> String {
1202        use regex::Regex;
1203        use std::sync::OnceLock;
1204        pub static TIMESTAMP_PATTERN: OnceLock<Regex> = OnceLock::new();
1205        let pat = TIMESTAMP_PATTERN.get_or_init(|| Regex::new(r"pid=\d+").unwrap());
1206        pat.replace_all(text, "<pid>").to_string()
1207    }
1208}