use std::sync::Arc;
use arc_swap::ArcSwap;
use tokio::sync::Mutex;
use crate::encode::SqlParam;
use crate::error::TypedError;
use crate::query::Client;
use crate::row::Row;
pub struct ReconnectingClient {
client: ArcSwap<Client>,
reconnect_lock: Mutex<()>,
addr: String,
user: String,
password: String,
database: String,
init_sql: Vec<String>,
}
impl std::fmt::Debug for ReconnectingClient {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ReconnectingClient")
.field("addr", &self.addr)
.field("user", &self.user)
.field("database", &self.database)
.field("alive", &self.client.load().is_alive())
.finish()
}
}
impl ReconnectingClient {
pub async fn new(
addr: &str,
user: &str,
password: &str,
database: &str,
init_sql: Vec<String>,
) -> Result<Self, TypedError> {
let client = Self::connect_inner(addr, user, password, database, &init_sql).await?;
Ok(Self {
client: ArcSwap::from_pointee(client),
reconnect_lock: Mutex::new(()),
addr: addr.to_string(),
user: user.to_string(),
password: password.to_string(),
database: database.to_string(),
init_sql,
})
}
async fn connect_inner(
addr: &str,
user: &str,
password: &str,
database: &str,
init_sql: &[String],
) -> Result<Client, TypedError> {
let init_refs: Vec<&str> = init_sql.iter().map(|s| s.as_str()).collect();
Client::connect_with_init(addr, user, password, database, &init_refs).await
}
async fn reconnect(&self) -> Result<(), TypedError> {
let _guard = self.reconnect_lock.lock().await;
if self.client.load().is_alive() {
return Ok(());
}
tracing::info!(addr = %self.addr, database = %self.database, "reconnecting");
let new_client = Self::connect_inner(
&self.addr,
&self.user,
&self.password,
&self.database,
&self.init_sql,
)
.await?;
self.client.store(Arc::new(new_client));
Ok(())
}
pub async fn query(&self, sql: &str, params: &[&dyn SqlParam]) -> Result<Vec<Row>, TypedError> {
let client = self.client.load();
match client.query(sql, params).await {
Ok(rows) => Ok(rows),
Err(e) if is_connection_error(&e) => {
tracing::warn!(error = %e, "connection lost, reconnecting");
self.reconnect().await?;
self.client.load().query(sql, params).await
}
Err(e) => Err(e),
}
}
pub async fn execute(&self, sql: &str, params: &[&dyn SqlParam]) -> Result<u64, TypedError> {
let client = self.client.load();
match client.execute(sql, params).await {
Ok(n) => Ok(n),
Err(e) if is_connection_error(&e) => {
tracing::warn!(error = %e, "connection lost, reconnecting");
self.reconnect().await?;
self.client.load().execute(sql, params).await
}
Err(e) => Err(e),
}
}
pub fn client(&self) -> arc_swap::Guard<Arc<Client>> {
self.client.load()
}
pub fn is_alive(&self) -> bool {
self.client.load().is_alive()
}
}
fn is_connection_error(e: &TypedError) -> bool {
match e {
TypedError::Wire(wire_err) => matches!(
wire_err.as_ref(),
pg_wired::PgWireError::Io(_) | pg_wired::PgWireError::ConnectionClosed
),
_ => false,
}
}