1mod error;
177mod messages;
178mod notify;
179mod config;
180mod inner;
181
182pub use error::*;
183pub use messages::*;
184use inner::*;
185pub use config::*;
186
187use tokio_postgres::{SimpleQueryMessage, ToStatement};
188
189use {
190 futures::TryFutureExt,
191 std::{
192 time::Duration,
193 },
194 tokio::{
195 time::{sleep, timeout},
196 },
197 tokio_postgres::{
198 Row, RowStream, Socket, Statement, Transaction,
199 tls::MakeTlsConnect,
200 types::{BorrowToSql, ToSql, Type},
201 },
202};
203
204pub type PGResult<T> = Result<T, PGError>;
206
207
208
209pub struct PGRobustClient<TLS>
210{
211 config: PGRobustClientConfig<TLS>,
212 inner: PGClient,
213}
214
215#[allow(unused)]
216impl<TLS> PGRobustClient<TLS>
217where
218 TLS: MakeTlsConnect<Socket> + Clone,
219 <TLS as MakeTlsConnect<Socket>>::Stream: Send + Sync + 'static,
220{
221 pub async fn spawn(config: PGRobustClientConfig<TLS>) -> PGResult<PGRobustClient<TLS>> {
225 let inner = PGClient::connect(&config).await?;
226 Ok(PGRobustClient { config, inner })
227 }
228
229 pub fn config(&self) -> &PGRobustClientConfig<TLS> {
233 &self.config
234 }
235
236 pub fn config_mut(&mut self) -> &mut PGRobustClientConfig<TLS> {
241 &mut self.config
242 }
243
244 pub async fn cancel_query(&mut self) -> PGResult<()> {
252 self.inner
253 .cancel_token
254 .cancel_query(self.config.make_tls.clone())
255 .await
256 .map_err(Into::into)
257 }
258
259 pub fn capture_and_clear_log(&mut self) -> Vec<PGMessage> {
264 match self.inner.log.write() {
265 Ok(mut guard) => {
266 let empty_log = Vec::default();
267 std::mem::replace(&mut *guard, empty_log)
268 }
269 Err(_) => {
270 #[cfg(feature = "tracing")]
271 tracing::error!("Lock poisoned in capture_and_clear_log - returning empty log");
272 Vec::default()
273 }
274 }
275 }
276
277 fn clear_log(&mut self) {
281 if let Ok(mut guard) = self.inner.log.write() {
282 guard.clear();
283 }
284 }
285
286 pub async fn with_captured_log<F, T>(&mut self, f: F) -> PGResult<(T, Vec<PGMessage>)>
295 where
296 F: AsyncFn(&mut Self) -> PGResult<T>,
297 {
298 self.capture_and_clear_log(); let result = f(self).await?;
300 let log = self.capture_and_clear_log();
301 Ok((result, log))
302 }
303
304 async fn reconnect(&mut self) -> PGResult<()> {
315 use std::cmp::{max, min};
317 let mut attempts = 1;
318 let mut k = 500;
319
320 while attempts <= self.config.max_reconnect_attempts {
321 sleep(Duration::from_millis(k + rand::random_range(0..k / 2))).await;
326 k = min(k * 2, 60000);
327
328 #[cfg(feature = "tracing")]
329 tracing::info!("Reconnect attempt #{}", attempts);
330 (self.config.callback)(PGMessage::reconnect(attempts, self.config.max_reconnect_attempts));
331
332 attempts += 1;
333
334 match PGClient::connect(&self.config).await {
335 Ok(inner) => {
336
337 self.inner = inner;
338
339 (self.config.callback)(PGMessage::connected());
340
341 if let Some(sql) = self.config.full_connect_script() {
342 match self.inner.simple_query(&sql).await {
343 Ok(_) => {
344 return Ok(());
345 }
346 Err(e) if is_pg_connection_issue(&e) => {
347 continue;
348 }
349 Err(e) => {
350 return Err(e.into());
351 }
352 }
353 } else {
354 return Ok(());
355 }
356 }
357 Err(e) if e.is_pg_connection_issue() => {
358 continue;
359 }
360 Err(e) => {
361 return Err(e);
362 }
363 }
364 }
365
366 (self.config.callback)(PGMessage::failed_to_reconnect(self.config.max_reconnect_attempts));
368 Err(PGError::FailedToReconnect(self.config.max_reconnect_attempts))
370 }
371
372
373 pub async fn wrap_reconnect<T>(
384 &mut self,
385 max_dur: Option<Duration>,
386 factory: impl AsyncFn(&mut PGClient) -> Result<T, tokio_postgres::Error>,
387 ) -> PGResult<T> {
388 self.clear_log();
390 let max_dur = max_dur.unwrap_or(self.config.default_timeout);
391 loop {
392 match timeout(max_dur, factory(&mut self.inner)).await {
393 Ok(Ok(o)) => return Ok(o),
395 Ok(Err(e)) if is_pg_connection_issue(&e) => {
397 self.reconnect().await?;
398 }
399 Ok(Err(e)) => {
401 return Err(e.into());
402 }
403 Err(_) => {
405 (self.config.callback)(PGMessage::timeout(max_dur));
407 let status = self.inner.cancel_token.cancel_query(self.config.make_tls.clone()).await;
409 (self.config.callback)(PGMessage::cancelled(!status.is_err()));
411 return Err(PGError::Timeout(max_dur));
413 }
414 }
415 }
416 }
417
418 pub async fn subscribe_notify(
419 &mut self,
420 channels: &[impl AsRef<str> + Send + Sync + 'static],
421 timeout: Option<Duration>,
422 ) -> PGResult<()> {
423
424 if !channels.is_empty() {
425 self.wrap_reconnect(timeout, async |client: &mut PGClient| {
427 PGClient::issue_listen(client, channels).await
428 })
429 .await?;
430
431 self.config.with_subscriptions(channels.iter().map(AsRef::as_ref));
433 }
434 Ok(())
435 }
436
437
438
439 pub async fn unsubscribe_notify(
440 &mut self,
441 channels: &[impl AsRef<str> + Send + Sync + 'static],
442 timeout: Option<Duration>,
443 ) -> PGResult<()> {
444 if !channels.is_empty() {
445 self.wrap_reconnect(timeout, async move |client: &mut PGClient| {
447 PGClient::issue_unlisten(client, channels).await
448 })
449 .await?;
450
451 self.config.without_subscriptions(channels.iter().map(AsRef::as_ref));
453 }
454 Ok(())
455 }
456
457 pub async fn unsubscribe_notify_all(&mut self, timeout: Option<Duration>) -> PGResult<()> {
461 self.wrap_reconnect(timeout, async |client: &mut PGClient| {
462 #[cfg(feature = "tracing")]
464 tracing::info!("Unsubscribing from channels: *");
465 client.simple_query("UNLISTEN *").await?;
467 Ok(())
468 })
469 .await
470 }
471
472
473 pub async fn execute_raw<P, I, T>(
475 &mut self,
476 statement: &T,
477 params: I,
478 timeout: Option<Duration>,
479 ) -> PGResult<u64>
480 where
481 T: ?Sized + ToStatement + Sync + Send,
482 P: BorrowToSql + Clone + Send + Sync,
483 I: IntoIterator<Item = P> + Sync + Send,
484 I::IntoIter: ExactSizeIterator,
485 {
486 let params: Vec<_> = params.into_iter().collect();
487 self.wrap_reconnect(timeout, async |client: &mut PGClient| {
488 client.execute_raw(statement, params.clone()).await
489 })
490 .await
491 }
492
493 pub async fn query<T>(
500 &mut self,
501 query: &T,
502 params: &[&(dyn ToSql + Sync)],
503 timeout: Option<Duration>,
504 ) -> PGResult<Vec<Row>>
505 where
506 T: ?Sized + ToStatement + Sync + Send,
507 {
508 let params = params.to_vec();
509 self.wrap_reconnect(timeout, async |client: &mut PGClient| {
510 client.query(query, ¶ms).await
511 })
512 .await
513 }
514
515 pub async fn query_one<T>(
517 &mut self,
518 statement: &T,
519 params: &[&(dyn ToSql + Sync)],
520 timeout: Option<Duration>,
521 ) -> PGResult<Row>
522 where
523 T: ?Sized + ToStatement + Sync + Send,
524 {
525 let params = params.to_vec();
526 self.wrap_reconnect(timeout, async |client: &mut PGClient| {
527 client.query_one(statement, ¶ms).await
528 })
529 .await
530 }
531
532 pub async fn query_opt<T>(
534 &mut self,
535 statement: &T,
536 params: &[&(dyn ToSql + Sync)],
537 timeout: Option<Duration>,
538 ) -> PGResult<Option<Row>>
539 where
540 T: ?Sized + ToStatement + Sync + Send,
541 {
542 let params = params.to_vec();
543 self.wrap_reconnect(timeout, async |client: &mut PGClient| {
544 client.query_opt(statement, ¶ms).await
545 })
546 .await
547 }
548
549 pub async fn query_raw<T, P, I>(
551 &mut self,
552 statement: &T,
553 params: I,
554 timeout: Option<Duration>,
555 ) -> PGResult<RowStream>
556 where
557 T: ?Sized + ToStatement + Sync + Send,
558 P: BorrowToSql + Clone + Send + Sync,
559 I: IntoIterator<Item = P> + Sync + Send,
560 I::IntoIter: ExactSizeIterator,
561 {
562 let params: Vec<_> = params.into_iter().collect();
563 self.wrap_reconnect(timeout, async |client: &mut PGClient| {
564 client.query_raw(statement, params.clone()).await
565 })
566 .await
567 }
568
569 pub async fn query_typed(
571 &mut self,
572 statement: &str,
573 params: &[(&(dyn ToSql + Sync), Type)],
574 timeout: Option<Duration>,
575 ) -> PGResult<Vec<Row>> {
576 let params = params.to_vec();
577 self.wrap_reconnect(timeout, async |client: &mut PGClient| {
578 client.query_typed(statement, ¶ms).await
579 })
580 .await
581 }
582
583 pub async fn query_typed_raw<P, I>(
585 &mut self,
586 statement: &str,
587 params: I,
588 timeout: Option<Duration>,
589 ) -> PGResult<RowStream>
590 where
591 P: BorrowToSql + Clone + Send + Sync,
592 I: IntoIterator<Item = (P, Type)> + Sync + Send,
593 {
594 let params: Vec<_> = params.into_iter().collect();
595 self.wrap_reconnect(timeout, async |client: &mut PGClient| {
596 client.query_typed_raw(statement, params.clone()).await
597 })
598 .await
599 }
600
601 pub async fn prepare(&mut self, query: &str, timeout: Option<Duration>) -> PGResult<Statement> {
603 self.wrap_reconnect(timeout, async |client: &mut PGClient| {
604 client.prepare(query).map_err(Into::into).await
605 })
606 .await
607 }
608
609 pub async fn prepare_typed(
611 &mut self,
612 query: &str,
613 parameter_types: &[Type],
614 timeout: Option<Duration>,
615 ) -> PGResult<Statement> {
616 self.wrap_reconnect(timeout, async |client: &mut PGClient| {
617 client.prepare_typed(query, parameter_types).await
618 })
619 .await
620 }
621
622 pub async fn transaction<F>(&mut self, timeout: Option<Duration>, f: F) -> PGResult<()>
632 where
633 for<'a> F: AsyncFn(&'a mut Transaction) -> Result<(), tokio_postgres::Error>,
634 {
635 self.wrap_reconnect(timeout, async |client: &mut PGClient| {
636 let mut tx = client.transaction().await?;
637 f(&mut tx).await?;
638 tx.commit().await?;
639 Ok(())
640 })
641 .await
642 }
643
644 pub async fn batch_execute(&mut self, query: &str, timeout: Option<Duration>) -> PGResult<()> {
646 self.wrap_reconnect(timeout, async |client: &mut PGClient| {
647 client.batch_execute(query).await
648 })
649 .await
650 }
651
652 pub async fn simple_query(
654 &mut self,
655 query: &str,
656 timeout: Option<Duration>,
657 ) -> PGResult<Vec<SimpleQueryMessage>> {
658 self.wrap_reconnect(timeout, async |client: &mut PGClient| {
659 client.simple_query(query).await
660 })
661 .await
662 }
663
664 pub fn client(&self) -> &tokio_postgres::Client {
666 &self.inner
667 }
668}
669
670pub async fn wrap_timeout<T>(dur: Duration, fut: impl Future<Output = PGResult<T>>) -> PGResult<T> {
674 match timeout(dur, fut).await {
675 Ok(out) => out,
676 Err(_) => Err(PGError::Timeout(dur)),
677 }
678}
679
680#[cfg(test)]
681mod tests {
682
683 use {
684 super::{PGError, PGMessage, PGRaiseLevel, PGRobustClient, PGRobustClientConfig},
685 insta::*,
686 std::{
687 sync::{Arc, RwLock},
688 time::Duration,
689 },
690 testcontainers::{ImageExt, runners::AsyncRunner},
691 testcontainers_modules::postgres::Postgres,
692 };
693
694 mod unit {
699 use super::*;
700 use tokio_postgres::NoTls;
701
702 #[test]
707 fn config_default_values() {
708 let config = PGRobustClientConfig::new("postgres://localhost/test", NoTls);
709
710 assert_eq!(config.max_reconnect_attempts, 10);
711 assert_eq!(config.default_timeout, Duration::from_secs(3600));
712 assert!(config.subscriptions.is_empty());
713 assert!(config.connect_script.is_none());
714 assert!(config.application_name.is_none());
715 }
716
717 #[test]
718 fn config_builder_chaining() {
719 let config = PGRobustClientConfig::new("postgres://localhost/test", NoTls)
720 .max_reconnect_attempts(5)
721 .default_timeout(Duration::from_secs(30))
722 .application_name("test_app")
723 .connect_script("SET timezone = 'UTC'")
724 .subscriptions(["channel1", "channel2"]);
725
726 assert_eq!(config.max_reconnect_attempts, 5);
727 assert_eq!(config.default_timeout, Duration::from_secs(30));
728 assert_eq!(config.application_name, Some("test_app".to_string()));
729 assert_eq!(config.connect_script, Some("SET timezone = 'UTC'".to_string()));
730 assert!(config.subscriptions.contains("channel1"));
731 assert!(config.subscriptions.contains("channel2"));
732 }
733
734 #[test]
735 fn config_with_methods() {
736 let mut config = PGRobustClientConfig::new("postgres://localhost/test", NoTls);
737
738 config.with_max_reconnect_attempts(Some(3));
739 config.with_default_timeout(Some(Duration::from_secs(60)));
740 config.with_application_name(Some("my_app"));
741 config.with_connect_script(Some("SELECT 1"));
742 config.with_subscriptions(["events"]);
743
744 assert_eq!(config.max_reconnect_attempts, 3);
745 assert_eq!(config.default_timeout, Duration::from_secs(60));
746 assert_eq!(config.application_name, Some("my_app".to_string()));
747 assert_eq!(config.connect_script, Some("SELECT 1".to_string()));
748 assert!(config.subscriptions.contains("events"));
749 }
750
751 #[test]
752 fn config_full_connect_script_empty() {
753 let config = PGRobustClientConfig::new("postgres://localhost/test", NoTls);
754 assert!(config.full_connect_script().is_none());
755 }
756
757 #[test]
758 fn config_full_connect_script_with_app_name() {
759 let config = PGRobustClientConfig::new("postgres://localhost/test", NoTls)
760 .application_name("my_app");
761
762 let script = config.full_connect_script().unwrap();
763 assert!(script.contains("SET application_name = 'my_app'"));
764 }
765
766 #[test]
767 fn config_full_connect_script_with_subscriptions() {
768 let config = PGRobustClientConfig::new("postgres://localhost/test", NoTls)
769 .subscriptions(["chan1", "chan2"]);
770
771 let script = config.full_connect_script().unwrap();
772 assert!(script.contains("LISTEN chan1;"));
773 assert!(script.contains("LISTEN chan2;"));
774 }
775
776 #[test]
777 fn config_full_connect_script_combined() {
778 let config = PGRobustClientConfig::new("postgres://localhost/test", NoTls)
779 .application_name("app")
780 .connect_script("SET timezone = 'UTC';")
781 .subscriptions(["events"]);
782
783 let script = config.full_connect_script().unwrap();
784 assert!(script.contains("SET application_name = 'app'"));
785 assert!(script.contains("SET timezone = 'UTC';"));
786 assert!(script.contains("LISTEN events;"));
787 }
788
789 #[test]
790 fn config_without_subscriptions() {
791 let mut config = PGRobustClientConfig::new("postgres://localhost/test", NoTls)
792 .subscriptions(["a", "b", "c"]);
793
794 config.without_subscriptions(["b"]);
795
796 assert!(config.subscriptions.contains("a"));
797 assert!(!config.subscriptions.contains("b"));
798 assert!(config.subscriptions.contains("c"));
799 }
800
801 #[test]
806 fn error_timeout_display() {
807 let err = PGError::Timeout(Duration::from_secs(30));
808 let msg = err.to_string();
809 assert!(msg.contains("timed out"));
810 assert!(msg.contains("30"));
811 }
812
813 #[test]
814 fn error_failed_to_reconnect_display() {
815 let err = PGError::FailedToReconnect(5);
816 let msg = err.to_string();
817 assert!(msg.contains("5"));
818 assert!(msg.contains("reconnect"));
819 }
820
821 #[test]
822 fn error_is_timeout() {
823 let timeout_err = PGError::Timeout(Duration::from_secs(1));
824 let reconnect_err = PGError::FailedToReconnect(1);
825
826 assert!(timeout_err.is_timeout());
827 assert!(!reconnect_err.is_timeout());
828 }
829
830 #[test]
831 fn error_other() {
832 let custom_err = std::io::Error::new(std::io::ErrorKind::Other, "custom error");
833 let pg_err = PGError::other(custom_err);
834
835 assert!(matches!(pg_err, PGError::Other(_)));
836 assert!(pg_err.to_string().contains("custom error"));
837 }
838
839 #[test]
844 fn message_reconnect_creation() {
845 let msg = PGMessage::reconnect(3, 10);
846 match msg {
847 PGMessage::Reconnect { attempts, max_attempts, .. } => {
848 assert_eq!(attempts, 3);
849 assert_eq!(max_attempts, 10);
850 }
851 _ => panic!("Expected Reconnect variant"),
852 }
853 }
854
855 #[test]
856 fn message_connected_creation() {
857 let msg = PGMessage::connected();
858 assert!(matches!(msg, PGMessage::Connected { .. }));
859 }
860
861 #[test]
862 fn message_timeout_creation() {
863 let msg = PGMessage::timeout(Duration::from_secs(5));
864 match msg {
865 PGMessage::Timeout { duration, .. } => {
866 assert_eq!(duration, Duration::from_secs(5));
867 }
868 _ => panic!("Expected Timeout variant"),
869 }
870 }
871
872 #[test]
873 fn message_cancelled_creation() {
874 let msg_success = PGMessage::cancelled(true);
875 let msg_failure = PGMessage::cancelled(false);
876
877 match msg_success {
878 PGMessage::Cancelled { success, .. } => assert!(success),
879 _ => panic!("Expected Cancelled variant"),
880 }
881 match msg_failure {
882 PGMessage::Cancelled { success, .. } => assert!(!success),
883 _ => panic!("Expected Cancelled variant"),
884 }
885 }
886
887 #[test]
888 fn message_failed_to_reconnect_creation() {
889 let msg = PGMessage::failed_to_reconnect(5);
890 match msg {
891 PGMessage::FailedToReconnect { attempts, .. } => {
892 assert_eq!(attempts, 5);
893 }
894 _ => panic!("Expected FailedToReconnect variant"),
895 }
896 }
897
898 #[test]
899 fn message_disconnected_creation() {
900 let msg = PGMessage::disconnected("Connection reset");
901 match msg {
902 PGMessage::Disconnected { reason, .. } => {
903 assert_eq!(reason, "Connection reset");
904 }
905 _ => panic!("Expected Disconnected variant"),
906 }
907 }
908
909 #[test]
910 fn message_display_reconnect() {
911 let msg = PGMessage::reconnect(2, 10);
912 let display = msg.to_string();
913 assert!(display.contains("RECONNECT"));
914 assert!(display.contains("2"));
915 assert!(display.contains("10"));
916 }
917
918 #[test]
919 fn message_display_timeout() {
920 let msg = PGMessage::timeout(Duration::from_millis(500));
921 let display = msg.to_string();
922 assert!(display.contains("TIMEOUT"));
923 }
924
925 #[test]
930 fn raise_level_from_str() {
931 use std::str::FromStr;
932
933 assert!(PGRaiseLevel::from_str("DEBUG").is_ok());
935 assert!(PGRaiseLevel::from_str("LOG").is_ok());
936 assert!(PGRaiseLevel::from_str("INFO").is_ok());
937 assert!(PGRaiseLevel::from_str("NOTICE").is_ok());
938 assert!(PGRaiseLevel::from_str("WARNING").is_ok());
939 assert!(PGRaiseLevel::from_str("ERROR").is_ok());
940 assert!(PGRaiseLevel::from_str("FATAL").is_ok());
941 assert!(PGRaiseLevel::from_str("PANIC").is_ok());
942 }
943
944 #[test]
945 fn raise_level_display() {
946 assert_eq!(PGRaiseLevel::Debug.to_string(), "DEBUG");
947 assert_eq!(PGRaiseLevel::Log.to_string(), "LOG");
948 assert_eq!(PGRaiseLevel::Warning.to_string(), "WARNING");
949 }
950
951 #[test]
952 fn raise_level_unknown_returns_error() {
953 use std::str::FromStr;
954 assert!(PGRaiseLevel::from_str("UNKNOWN_LEVEL").is_err());
955 assert!(PGRaiseLevel::from_str("debug").is_err()); }
957 }
958
959 fn sql_for_log_and_notify_test(level: PGRaiseLevel) -> String {
964 format!(
965 r#"
966 set client_min_messages to '{}';
967 do $$
968 begin
969 raise debug 'this is a DEBUG notification';
970 notify test, 'test#1';
971 raise log 'this is a LOG notification';
972 notify test, 'test#2';
973 raise info 'this is a INFO notification';
974 notify test, 'test#3';
975 raise notice 'this is a NOTICE notification';
976 notify test, 'test#4';
977 raise warning 'this is a WARNING notification';
978 notify test, 'test#5';
979 end;
980 $$;
981 "#,
982 level
983 )
984 }
985
986 #[tokio::test]
987 async fn test_integration() {
988 let pg_server = Postgres::default()
994 .with_tag("16.4")
995 .start()
996 .await
997 .expect("could not start postgres server");
998
999 let database_url = format!(
1001 "postgres://postgres:postgres@{}:{}/postgres",
1002 pg_server.get_host().await.unwrap(),
1003 pg_server.get_host_port_ipv4(5432).await.unwrap()
1004 );
1005
1006 let notices = Arc::new(RwLock::new(Vec::new()));
1013 let notices_clone = notices.clone();
1014
1015 let callback = move |msg: PGMessage| {
1016 if let Ok(mut guard) = notices_clone.write() {
1017 guard.push(msg.to_string());
1018 }
1019 };
1020
1021 let config = PGRobustClientConfig::new(database_url, tokio_postgres::NoTls);
1022
1023 let mut admin = PGRobustClient::spawn(config.clone())
1024 .await
1025 .expect("could not create initial client");
1026
1027 let mut client = PGRobustClient::spawn(config.callback(callback).max_reconnect_attempts(2))
1028 .await
1029 .expect("could not create initial client");
1030
1031 client
1036 .subscribe_notify(&["test"], None)
1037 .await
1038 .expect("could not subscribe");
1039
1040 let (_, execution_log) = client
1041 .with_captured_log(async |client: &mut PGRobustClient<_>| {
1042 client
1043 .simple_query(&sql_for_log_and_notify_test(PGRaiseLevel::Debug), None)
1044 .await
1045 })
1046 .await
1047 .expect("could not execute queries on postgres");
1048
1049 assert_json_snapshot!("subscribed-executionlog", &execution_log, {
1050 "[].timestamp" => "<timestamp>",
1051 "[].process_id" => "<pid>",
1052 });
1053
1054 assert_snapshot!("subscribed-notify", extract_and_clear_logs(¬ices));
1055
1056 client
1061 .unsubscribe_notify(&["test"], None)
1062 .await
1063 .expect("could not unsubscribe");
1064
1065 let (_, execution_log) = client
1066 .with_captured_log(async |client| {
1067 client
1068 .simple_query(&sql_for_log_and_notify_test(PGRaiseLevel::Warning), None)
1069 .await
1070 })
1071 .await
1072 .expect("could not execute queries on postgres");
1073
1074 assert_json_snapshot!("unsubscribed-executionlog", &execution_log, {
1075 "[].timestamp" => "<timestamp>",
1076 "[].process_id" => "<pid>",
1077 });
1078
1079 assert_snapshot!("unsubscribed-notify", extract_and_clear_logs(¬ices));
1080
1081 let result = client
1086 .simple_query(
1087 "
1088 do $$
1089 begin
1090 raise info 'before sleep';
1091 perform pg_sleep(3);
1092 raise info 'after sleep';
1093 end;
1094 $$
1095 ",
1096 Some(Duration::from_secs(1)),
1097 )
1098 .await;
1099
1100 assert!(matches!(result, Err(PGError::Timeout(_))));
1101 assert_snapshot!("timeout-messages", extract_and_clear_logs(¬ices));
1102
1103 admin.simple_query("select pg_terminate_backend(pid) from pg_stat_activity where pid != pg_backend_pid()", None)
1108 .await.expect("could not kill other client");
1109
1110 let result = client
1111 .simple_query(
1112 "
1113 do $$
1114 begin
1115 raise info 'before sleep';
1116 perform pg_sleep(1);
1117 raise info 'after sleep';
1118 end;
1119 $$
1120 ",
1121 Some(Duration::from_secs(10)),
1122 )
1123 .await;
1124
1125 assert!(matches!(result, Ok(_)));
1126 assert_snapshot!("reconnect-before", extract_and_clear_logs(¬ices));
1127
1128 let query = client.simple_query(
1133 "
1134 do $$
1135 begin
1136 raise info 'before sleep';
1137 perform pg_sleep(1);
1138 raise info 'after sleep';
1139 end;
1140 $$
1141 ",
1142 None,
1143 );
1144
1145 let kill_later =
1146 admin.simple_query("
1147 select pg_sleep(0.5);
1148 select pg_terminate_backend(pid) from pg_stat_activity where pid != pg_backend_pid()",
1149 None
1150 );
1151
1152 let (_, result) = tokio::join!(kill_later, query);
1153
1154 assert!(matches!(result, Ok(_)));
1155 assert_snapshot!("reconnect-during", extract_and_clear_logs(¬ices));
1156
1157 pg_server.stop().await.expect("could not stop server");
1162
1163 let result = client.simple_query(
1164 "
1165 do $$
1166 begin
1167 raise info 'before sleep';
1168 perform pg_sleep(1);
1169 raise info 'after sleep';
1170 end;
1171 $$
1172 ",
1173 None,
1174 ).await;
1175
1176 eprintln!("result: {result:?}");
1177 assert!(matches!(result, Err(PGError::FailedToReconnect(2))));
1178 assert_snapshot!("reconnect-failure", extract_and_clear_logs(¬ices));
1179
1180
1181 }
1182
1183 fn extract_and_clear_logs(logs: &Arc<RwLock<Vec<String>>>) -> String {
1184 let mut guard = logs.write().expect("could not read notices");
1185 let emtpy_log = Vec::default();
1186 let log = std::mem::replace(&mut *guard, emtpy_log);
1187 redact_pids(&redact_timestamps(&log.join("\n")))
1188 }
1189
1190 fn redact_timestamps(text: &str) -> String {
1191 use regex::Regex;
1192 use std::sync::OnceLock;
1193 pub static TIMESTAMP_PATTERN: OnceLock<Regex> = OnceLock::new();
1194 let pat = TIMESTAMP_PATTERN.get_or_init(|| {
1195 Regex::new(r"\d{4}-\d{2}-\d{2}.?\d{2}:\d{2}:\d{2}(\.\d{3,9})?(Z| UTC|[+-]\d{2}:\d{2})?")
1196 .unwrap()
1197 });
1198 pat.replace_all(text, "<timestamp>").to_string()
1199 }
1200
1201 fn redact_pids(text: &str) -> String {
1202 use regex::Regex;
1203 use std::sync::OnceLock;
1204 pub static TIMESTAMP_PATTERN: OnceLock<Regex> = OnceLock::new();
1205 let pat = TIMESTAMP_PATTERN.get_or_init(|| Regex::new(r"pid=\d+").unwrap());
1206 pat.replace_all(text, "<pid>").to_string()
1207 }
1208}