mod error;
mod messages;
mod notify;
mod config;
mod inner;
pub use error::*;
pub use messages::*;
use inner::*;
pub use config::*;
use tokio_postgres::{SimpleQueryMessage, ToStatement};
use {
futures::TryFutureExt,
std::{
time::Duration,
},
tokio::{
time::{sleep, timeout},
},
tokio_postgres::{
Row, RowStream, Socket, Statement, Transaction,
tls::MakeTlsConnect,
types::{BorrowToSql, ToSql, Type},
},
};
pub type PGResult<T> = Result<T, PGError>;
pub struct PGRobustClient<TLS>
{
config: PGRobustClientConfig<TLS>,
inner: PGClient,
}
#[allow(unused)]
impl<TLS> PGRobustClient<TLS>
where
TLS: MakeTlsConnect<Socket> + Clone,
<TLS as MakeTlsConnect<Socket>>::Stream: Send + Sync + 'static,
{
pub async fn spawn(config: PGRobustClientConfig<TLS>) -> PGResult<PGRobustClient<TLS>> {
let inner = PGClient::connect(&config).await?;
Ok(PGRobustClient { config, inner })
}
pub fn config(&self) -> &PGRobustClientConfig<TLS> {
&self.config
}
pub fn config_mut(&mut self) -> &mut PGRobustClientConfig<TLS> {
&mut self.config
}
pub async fn cancel_query(&mut self) -> PGResult<()> {
self.inner
.cancel_token
.cancel_query(self.config.make_tls.clone())
.await
.map_err(Into::into)
}
pub fn capture_and_clear_log(&mut self) -> Vec<PGMessage> {
match self.inner.log.write() {
Ok(mut guard) => {
let empty_log = Vec::default();
std::mem::replace(&mut *guard, empty_log)
}
Err(_) => {
#[cfg(feature = "tracing")]
tracing::error!("Lock poisoned in capture_and_clear_log - returning empty log");
Vec::default()
}
}
}
fn clear_log(&mut self) {
if let Ok(mut guard) = self.inner.log.write() {
guard.clear();
}
}
pub async fn with_captured_log<F, T>(&mut self, f: F) -> PGResult<(T, Vec<PGMessage>)>
where
F: AsyncFn(&mut Self) -> PGResult<T>,
{
self.capture_and_clear_log(); let result = f(self).await?;
let log = self.capture_and_clear_log();
Ok((result, log))
}
async fn reconnect(&mut self) -> PGResult<()> {
use std::cmp::{max, min};
let mut attempts = 1;
let mut k = 500;
while attempts <= self.config.max_reconnect_attempts {
sleep(Duration::from_millis(k + rand::random_range(0..k / 2))).await;
k = min(k * 2, 60000);
#[cfg(feature = "tracing")]
tracing::info!("Reconnect attempt #{}", attempts);
(self.config.callback)(PGMessage::reconnect(attempts, self.config.max_reconnect_attempts));
attempts += 1;
match PGClient::connect(&self.config).await {
Ok(inner) => {
self.inner = inner;
(self.config.callback)(PGMessage::connected());
if let Some(sql) = self.config.full_connect_script() {
match self.inner.simple_query(&sql).await {
Ok(_) => {
return Ok(());
}
Err(e) if is_pg_connection_issue(&e) => {
continue;
}
Err(e) => {
return Err(e.into());
}
}
} else {
return Ok(());
}
}
Err(e) if e.is_pg_connection_issue() => {
continue;
}
Err(e) => {
return Err(e);
}
}
}
(self.config.callback)(PGMessage::failed_to_reconnect(self.config.max_reconnect_attempts));
Err(PGError::FailedToReconnect(self.config.max_reconnect_attempts))
}
pub async fn wrap_reconnect<T>(
&mut self,
max_dur: Option<Duration>,
factory: impl AsyncFn(&mut PGClient) -> Result<T, tokio_postgres::Error>,
) -> PGResult<T> {
self.clear_log();
let max_dur = max_dur.unwrap_or(self.config.default_timeout);
loop {
match timeout(max_dur, factory(&mut self.inner)).await {
Ok(Ok(o)) => return Ok(o),
Ok(Err(e)) if is_pg_connection_issue(&e) => {
self.reconnect().await?;
}
Ok(Err(e)) => {
return Err(e.into());
}
Err(_) => {
(self.config.callback)(PGMessage::timeout(max_dur));
let status = self.inner.cancel_token.cancel_query(self.config.make_tls.clone()).await;
(self.config.callback)(PGMessage::cancelled(!status.is_err()));
return Err(PGError::Timeout(max_dur));
}
}
}
}
pub async fn subscribe_notify(
&mut self,
channels: &[impl AsRef<str> + Send + Sync + 'static],
timeout: Option<Duration>,
) -> PGResult<()> {
if !channels.is_empty() {
self.wrap_reconnect(timeout, async |client: &mut PGClient| {
PGClient::issue_listen(client, channels).await
})
.await?;
self.config.with_subscriptions(channels.iter().map(AsRef::as_ref));
}
Ok(())
}
pub async fn unsubscribe_notify(
&mut self,
channels: &[impl AsRef<str> + Send + Sync + 'static],
timeout: Option<Duration>,
) -> PGResult<()> {
if !channels.is_empty() {
self.wrap_reconnect(timeout, async move |client: &mut PGClient| {
PGClient::issue_unlisten(client, channels).await
})
.await?;
self.config.without_subscriptions(channels.iter().map(AsRef::as_ref));
}
Ok(())
}
pub async fn unsubscribe_notify_all(&mut self, timeout: Option<Duration>) -> PGResult<()> {
self.wrap_reconnect(timeout, async |client: &mut PGClient| {
#[cfg(feature = "tracing")]
tracing::info!("Unsubscribing from channels: *");
client.simple_query("UNLISTEN *").await?;
Ok(())
})
.await
}
pub async fn execute_raw<P, I, T>(
&mut self,
statement: &T,
params: I,
timeout: Option<Duration>,
) -> PGResult<u64>
where
T: ?Sized + ToStatement + Sync + Send,
P: BorrowToSql + Clone + Send + Sync,
I: IntoIterator<Item = P> + Sync + Send,
I::IntoIter: ExactSizeIterator,
{
let params: Vec<_> = params.into_iter().collect();
self.wrap_reconnect(timeout, async |client: &mut PGClient| {
client.execute_raw(statement, params.clone()).await
})
.await
}
pub async fn query<T>(
&mut self,
query: &T,
params: &[&(dyn ToSql + Sync)],
timeout: Option<Duration>,
) -> PGResult<Vec<Row>>
where
T: ?Sized + ToStatement + Sync + Send,
{
let params = params.to_vec();
self.wrap_reconnect(timeout, async |client: &mut PGClient| {
client.query(query, ¶ms).await
})
.await
}
pub async fn query_one<T>(
&mut self,
statement: &T,
params: &[&(dyn ToSql + Sync)],
timeout: Option<Duration>,
) -> PGResult<Row>
where
T: ?Sized + ToStatement + Sync + Send,
{
let params = params.to_vec();
self.wrap_reconnect(timeout, async |client: &mut PGClient| {
client.query_one(statement, ¶ms).await
})
.await
}
pub async fn query_opt<T>(
&mut self,
statement: &T,
params: &[&(dyn ToSql + Sync)],
timeout: Option<Duration>,
) -> PGResult<Option<Row>>
where
T: ?Sized + ToStatement + Sync + Send,
{
let params = params.to_vec();
self.wrap_reconnect(timeout, async |client: &mut PGClient| {
client.query_opt(statement, ¶ms).await
})
.await
}
pub async fn query_raw<T, P, I>(
&mut self,
statement: &T,
params: I,
timeout: Option<Duration>,
) -> PGResult<RowStream>
where
T: ?Sized + ToStatement + Sync + Send,
P: BorrowToSql + Clone + Send + Sync,
I: IntoIterator<Item = P> + Sync + Send,
I::IntoIter: ExactSizeIterator,
{
let params: Vec<_> = params.into_iter().collect();
self.wrap_reconnect(timeout, async |client: &mut PGClient| {
client.query_raw(statement, params.clone()).await
})
.await
}
pub async fn query_typed(
&mut self,
statement: &str,
params: &[(&(dyn ToSql + Sync), Type)],
timeout: Option<Duration>,
) -> PGResult<Vec<Row>> {
let params = params.to_vec();
self.wrap_reconnect(timeout, async |client: &mut PGClient| {
client.query_typed(statement, ¶ms).await
})
.await
}
pub async fn query_typed_raw<P, I>(
&mut self,
statement: &str,
params: I,
timeout: Option<Duration>,
) -> PGResult<RowStream>
where
P: BorrowToSql + Clone + Send + Sync,
I: IntoIterator<Item = (P, Type)> + Sync + Send,
{
let params: Vec<_> = params.into_iter().collect();
self.wrap_reconnect(timeout, async |client: &mut PGClient| {
client.query_typed_raw(statement, params.clone()).await
})
.await
}
pub async fn prepare(&mut self, query: &str, timeout: Option<Duration>) -> PGResult<Statement> {
self.wrap_reconnect(timeout, async |client: &mut PGClient| {
client.prepare(query).map_err(Into::into).await
})
.await
}
pub async fn prepare_typed(
&mut self,
query: &str,
parameter_types: &[Type],
timeout: Option<Duration>,
) -> PGResult<Statement> {
self.wrap_reconnect(timeout, async |client: &mut PGClient| {
client.prepare_typed(query, parameter_types).await
})
.await
}
pub async fn transaction<F>(&mut self, timeout: Option<Duration>, f: F) -> PGResult<()>
where
for<'a> F: AsyncFn(&'a mut Transaction) -> Result<(), tokio_postgres::Error>,
{
self.wrap_reconnect(timeout, async |client: &mut PGClient| {
let mut tx = client.transaction().await?;
f(&mut tx).await?;
tx.commit().await?;
Ok(())
})
.await
}
pub async fn batch_execute(&mut self, query: &str, timeout: Option<Duration>) -> PGResult<()> {
self.wrap_reconnect(timeout, async |client: &mut PGClient| {
client.batch_execute(query).await
})
.await
}
pub async fn simple_query(
&mut self,
query: &str,
timeout: Option<Duration>,
) -> PGResult<Vec<SimpleQueryMessage>> {
self.wrap_reconnect(timeout, async |client: &mut PGClient| {
client.simple_query(query).await
})
.await
}
pub fn client(&self) -> &tokio_postgres::Client {
&self.inner
}
}
pub async fn wrap_timeout<T>(dur: Duration, fut: impl Future<Output = PGResult<T>>) -> PGResult<T> {
match timeout(dur, fut).await {
Ok(out) => out,
Err(_) => Err(PGError::Timeout(dur)),
}
}
#[cfg(test)]
mod tests {
use {
super::{PGError, PGMessage, PGRaiseLevel, PGRobustClient, PGRobustClientConfig},
insta::*,
std::{
sync::{Arc, RwLock},
time::Duration,
},
testcontainers::{ImageExt, runners::AsyncRunner},
testcontainers_modules::postgres::Postgres,
};
mod unit {
use super::*;
use tokio_postgres::NoTls;
#[test]
fn config_default_values() {
let config = PGRobustClientConfig::new("postgres://localhost/test", NoTls);
assert_eq!(config.max_reconnect_attempts, 10);
assert_eq!(config.default_timeout, Duration::from_secs(3600));
assert!(config.subscriptions.is_empty());
assert!(config.connect_script.is_none());
assert!(config.application_name.is_none());
}
#[test]
fn config_builder_chaining() {
let config = PGRobustClientConfig::new("postgres://localhost/test", NoTls)
.max_reconnect_attempts(5)
.default_timeout(Duration::from_secs(30))
.application_name("test_app")
.connect_script("SET timezone = 'UTC'")
.subscriptions(["channel1", "channel2"]);
assert_eq!(config.max_reconnect_attempts, 5);
assert_eq!(config.default_timeout, Duration::from_secs(30));
assert_eq!(config.application_name, Some("test_app".to_string()));
assert_eq!(config.connect_script, Some("SET timezone = 'UTC'".to_string()));
assert!(config.subscriptions.contains("channel1"));
assert!(config.subscriptions.contains("channel2"));
}
#[test]
fn config_with_methods() {
let mut config = PGRobustClientConfig::new("postgres://localhost/test", NoTls);
config.with_max_reconnect_attempts(Some(3));
config.with_default_timeout(Some(Duration::from_secs(60)));
config.with_application_name(Some("my_app"));
config.with_connect_script(Some("SELECT 1"));
config.with_subscriptions(["events"]);
assert_eq!(config.max_reconnect_attempts, 3);
assert_eq!(config.default_timeout, Duration::from_secs(60));
assert_eq!(config.application_name, Some("my_app".to_string()));
assert_eq!(config.connect_script, Some("SELECT 1".to_string()));
assert!(config.subscriptions.contains("events"));
}
#[test]
fn config_full_connect_script_empty() {
let config = PGRobustClientConfig::new("postgres://localhost/test", NoTls);
assert!(config.full_connect_script().is_none());
}
#[test]
fn config_full_connect_script_with_app_name() {
let config = PGRobustClientConfig::new("postgres://localhost/test", NoTls)
.application_name("my_app");
let script = config.full_connect_script().unwrap();
assert!(script.contains("SET application_name = 'my_app'"));
}
#[test]
fn config_full_connect_script_with_subscriptions() {
let config = PGRobustClientConfig::new("postgres://localhost/test", NoTls)
.subscriptions(["chan1", "chan2"]);
let script = config.full_connect_script().unwrap();
assert!(script.contains("LISTEN chan1;"));
assert!(script.contains("LISTEN chan2;"));
}
#[test]
fn config_full_connect_script_combined() {
let config = PGRobustClientConfig::new("postgres://localhost/test", NoTls)
.application_name("app")
.connect_script("SET timezone = 'UTC';")
.subscriptions(["events"]);
let script = config.full_connect_script().unwrap();
assert!(script.contains("SET application_name = 'app'"));
assert!(script.contains("SET timezone = 'UTC';"));
assert!(script.contains("LISTEN events;"));
}
#[test]
fn config_without_subscriptions() {
let mut config = PGRobustClientConfig::new("postgres://localhost/test", NoTls)
.subscriptions(["a", "b", "c"]);
config.without_subscriptions(["b"]);
assert!(config.subscriptions.contains("a"));
assert!(!config.subscriptions.contains("b"));
assert!(config.subscriptions.contains("c"));
}
#[test]
fn error_timeout_display() {
let err = PGError::Timeout(Duration::from_secs(30));
let msg = err.to_string();
assert!(msg.contains("timed out"));
assert!(msg.contains("30"));
}
#[test]
fn error_failed_to_reconnect_display() {
let err = PGError::FailedToReconnect(5);
let msg = err.to_string();
assert!(msg.contains("5"));
assert!(msg.contains("reconnect"));
}
#[test]
fn error_is_timeout() {
let timeout_err = PGError::Timeout(Duration::from_secs(1));
let reconnect_err = PGError::FailedToReconnect(1);
assert!(timeout_err.is_timeout());
assert!(!reconnect_err.is_timeout());
}
#[test]
fn error_other() {
let custom_err = std::io::Error::new(std::io::ErrorKind::Other, "custom error");
let pg_err = PGError::other(custom_err);
assert!(matches!(pg_err, PGError::Other(_)));
assert!(pg_err.to_string().contains("custom error"));
}
#[test]
fn message_reconnect_creation() {
let msg = PGMessage::reconnect(3, 10);
match msg {
PGMessage::Reconnect { attempts, max_attempts, .. } => {
assert_eq!(attempts, 3);
assert_eq!(max_attempts, 10);
}
_ => panic!("Expected Reconnect variant"),
}
}
#[test]
fn message_connected_creation() {
let msg = PGMessage::connected();
assert!(matches!(msg, PGMessage::Connected { .. }));
}
#[test]
fn message_timeout_creation() {
let msg = PGMessage::timeout(Duration::from_secs(5));
match msg {
PGMessage::Timeout { duration, .. } => {
assert_eq!(duration, Duration::from_secs(5));
}
_ => panic!("Expected Timeout variant"),
}
}
#[test]
fn message_cancelled_creation() {
let msg_success = PGMessage::cancelled(true);
let msg_failure = PGMessage::cancelled(false);
match msg_success {
PGMessage::Cancelled { success, .. } => assert!(success),
_ => panic!("Expected Cancelled variant"),
}
match msg_failure {
PGMessage::Cancelled { success, .. } => assert!(!success),
_ => panic!("Expected Cancelled variant"),
}
}
#[test]
fn message_failed_to_reconnect_creation() {
let msg = PGMessage::failed_to_reconnect(5);
match msg {
PGMessage::FailedToReconnect { attempts, .. } => {
assert_eq!(attempts, 5);
}
_ => panic!("Expected FailedToReconnect variant"),
}
}
#[test]
fn message_disconnected_creation() {
let msg = PGMessage::disconnected("Connection reset");
match msg {
PGMessage::Disconnected { reason, .. } => {
assert_eq!(reason, "Connection reset");
}
_ => panic!("Expected Disconnected variant"),
}
}
#[test]
fn message_display_reconnect() {
let msg = PGMessage::reconnect(2, 10);
let display = msg.to_string();
assert!(display.contains("RECONNECT"));
assert!(display.contains("2"));
assert!(display.contains("10"));
}
#[test]
fn message_display_timeout() {
let msg = PGMessage::timeout(Duration::from_millis(500));
let display = msg.to_string();
assert!(display.contains("TIMEOUT"));
}
#[test]
fn raise_level_from_str() {
use std::str::FromStr;
assert!(PGRaiseLevel::from_str("DEBUG").is_ok());
assert!(PGRaiseLevel::from_str("LOG").is_ok());
assert!(PGRaiseLevel::from_str("INFO").is_ok());
assert!(PGRaiseLevel::from_str("NOTICE").is_ok());
assert!(PGRaiseLevel::from_str("WARNING").is_ok());
assert!(PGRaiseLevel::from_str("ERROR").is_ok());
assert!(PGRaiseLevel::from_str("FATAL").is_ok());
assert!(PGRaiseLevel::from_str("PANIC").is_ok());
}
#[test]
fn raise_level_display() {
assert_eq!(PGRaiseLevel::Debug.to_string(), "DEBUG");
assert_eq!(PGRaiseLevel::Log.to_string(), "LOG");
assert_eq!(PGRaiseLevel::Warning.to_string(), "WARNING");
}
#[test]
fn raise_level_unknown_returns_error() {
use std::str::FromStr;
assert!(PGRaiseLevel::from_str("UNKNOWN_LEVEL").is_err());
assert!(PGRaiseLevel::from_str("debug").is_err()); }
}
fn sql_for_log_and_notify_test(level: PGRaiseLevel) -> String {
format!(
r#"
set client_min_messages to '{}';
do $$
begin
raise debug 'this is a DEBUG notification';
notify test, 'test#1';
raise log 'this is a LOG notification';
notify test, 'test#2';
raise info 'this is a INFO notification';
notify test, 'test#3';
raise notice 'this is a NOTICE notification';
notify test, 'test#4';
raise warning 'this is a WARNING notification';
notify test, 'test#5';
end;
$$;
"#,
level
)
}
#[tokio::test]
async fn test_integration() {
let pg_server = Postgres::default()
.with_tag("16.4")
.start()
.await
.expect("could not start postgres server");
let database_url = format!(
"postgres://postgres:postgres@{}:{}/postgres",
pg_server.get_host().await.unwrap(),
pg_server.get_host_port_ipv4(5432).await.unwrap()
);
let notices = Arc::new(RwLock::new(Vec::new()));
let notices_clone = notices.clone();
let callback = move |msg: PGMessage| {
if let Ok(mut guard) = notices_clone.write() {
guard.push(msg.to_string());
}
};
let config = PGRobustClientConfig::new(database_url, tokio_postgres::NoTls);
let mut admin = PGRobustClient::spawn(config.clone())
.await
.expect("could not create initial client");
let mut client = PGRobustClient::spawn(config.callback(callback).max_reconnect_attempts(2))
.await
.expect("could not create initial client");
client
.subscribe_notify(&["test"], None)
.await
.expect("could not subscribe");
let (_, execution_log) = client
.with_captured_log(async |client: &mut PGRobustClient<_>| {
client
.simple_query(&sql_for_log_and_notify_test(PGRaiseLevel::Debug), None)
.await
})
.await
.expect("could not execute queries on postgres");
assert_json_snapshot!("subscribed-executionlog", &execution_log, {
"[].timestamp" => "<timestamp>",
"[].process_id" => "<pid>",
});
assert_snapshot!("subscribed-notify", extract_and_clear_logs(¬ices));
client
.unsubscribe_notify(&["test"], None)
.await
.expect("could not unsubscribe");
let (_, execution_log) = client
.with_captured_log(async |client| {
client
.simple_query(&sql_for_log_and_notify_test(PGRaiseLevel::Warning), None)
.await
})
.await
.expect("could not execute queries on postgres");
assert_json_snapshot!("unsubscribed-executionlog", &execution_log, {
"[].timestamp" => "<timestamp>",
"[].process_id" => "<pid>",
});
assert_snapshot!("unsubscribed-notify", extract_and_clear_logs(¬ices));
let result = client
.simple_query(
"
do $$
begin
raise info 'before sleep';
perform pg_sleep(3);
raise info 'after sleep';
end;
$$
",
Some(Duration::from_secs(1)),
)
.await;
assert!(matches!(result, Err(PGError::Timeout(_))));
assert_snapshot!("timeout-messages", extract_and_clear_logs(¬ices));
admin.simple_query("select pg_terminate_backend(pid) from pg_stat_activity where pid != pg_backend_pid()", None)
.await.expect("could not kill other client");
let result = client
.simple_query(
"
do $$
begin
raise info 'before sleep';
perform pg_sleep(1);
raise info 'after sleep';
end;
$$
",
Some(Duration::from_secs(10)),
)
.await;
assert!(matches!(result, Ok(_)));
assert_snapshot!("reconnect-before", extract_and_clear_logs(¬ices));
let query = client.simple_query(
"
do $$
begin
raise info 'before sleep';
perform pg_sleep(1);
raise info 'after sleep';
end;
$$
",
None,
);
let kill_later =
admin.simple_query("
select pg_sleep(0.5);
select pg_terminate_backend(pid) from pg_stat_activity where pid != pg_backend_pid()",
None
);
let (_, result) = tokio::join!(kill_later, query);
assert!(matches!(result, Ok(_)));
assert_snapshot!("reconnect-during", extract_and_clear_logs(¬ices));
pg_server.stop().await.expect("could not stop server");
let result = client.simple_query(
"
do $$
begin
raise info 'before sleep';
perform pg_sleep(1);
raise info 'after sleep';
end;
$$
",
None,
).await;
eprintln!("result: {result:?}");
assert!(matches!(result, Err(PGError::FailedToReconnect(2))));
assert_snapshot!("reconnect-failure", extract_and_clear_logs(¬ices));
}
fn extract_and_clear_logs(logs: &Arc<RwLock<Vec<String>>>) -> String {
let mut guard = logs.write().expect("could not read notices");
let emtpy_log = Vec::default();
let log = std::mem::replace(&mut *guard, emtpy_log);
redact_pids(&redact_timestamps(&log.join("\n")))
}
fn redact_timestamps(text: &str) -> String {
use regex::Regex;
use std::sync::OnceLock;
pub static TIMESTAMP_PATTERN: OnceLock<Regex> = OnceLock::new();
let pat = TIMESTAMP_PATTERN.get_or_init(|| {
Regex::new(r"\d{4}-\d{2}-\d{2}.?\d{2}:\d{2}:\d{2}(\.\d{3,9})?(Z| UTC|[+-]\d{2}:\d{2})?")
.unwrap()
});
pat.replace_all(text, "<timestamp>").to_string()
}
fn redact_pids(text: &str) -> String {
use regex::Regex;
use std::sync::OnceLock;
pub static TIMESTAMP_PATTERN: OnceLock<Regex> = OnceLock::new();
let pat = TIMESTAMP_PATTERN.get_or_init(|| Regex::new(r"pid=\d+").unwrap());
pat.replace_all(text, "<pid>").to_string()
}
}