use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use tokio::sync::RwLock;
use crate::async_conn::AsyncConn;
use crate::connection::WireConn;
use crate::error::PgWireError;
use crate::protocol::types::RawRow;
use crate::tls::TlsMode;
#[derive(Clone)]
#[non_exhaustive]
pub struct ConnConfig {
pub addr: String,
pub user: String,
pub password: String,
pub database: String,
pub tls_mode: TlsMode,
}
impl std::fmt::Debug for ConnConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ConnConfig")
.field("addr", &self.addr)
.field("user", &self.user)
.field("password", &"<redacted>")
.field("database", &self.database)
.field("tls_mode", &self.tls_mode)
.finish()
}
}
pub struct AsyncPool {
conns: Vec<RwLock<Arc<AsyncConn>>>,
config: ConnConfig,
counter: AtomicUsize,
}
impl std::fmt::Debug for AsyncPool {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AsyncPool")
.field("size", &self.conns.len())
.field("config", &self.config)
.finish()
}
}
impl AsyncPool {
pub async fn connect(
addr: &str,
user: &str,
password: &str,
database: &str,
size: usize,
) -> Result<Arc<Self>, PgWireError> {
Self::connect_with_tls(addr, user, password, database, size, TlsMode::default()).await
}
pub async fn connect_with_tls(
addr: &str,
user: &str,
password: &str,
database: &str,
size: usize,
tls_mode: TlsMode,
) -> Result<Arc<Self>, PgWireError> {
if size == 0 {
return Err(PgWireError::Protocol("pool size must be >= 1".into()));
}
let config = ConnConfig {
addr: addr.to_string(),
user: user.to_string(),
password: password.to_string(),
database: database.to_string(),
tls_mode,
};
let mut conns = Vec::with_capacity(size);
for _ in 0..size {
let wire =
WireConn::connect_with_options(addr, user, password, database, &[], tls_mode)
.await?;
conns.push(RwLock::new(Arc::new(AsyncConn::new(wire))));
}
let pool = Arc::new(Self {
conns,
config,
counter: AtomicUsize::new(0),
});
{
let pool_weak = Arc::downgrade(&pool);
tokio::spawn(async move {
health_monitor(pool_weak).await;
});
}
Ok(pool)
}
pub async fn get_async(&self) -> Arc<AsyncConn> {
let len = self.conns.len();
let start = self.counter.fetch_add(1, Ordering::Relaxed) % len;
for i in 0..len {
let idx = (start + i) % len;
let conn = self.conns[idx].read().await;
if conn.is_alive() {
return Arc::clone(&conn);
}
}
let conn = self.conns[start % len].read().await;
Arc::clone(&conn)
}
async fn reconnect(&self, idx: usize) -> Result<(), PgWireError> {
let wire = WireConn::connect_with_options(
&self.config.addr,
&self.config.user,
&self.config.password,
&self.config.database,
&[],
self.config.tls_mode,
)
.await?;
let new_conn = Arc::new(AsyncConn::new(wire));
let mut slot = self.conns[idx].write().await;
*slot = new_conn;
tracing::info!("pg-wired: reconnected slot {idx}");
Ok(())
}
pub fn size(&self) -> usize {
self.conns.len()
}
pub async fn alive_count(&self) -> usize {
let mut count = 0;
for slot in &self.conns {
let conn = slot.read().await;
if conn.is_alive() {
count += 1;
}
}
count
}
pub async fn close(&self) -> Result<(), PgWireError> {
for slot in &self.conns {
let conn = slot.read().await;
let _ = conn.close().await;
}
Ok(())
}
pub async fn exec_transaction(
&self,
setup_sql: &str,
query_sql: &str,
params: &[Option<&[u8]>],
param_oids: &[u32],
) -> Result<Vec<RawRow>, PgWireError> {
self.get_async()
.await
.exec_transaction(setup_sql, query_sql, params, param_oids)
.await
}
pub async fn exec_query(
&self,
sql: &str,
params: &[Option<&[u8]>],
param_oids: &[u32],
) -> Result<Vec<RawRow>, PgWireError> {
self.get_async()
.await
.exec_query(sql, params, param_oids)
.await
}
pub async fn exec_query_with_formats(
&self,
sql: &str,
params: &[Option<&[u8]>],
param_oids: &[u32],
param_formats: &[crate::protocol::types::FormatCode],
result_formats: &[crate::protocol::types::FormatCode],
) -> Result<Vec<RawRow>, PgWireError> {
self.get_async()
.await
.exec_query_with_formats(sql, params, param_oids, param_formats, result_formats)
.await
}
}
async fn health_monitor(pool_weak: std::sync::Weak<AsyncPool>) {
let mut interval = tokio::time::interval(std::time::Duration::from_secs(5));
loop {
interval.tick().await;
let pool = match pool_weak.upgrade() {
Some(p) => p,
None => {
tracing::debug!("pg-wired: health monitor stopping (pool dropped)");
return;
}
};
for idx in 0..pool.conns.len() {
let is_dead = {
let conn = pool.conns[idx].read().await;
!conn.is_alive()
};
if is_dead {
tracing::warn!("pg-wired: slot {idx} is dead, reconnecting...");
match pool.reconnect(idx).await {
Ok(()) => {}
Err(e) => {
tracing::error!("pg-wired: reconnect slot {idx} failed: {e}");
}
}
}
}
}
}