postgres_notify/lib.rs
1//!
2//! `postgres-notify` started out as an easy way to receive PostgreSQL
3//! notifications but has since evolved into a much more useful client
4//! and is able to handle the following:
5//!
6//! - Receive `NOTIFY <channel> <payload>` pub/sub style notifications
7//!
8//! - Receive `RAISE` messages and collects execution logs
9//!
10//! - Applies a timeout to all queries. If a query timesout then the
11//! client will attempt to cancel the ongoing query before returning
12//! an error.
13//!
14//! - Supports cancelling an ongoing query.
15//!
16//! - Automatically reconnects if the connection is lost and uses
17//! exponential backoff with jitter to avoid thundering herd effect.
18//!
19//! - Has a familiar API with an additional `timeout` argument.
20//!
21//!
22//!
23//! # BREAKING CHANGE in v0.3.0
24//!
25//! This latest version is a breaking change. The `PGNotifyingClient` has
26//! been renamed `PGRobustClient` and queries don't need to be made through
27//! the inner client anymore. Furthermore, a single callback handles all
28//! of the notifications: NOTIFY, RAISE, TIMOUT, RECONNECT.
29//!
30//!
31//!
32//! # LISTEN/NOTIFY
33//!
34//! For a very long time (at least since version 7.1) postgres has supported
35//! asynchronous notifications based on LISTEN/NOTIFY commands. This allows
36//! the database to send notifications to the client in an "out-of-band"
37//! channel.
38//!
39//! Once the client has issued a `LISTEN <channel>` command, the database will
40//! send notifications to the client whenever a `NOTIFY <channel> <payload>`
41//! is issued on the database regardless of which session has issued it.
42//! This can act as a cheap alternative to a pub/sub system though without
43//! mailboxes or persistence.
44//!
45//! When calling `subscribe_notify` with a list of channel names, [`PGRobustClient`]
46//! will the client callback any time a `NOTIFY` message is received for any of
47//! the subscribed channels.
48//!
49//! ```rust
50//! use postgres_notify::{PGRobustClient, PGMessage};
51//! use tokio_postgres::NoTls;
52//! use std::time::Duration;
53//!
54//! let rt = tokio::runtime::Builder::new_current_thread()
55//! .enable_io()
56//! .enable_time()
57//! .build()
58//! .expect("could not start tokio runtime");
59//!
60//! rt.block_on(async move {
61//! let database_url = "postgres://postgres:postgres@localhost:5432/postgres";
62//! let callback = |msg:PGMessage| println!("{:?}", &msg);
63//! let mut client = PGRobustClient::spawn(database_url, NoTls, callback)
64//! .await.expect("Could not connect to postgres");
65//!
66//! client.subscribe_notify(&["test"], Some(Duration::from_millis(100)))
67//! .await.expect("Could not subscribe to channels");
68//! });
69//! ```
70//!
71//!
72//!
73//! # RAISE/LOGS
74//!
75//! Logs in PostgreSQL are created by writing `RAISE <level> <message>` statements
76//! within your functions, stored procedures and scripts. When such a command is
77//! issued, [`PGRobustClient`] receives a notification even if the call is still
78//! in progress. This allows the caller to capture the execution log in realtime
79//! if needed.
80//!
81//! [`PGRobustClient`] simplifies log collection in two ways. Firstly it provides
82//! the [`with_captured_log`](PGRobustClient::with_captured_log) functions,
83//! which collects the execution log and returns it along with the query result.
84//! This is probably what most people will want to use.
85//!
86//! If your needs are more complex or if you want to propagate realtime logs,
87//! then using client callback can be used to forwand the message on an
88//! asynchonous channel.
89//!
90//! ```rust
91//! use postgres_notify::{PGRobustClient, PGMessage};
92//! use tokio_postgres::NoTls;
93//! use std::time::Duration;
94//!
95//! let rt = tokio::runtime::Builder::new_current_thread()
96//! .enable_io()
97//! .enable_time()
98//! .build()
99//! .expect("could not start tokio runtime");
100//!
101//! rt.block_on(async move {
102//!
103//! let callback = |msg:PGMessage| println!("{:?}", &msg);
104//!
105//! let database_url = "postgres://postgres:postgres@localhost:5432/postgres";
106//! let mut client = PGRobustClient::spawn(database_url, NoTls, callback)
107//! .await.expect("Could not connect to postgres");
108//!
109//! // Will capture the notices in a Vec
110//! let (_, log) = client.with_captured_log(async |client| {
111//! client.simple_query("
112//! do $$
113//! begin
114//! raise debug 'this is a DEBUG notification';
115//! raise log 'this is a LOG notification';
116//! raise info 'this is a INFO notification';
117//! raise notice 'this is a NOTICE notification';
118//! raise warning 'this is a WARNING notification';
119//! end;
120//! $$",
121//! Some(Duration::from_secs(1))
122//! ).await.expect("Error during query execution");
123//! Ok(())
124//! }).await.expect("Error during captur log");
125//!
126//! println!("{:#?}", &log);
127//! });
128//! ```
129//!
130//! Note that the client passed to the async callback is `&mut self`, which
131//! means that all queries within that block are subject to the same timeout
132//! and reconnect handling.
133//!
134//! You can look at the unit tests for a more in-depth example.
135//!
136//!
137//!
138//! # TIMEOUT
139//!
140//! All of the query functions in [`PGRobustClient`] have a `timeout` argument.
141//! If the query takes longer than the timeout, then an error is returned.
142//! If not specified, the default timeout is 1 hour.
143//!
144//!
145//! # RECONNECT
146//!
147//! If the connection to the database is lost, then [`PGRobustClient`] will
148//! attempt to reconnect to the database automatically. If the maximum number
149//! of reconnect attempts is reached then an error is returned. Furthermore,
150//! it uses a exponential backoff with jitter in order to avoid thundering
151//! herd effect.
152//!
153
154mod error;
155mod messages;
156mod notify;
157
158pub use error::*;
159pub use messages::*;
160use tokio_postgres::{SimpleQueryMessage, ToStatement};
161
162use {
163 futures::TryFutureExt,
164 std::{
165 collections::BTreeSet,
166 sync::{Arc, RwLock},
167 time::Duration,
168 },
169 tokio::{
170 task::JoinHandle,
171 time::{sleep, timeout},
172 },
173 tokio_postgres::{
174 CancelToken, Client as PGClient, Row, RowStream, Socket, Statement, Transaction,
175 tls::MakeTlsConnect,
176 types::{BorrowToSql, ToSql, Type},
177 },
178};
179
180/// Shorthand for Result with tokio_postgres::Error
181pub type PGResult<T> = Result<T, PGError>;
182
183pub struct PGRobustClient<TLS>
184where
185 TLS: MakeTlsConnect<Socket>,
186{
187 database_url: String,
188 make_tls: TLS,
189 client: PGClient,
190 conn_handle: JoinHandle<()>,
191 cancel_token: CancelToken,
192 subscriptions: BTreeSet<String>,
193 callback: Arc<dyn Fn(PGMessage) + Send + Sync + 'static>,
194 max_reconnect_attempts: u32,
195 default_timeout: Duration,
196 log: Arc<RwLock<Vec<PGMessage>>>,
197}
198
199#[allow(unused)]
200impl<TLS> PGRobustClient<TLS>
201where
202 TLS: MakeTlsConnect<Socket> + Clone,
203 <TLS as MakeTlsConnect<Socket>>::Stream: Send + Sync + 'static,
204{
205 ///
206 /// Given a connect factory and a callback, returns a new [`PGRobustClient`].
207 ///
208 /// The callback will be called whenever a new NOTIFY/RAISE message is received.
209 /// Furthermore, it is also called with a [`PGMessage::Timeout`], when a query
210 /// times out, [`PGMessage::Disconnected`] if the internal state of the client
211 /// is not as expected (Poisoned lock, dropped connections, etc.) or
212 /// [`PGMessage::Reconnect`] whenever a new reconnect attempt is made.
213 ///
214 pub async fn spawn(
215 database_url: impl AsRef<str>,
216 make_tls: TLS,
217 callback: impl Fn(PGMessage) + Send + Sync + 'static,
218 ) -> PGResult<Self> {
219 //
220 // Setup log and other default values
221 //
222 let log = Arc::new(RwLock::new(Vec::default()));
223 let default_timeout = Duration::from_secs(60 * 60);
224
225 //
226 // We wrap the callback so that it also inserts into the log.
227 //
228 // NOTE: we need to type erase here because otherwise the call to Self::connect
229 // will not compile.
230 //
231 let callback: Arc<dyn Fn(PGMessage) + Send + Sync + 'static> = Arc::new({
232 let log = log.clone();
233 move |msg: PGMessage| {
234 callback(msg.clone());
235 if let Ok(mut log) = log.write() {
236 log.push(msg);
237 }
238 }
239 });
240
241 // Connect to the database
242 let (client, conn_handle, cancel_token) =
243 Self::connect(database_url.as_ref(), &make_tls, &callback).await?;
244
245 Ok(Self {
246 database_url: database_url.as_ref().to_string(),
247 make_tls,
248 client,
249 conn_handle,
250 cancel_token,
251 subscriptions: BTreeSet::new(),
252 callback,
253 max_reconnect_attempts: u32::MAX,
254 default_timeout,
255 log,
256 })
257 }
258
259 ///
260 /// Sets the default timeout for all queries. Defaults to 1 hour.
261 ///
262 /// This function consumes and returns self and is therefor usually used
263 /// just after [`PGRobustClient::spawn`].
264 ///
265 pub fn with_default_timeout(mut self, timeout: Duration) -> Self {
266 self.default_timeout = timeout;
267 self
268 }
269
270 ///
271 /// Sets the maximum number of reconnect attempts before giving up.
272 /// Defaults to `u32::MAX`.
273 ///
274 /// This function consumes and returns self and is therefor usually used
275 /// just after [`PGRobustClient::spawn`].
276 ///
277 pub fn with_max_reconnect_attempts(mut self, max_attempts: u32) -> Self {
278 self.max_reconnect_attempts = max_attempts;
279 self
280 }
281
282 ///
283 /// PRIVATE
284 /// Does the necessary details to connect to the database and hookup callbacks and notifications.
285 ///
286 async fn connect(
287 database_url: &str,
288 make_tls: &TLS,
289 callback: &Arc<dyn Fn(PGMessage) + Send + Sync + 'static>,
290 ) -> PGResult<(PGClient, JoinHandle<()>, CancelToken)> {
291 //
292 let (client, conn) = tokio_postgres::connect(database_url, make_tls.clone()).await?;
293 let cancel_token = client.cancel_token();
294
295 let callback = callback.clone();
296 let handle = tokio::spawn(notify::handle_connection_polling(conn, move |msg| {
297 callback(msg)
298 }));
299
300 Ok((client, handle, cancel_token))
301 }
302
303 ///
304 /// Cancels any in-progress query.
305 ///
306 /// This is the only function that does not take a timeout nor does it
307 /// attempt to reconnect if the connection is lost. It will simply
308 /// return the original error.
309 ///
310 pub async fn cancel_query(&mut self) -> PGResult<()> {
311 self.cancel_token
312 .cancel_query(self.make_tls.clone())
313 .await
314 .map_err(Into::into)
315 }
316
317 ///
318 /// Returns the log messages captured since the last call to this function.
319 /// It also clears the log.
320 ///
321 pub fn capture_and_clear_log(&mut self) -> Vec<PGMessage> {
322 if let Ok(mut guard) = self.log.write() {
323 let empty_log = Vec::default();
324 std::mem::replace(&mut *guard, empty_log)
325 } else {
326 Vec::default()
327 }
328 }
329
330 ///
331 /// Given an async closure taking the postgres client, returns the result
332 /// of said closure along with the accumulated log since the beginning of
333 /// the closure.
334 ///
335 /// If you use query pipelining then collect the logs for all queries in
336 /// the pipeline. Otherwise, the logs might not be what you expect.
337 ///
338 pub async fn with_captured_log<F, T>(&mut self, f: F) -> PGResult<(T, Vec<PGMessage>)>
339 where
340 F: AsyncFn(&mut Self) -> PGResult<T>,
341 {
342 self.capture_and_clear_log(); // clear the log just in case...
343 let result = f(self).await?;
344 let log = self.capture_and_clear_log();
345 Ok((result, log))
346 }
347
348 ///
349 /// Attempts to reconnect after a connection loss.
350 ///
351 /// Reconnection applies an exponention backoff with jitter in order to
352 /// avoid thundering herd effect. If the maximum number of attempts is
353 /// reached then an error is returned.
354 ///
355 /// If an error unrelated to establishing a new connection is returned
356 /// when trying to connect then that error is returned.
357 ///
358 pub async fn reconnect(&mut self) -> PGResult<()> {
359 //
360 use std::cmp::{max, min};
361 let mut attempts = 1;
362 let mut k = 500;
363
364 while attempts <= self.max_reconnect_attempts {
365 //
366 // Implement exponential backoff + jitter
367 // Initial delay will be 500ms, max delay is 1h.
368 //
369 sleep(Duration::from_millis(k + rand::random_range(0..k / 2))).await;
370 k = min(k * 2, 60000);
371
372 tracing::info!("Reconnect attempt #{}", attempts);
373 (self.callback)(PGMessage::reconnect(attempts, self.max_reconnect_attempts));
374
375 attempts += 1;
376
377 let maybe_triple =
378 Self::connect(&self.database_url, &self.make_tls, &self.callback).await;
379
380 match maybe_triple {
381 Ok((client, conn_handle, cancel_token)) => {
382 // Abort the old connection just in case
383 self.conn_handle.abort();
384
385 self.client = client;
386 self.conn_handle = conn_handle;
387 self.cancel_token = cancel_token;
388
389 // Resubscribe to previously subscribed channels
390 let subs: Vec<_> = self.subscriptions.iter().map(String::from).collect();
391
392 match Self::subscribe_notify_impl(&self.client, &subs).await {
393 Ok(_) => {
394 return Ok(());
395 }
396 Err(e) if is_pg_connection_issue(&e) => {
397 continue;
398 }
399 Err(e) => {
400 return Err(e.into());
401 }
402 }
403 }
404 Err(e) if e.is_pg_connection_issue() => {
405 continue;
406 }
407 Err(e) => {
408 return Err(e);
409 }
410 }
411 }
412
413 // Issue the failed to reconnect message
414 (self.callback)(PGMessage::failed_to_reconnect(self.max_reconnect_attempts));
415 // Return the error
416 Err(PGError::FailedToReconnect(self.max_reconnect_attempts))
417 }
418
419 pub async fn wrap_reconnect<T>(
420 &mut self,
421 max_dur: Option<Duration>,
422 factory: impl AsyncFn(&mut PGClient) -> Result<T, tokio_postgres::Error>,
423 ) -> PGResult<T> {
424 let max_dur = max_dur.unwrap_or(self.default_timeout);
425 loop {
426 match timeout(max_dur, factory(&mut self.client)).await {
427 // Query succeeded so return the result
428 Ok(Ok(o)) => return Ok(o),
429 // Query failed because of connection issues
430 Ok(Err(e)) if is_pg_connection_issue(&e) => {
431 self.reconnect().await?;
432 }
433 // Query failed for some other reason
434 Ok(Err(e)) => {
435 return Err(e.into());
436 }
437 // Query timed out!
438 Err(_) => {
439 // Callback with timeout message
440 (self.callback)(PGMessage::timeout(max_dur));
441 // Cancel the ongoing query
442 let status = self.cancel_token.cancel_query(self.make_tls.clone()).await;
443 // Callback with cancelled message
444 (self.callback)(PGMessage::cancelled(!status.is_err()));
445 // Return the timeout error
446 return Err(PGError::Timeout(max_dur));
447 }
448 }
449 }
450 }
451
452 pub async fn subscribe_notify(
453 &mut self,
454 channels: &[impl AsRef<str> + Send + Sync + 'static],
455 timeout: Option<Duration>,
456 ) -> PGResult<()> {
457 if !channels.is_empty() {
458 self.wrap_reconnect(timeout, async |client: &mut PGClient| {
459 Self::subscribe_notify_impl(client, channels).await
460 })
461 .await?;
462
463 // Add to our subscriptions
464 channels.iter().for_each(|ch| {
465 self.subscriptions.insert(ch.as_ref().to_string());
466 });
467 }
468 Ok(())
469 }
470
471 async fn subscribe_notify_impl(
472 client: &PGClient,
473 channels: &[impl AsRef<str> + Send + Sync + 'static],
474 ) -> Result<(), tokio_postgres::Error> {
475 // Build a sequence of `LISTEN` commands
476 let sql = channels
477 .iter()
478 .map(|ch| format!("LISTEN {};", ch.as_ref()))
479 .collect::<Vec<_>>()
480 .join("\n");
481
482 // Tell the world we are about to subscribe
483 #[cfg(feature = "tracing")]
484 tracing::info!(
485 "Subscribing to channels: \"{}\"",
486 &channels
487 .iter()
488 .map(AsRef::as_ref)
489 .collect::<Vec<_>>()
490 .join(",")
491 );
492
493 // Issue the `LISTEN` commands
494 client.simple_query(&sql).await?;
495 Ok(())
496 }
497
498 pub async fn unsubscribe_notify(
499 &mut self,
500 channels: &[impl AsRef<str> + Send + Sync + 'static],
501 timeout: Option<Duration>,
502 ) -> PGResult<()> {
503 if !channels.is_empty() {
504 self.wrap_reconnect(timeout, async move |client: &mut PGClient| {
505 // Build a sequence of `LISTEN` commands
506 let sql = channels
507 .iter()
508 .map(|ch| format!("UNLISTEN {};", ch.as_ref()))
509 .collect::<Vec<_>>()
510 .join("\n");
511
512 // Tell the world we are about to subscribe
513 #[cfg(feature = "tracing")]
514 tracing::info!(
515 "Unsubscribing from channels: \"{}\"",
516 &channels
517 .iter()
518 .map(AsRef::as_ref)
519 .collect::<Vec<_>>()
520 .join(",")
521 );
522
523 // Issue the `LISTEN` commands
524 client.simple_query(&sql).await?;
525 Ok(())
526 })
527 .await?;
528
529 // Remove subscriptions
530 channels.iter().for_each(|ch| {
531 self.subscriptions.remove(ch.as_ref());
532 });
533 }
534 Ok(())
535 }
536
537 ///
538 /// Unsubscribes from all channels.
539 ///
540 pub async fn unsubscribe_notify_all(&mut self, timeout: Option<Duration>) -> PGResult<()> {
541 self.wrap_reconnect(timeout, async |client: &mut PGClient| {
542 // Tell the world we are about to unsubscribe
543 #[cfg(feature = "tracing")]
544 tracing::info!("Unsubscribing from channels: *");
545 // Issue the `UNLISTEN` commands
546 client.simple_query("UNLISTEN *").await?;
547 Ok(())
548 })
549 .await
550 }
551
552 /// Like [`Client::execute_raw`].
553 pub async fn execute_raw<P, I, T>(
554 &mut self,
555 statement: &T,
556 params: I,
557 timeout: Option<Duration>,
558 ) -> PGResult<u64>
559 where
560 T: ?Sized + ToStatement + Sync + Send,
561 P: BorrowToSql + Clone + Send + Sync,
562 I: IntoIterator<Item = P> + Sync + Send,
563 I::IntoIter: ExactSizeIterator,
564 {
565 let params: Vec<_> = params.into_iter().collect();
566 self.wrap_reconnect(timeout, async |client: &mut PGClient| {
567 client.execute_raw(statement, params.clone()).await
568 })
569 .await
570 }
571
572 /// Like [`Client::query`].
573 pub async fn query<T>(
574 &mut self,
575 query: &T,
576 params: &[&(dyn ToSql + Sync)],
577 timeout: Option<Duration>,
578 ) -> PGResult<Vec<Row>>
579 where
580 T: ?Sized + ToStatement + Sync + Send,
581 {
582 let params = params.to_vec();
583 self.wrap_reconnect(timeout, async |client: &mut PGClient| {
584 client.query(query, ¶ms).await
585 })
586 .await
587 }
588
589 /// Like [`Client::query_one`].
590 pub async fn query_one<T>(
591 &mut self,
592 statement: &T,
593 params: &[&(dyn ToSql + Sync)],
594 timeout: Option<Duration>,
595 ) -> PGResult<Row>
596 where
597 T: ?Sized + ToStatement + Sync + Send,
598 {
599 let params = params.to_vec();
600 self.wrap_reconnect(timeout, async |client: &mut PGClient| {
601 client.query_one(statement, ¶ms).await
602 })
603 .await
604 }
605
606 /// Like [`Client::query_opt`].
607 pub async fn query_opt<T>(
608 &mut self,
609 statement: &T,
610 params: &[&(dyn ToSql + Sync)],
611 timeout: Option<Duration>,
612 ) -> PGResult<Option<Row>>
613 where
614 T: ?Sized + ToStatement + Sync + Send,
615 {
616 let params = params.to_vec();
617 self.wrap_reconnect(timeout, async |client: &mut PGClient| {
618 client.query_opt(statement, ¶ms).await
619 })
620 .await
621 }
622
623 /// Like [`Client::query_raw`].
624 pub async fn query_raw<T, P, I>(
625 &mut self,
626 statement: &T,
627 params: I,
628 timeout: Option<Duration>,
629 ) -> PGResult<RowStream>
630 where
631 T: ?Sized + ToStatement + Sync + Send,
632 P: BorrowToSql + Clone + Send + Sync,
633 I: IntoIterator<Item = P> + Sync + Send,
634 I::IntoIter: ExactSizeIterator,
635 {
636 let params: Vec<_> = params.into_iter().collect();
637 self.wrap_reconnect(timeout, async |client: &mut PGClient| {
638 client.query_raw(statement, params.clone()).await
639 })
640 .await
641 }
642
643 /// Like [`Client::query_typed`]
644 pub async fn query_typed(
645 &mut self,
646 statement: &str,
647 params: &[(&(dyn ToSql + Sync), Type)],
648 timeout: Option<Duration>,
649 ) -> PGResult<Vec<Row>> {
650 let params = params.to_vec();
651 self.wrap_reconnect(timeout, async |client: &mut PGClient| {
652 client.query_typed(statement, ¶ms).await
653 })
654 .await
655 }
656
657 /// Like [`Client::query_typed_raw`]
658 pub async fn query_typed_raw<P, I>(
659 &mut self,
660 statement: &str,
661 params: I,
662 timeout: Option<Duration>,
663 ) -> PGResult<RowStream>
664 where
665 P: BorrowToSql + Clone + Send + Sync,
666 I: IntoIterator<Item = (P, Type)> + Sync + Send,
667 {
668 let params: Vec<_> = params.into_iter().collect();
669 self.wrap_reconnect(timeout, async |client: &mut PGClient| {
670 client.query_typed_raw(statement, params.clone()).await
671 })
672 .await
673 }
674
675 /// Like [`Client::prepare`].
676 pub async fn prepare(&mut self, query: &str, timeout: Option<Duration>) -> PGResult<Statement> {
677 self.wrap_reconnect(timeout, async |client: &mut PGClient| {
678 client.prepare(query).map_err(Into::into).await
679 })
680 .await
681 }
682
683 /// Like [`Client::prepare_typed`].
684 pub async fn prepare_typed(
685 &mut self,
686 query: &str,
687 parameter_types: &[Type],
688 timeout: Option<Duration>,
689 ) -> PGResult<Statement> {
690 self.wrap_reconnect(timeout, async |client: &mut PGClient| {
691 client.prepare_typed(query, parameter_types).await
692 })
693 .await
694 }
695
696 //
697 /// Similar but not quite the same as [`Client::transaction`].
698 ///
699 /// Executes the closure as a single transaction.
700 /// Commit is automatically called after the closure. If any connection
701 /// issues occur during the transaction then the transaction is rolled
702 /// back (on drop) and retried a new with the new connection subject to
703 /// the maximum number of reconnect attempts.
704 ///
705 pub async fn transaction<F>(&mut self, timeout: Option<Duration>, f: F) -> PGResult<()>
706 where
707 for<'a> F: AsyncFn(&'a mut Transaction) -> Result<(), tokio_postgres::Error>,
708 {
709 self.wrap_reconnect(timeout, async |client: &mut PGClient| {
710 let mut tx = client.transaction().await?;
711 f(&mut tx).await?;
712 tx.commit().await?;
713 Ok(())
714 })
715 .await
716 }
717
718 /// Like [`Client::batch_execute`].
719 pub async fn batch_execute(&mut self, query: &str, timeout: Option<Duration>) -> PGResult<()> {
720 self.wrap_reconnect(timeout, async |client: &mut PGClient| {
721 client.batch_execute(query).await
722 })
723 .await
724 }
725
726 /// Like [`Client::simple_query`].
727 pub async fn simple_query(
728 &mut self,
729 query: &str,
730 timeout: Option<Duration>,
731 ) -> PGResult<Vec<SimpleQueryMessage>> {
732 self.wrap_reconnect(timeout, async |client: &mut PGClient| {
733 client.simple_query(query).await
734 })
735 .await
736 }
737
738 /// Returns a reference to the underlying [`Client`].
739 pub fn client(&self) -> &PGClient {
740 &self.client
741 }
742}
743
744///
745/// Wraps any future in a tokio timeout and maps the Elapsed error to a PGError::Timeout.
746///
747pub async fn wrap_timeout<T>(dur: Duration, fut: impl Future<Output = PGResult<T>>) -> PGResult<T> {
748 match timeout(dur, fut).await {
749 Ok(out) => out,
750 Err(_) => Err(PGError::Timeout(dur)),
751 }
752}
753
754#[cfg(test)]
755mod tests {
756
757 use {
758 super::{PGError, PGMessage, PGRaiseLevel, PGRobustClient},
759 insta::*,
760 std::{
761 sync::{Arc, RwLock},
762 time::Duration,
763 },
764 testcontainers::{ImageExt, runners::AsyncRunner},
765 testcontainers_modules::postgres::Postgres,
766 };
767
768 fn sql_for_log_and_notify_test(level: PGRaiseLevel) -> String {
769 format!(
770 r#"
771 set client_min_messages to '{}';
772 do $$
773 begin
774 raise debug 'this is a DEBUG notification';
775 notify test, 'test#1';
776 raise log 'this is a LOG notification';
777 notify test, 'test#2';
778 raise info 'this is a INFO notification';
779 notify test, 'test#3';
780 raise notice 'this is a NOTICE notification';
781 notify test, 'test#4';
782 raise warning 'this is a WARNING notification';
783 notify test, 'test#5';
784 end;
785 $$;
786 "#,
787 level
788 )
789 }
790
791 #[tokio::test]
792 async fn test_integration() {
793 //
794 // --------------------------------------------------------------------
795 // Setup Postgres Server
796 // --------------------------------------------------------------------
797
798 let pg_server = Postgres::default()
799 .with_tag("16.4")
800 .start()
801 .await
802 .expect("could not start postgres server");
803
804 // NOTE: this stuff with Box::leak allows us to create a static string
805 let database_url = format!(
806 "postgres://postgres:postgres@{}:{}/postgres",
807 pg_server.get_host().await.unwrap(),
808 pg_server.get_host_port_ipv4(5432).await.unwrap()
809 );
810
811 // let database_url = "postgres://postgres:postgres@localhost:5432/postgres";
812
813 // --------------------------------------------------------------------
814 // Connect to the server
815 // --------------------------------------------------------------------
816
817 let notices = Arc::new(RwLock::new(Vec::new()));
818 let notices_clone = notices.clone();
819
820 let callback = move |msg: PGMessage| {
821 if let Ok(mut guard) = notices_clone.write() {
822 guard.push(msg.to_string());
823 }
824 };
825
826 let mut admin = PGRobustClient::spawn(&database_url, tokio_postgres::NoTls, |_| {})
827 .await
828 .expect("could not create initial client");
829
830 let mut client = PGRobustClient::spawn(&database_url, tokio_postgres::NoTls, callback)
831 .await
832 .expect("could not create initial client")
833 .with_max_reconnect_attempts(2);
834
835 // --------------------------------------------------------------------
836 // Subscribe to notify and raise
837 // --------------------------------------------------------------------
838
839 client
840 .subscribe_notify(&["test"], None)
841 .await
842 .expect("could not subscribe");
843
844 let (_, execution_log) = client
845 .with_captured_log(async |client: &mut PGRobustClient<_>| {
846 client
847 .simple_query(&sql_for_log_and_notify_test(PGRaiseLevel::Debug), None)
848 .await
849 })
850 .await
851 .expect("could not execute queries on postgres");
852
853 assert_json_snapshot!("subscribed-executionlog", &execution_log, {
854 "[].timestamp" => "<timestamp>",
855 "[].process_id" => "<pid>",
856 });
857
858 assert_snapshot!("subscribed-notify", extract_and_clear_logs(¬ices));
859
860 // --------------------------------------------------------------------
861 // Unsubscribe
862 // --------------------------------------------------------------------
863
864 client
865 .unsubscribe_notify(&["test"], None)
866 .await
867 .expect("could not unsubscribe");
868
869 let (_, execution_log) = client
870 .with_captured_log(async |client| {
871 client
872 .simple_query(&sql_for_log_and_notify_test(PGRaiseLevel::Warning), None)
873 .await
874 })
875 .await
876 .expect("could not execute queries on postgres");
877
878 assert_json_snapshot!("unsubscribed-executionlog", &execution_log, {
879 "[].timestamp" => "<timestamp>",
880 "[].process_id" => "<pid>",
881 });
882
883 assert_snapshot!("unsubscribed-notify", extract_and_clear_logs(¬ices));
884
885 // --------------------------------------------------------------------
886 // Timeout
887 // --------------------------------------------------------------------
888
889 let result = client
890 .simple_query(
891 "
892 do $$
893 begin
894 raise info 'before sleep';
895 perform pg_sleep(3);
896 raise info 'after sleep';
897 end;
898 $$
899 ",
900 Some(Duration::from_secs(1)),
901 )
902 .await;
903
904 assert!(matches!(result, Err(PGError::Timeout(_))));
905 assert_snapshot!("timeout-messages", extract_and_clear_logs(¬ices));
906
907 // --------------------------------------------------------------------
908 // Reconnect (before query)
909 // --------------------------------------------------------------------
910
911 admin.simple_query("select pg_terminate_backend(pid) from pg_stat_activity where pid != pg_backend_pid()", None)
912 .await.expect("could not kill other client");
913
914 let result = client
915 .simple_query(
916 "
917 do $$
918 begin
919 raise info 'before sleep';
920 perform pg_sleep(1);
921 raise info 'after sleep';
922 end;
923 $$
924 ",
925 Some(Duration::from_secs(10)),
926 )
927 .await;
928
929 assert!(matches!(result, Ok(_)));
930 assert_snapshot!("reconnect-before", extract_and_clear_logs(¬ices));
931
932 // --------------------------------------------------------------------
933 // Reconnect (during query)
934 // --------------------------------------------------------------------
935
936 let query = client.simple_query(
937 "
938 do $$
939 begin
940 raise info 'before sleep';
941 perform pg_sleep(1);
942 raise info 'after sleep';
943 end;
944 $$
945 ",
946 None,
947 );
948
949 let kill_later =
950 admin.simple_query("
951 select pg_sleep(0.5);
952 select pg_terminate_backend(pid) from pg_stat_activity where pid != pg_backend_pid()",
953 None
954 );
955
956 let (_, result) = tokio::join!(kill_later, query);
957
958 assert!(matches!(result, Ok(_)));
959 assert_snapshot!("reconnect-during", extract_and_clear_logs(¬ices));
960
961 // --------------------------------------------------------------------
962 // Reconnect (failure)
963 // --------------------------------------------------------------------
964
965 pg_server.stop().await.expect("could not stop server");
966
967 let result = client.simple_query(
968 "
969 do $$
970 begin
971 raise info 'before sleep';
972 perform pg_sleep(1);
973 raise info 'after sleep';
974 end;
975 $$
976 ",
977 None,
978 ).await;
979
980 eprintln!("result: {result:?}");
981 assert!(matches!(result, Err(PGError::FailedToReconnect(2))));
982 assert_snapshot!("reconnect-failure", extract_and_clear_logs(¬ices));
983
984
985 }
986
987 fn extract_and_clear_logs(logs: &Arc<RwLock<Vec<String>>>) -> String {
988 let mut guard = logs.write().expect("could not read notices");
989 let emtpy_log = Vec::default();
990 let log = std::mem::replace(&mut *guard, emtpy_log);
991 redact_pids(&redact_timestamps(&log.join("\n")))
992 }
993
994 fn redact_timestamps(text: &str) -> String {
995 use regex::Regex;
996 use std::sync::OnceLock;
997 pub static TIMESTAMP_PATTERN: OnceLock<Regex> = OnceLock::new();
998 let pat = TIMESTAMP_PATTERN.get_or_init(|| {
999 Regex::new(r"\d{4}-\d{2}-\d{2}.?\d{2}:\d{2}:\d{2}(\.\d{3,9})?(Z| UTC|[+-]\d{2}:\d{2})?")
1000 .unwrap()
1001 });
1002 pat.replace_all(text, "<timestamp>").to_string()
1003 }
1004
1005 fn redact_pids(text: &str) -> String {
1006 use regex::Regex;
1007 use std::sync::OnceLock;
1008 pub static TIMESTAMP_PATTERN: OnceLock<Regex> = OnceLock::new();
1009 let pat = TIMESTAMP_PATTERN.get_or_init(|| Regex::new(r"pid=\d+").unwrap());
1010 pat.replace_all(text, "<pid>").to_string()
1011 }
1012}