use std::sync::Arc;
use pg_pool::async_wire::AsyncPoolable;
use pg_pool::{ConnPool, ConnPoolConfig, LifecycleHooks, PoolError, PoolGuard};
use crate::encode::SqlParam;
use crate::error::TypedError;
use crate::row::Row;
#[derive(Debug)]
pub struct ExclusivePool {
pool: Arc<ConnPool<AsyncPoolable>>,
}
impl ExclusivePool {
pub async fn new(
config: ConnPoolConfig,
hooks: LifecycleHooks<AsyncPoolable>,
) -> Result<Self, PoolError<pg_wired::PgWireError>> {
let pool = ConnPool::new(config, hooks).await?;
Ok(Self { pool })
}
pub async fn connect(
addr: &str,
user: &str,
password: &str,
database: &str,
max_size: usize,
) -> Result<Self, PoolError<pg_wired::PgWireError>> {
let mut config = ConnPoolConfig::default();
config.addr = addr.to_string();
config.user = user.to_string();
config.password = password.to_string();
config.database = database.to_string();
config.max_size = max_size;
Self::new(config, LifecycleHooks::default()).await
}
pub async fn get(&self) -> Result<PooledClient, TypedError> {
tracing::debug!("pool checkout");
crate::metrics::record_pool_checkout();
let guard = self.pool.get().await.map_err(|e| {
tracing::warn!(error = %e, "pool checkout failed");
crate::metrics::record_pool_timeout();
TypedError::from(e)
})?;
Ok(PooledClient { guard })
}
pub fn metrics(&self) -> pg_pool::PoolMetrics {
self.pool.metrics()
}
pub async fn warm_up(&self, target: usize) {
self.pool.warm_up(target).await;
}
pub async fn drain(&self) {
self.pool.drain().await;
}
}
pub struct PooledClient {
guard: PoolGuard<AsyncPoolable>,
}
impl std::fmt::Debug for PooledClient {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PooledClient").finish_non_exhaustive()
}
}
impl PooledClient {
pub fn conn(&self) -> &pg_wired::AsyncConn {
self.guard.conn()
}
pub async fn query(&self, sql: &str, params: &[&dyn SqlParam]) -> Result<Vec<Row>, TypedError> {
crate::query::Client::query_on_conn(self.guard.conn(), sql, params).await
}
pub async fn execute(&self, sql: &str, params: &[&dyn SqlParam]) -> Result<u64, TypedError> {
crate::query::Client::execute_on_conn(self.guard.conn(), sql, params).await
}
pub async fn simple_query(&self, sql: &str) -> Result<(), TypedError> {
self.guard.conn().mark_state_mutated();
crate::query::Client::simple_query_on_conn(self.guard.conn(), sql).await
}
pub async fn copy_in(&self, copy_sql: &str, data: &[u8]) -> Result<u64, TypedError> {
self.guard.conn().mark_state_mutated();
self.guard
.conn()
.copy_in(copy_sql, data)
.await
.map_err(|e| TypedError::from(e).with_sql(copy_sql))
}
pub async fn copy_out(&self, copy_sql: &str) -> Result<Vec<u8>, TypedError> {
self.guard.conn().mark_state_mutated();
self.guard
.conn()
.copy_out(copy_sql)
.await
.map_err(|e| TypedError::from(e).with_sql(copy_sql))
}
pub fn is_alive(&self) -> bool {
self.guard.conn().is_alive()
}
pub fn cancel_token(&self) -> pg_wired::CancelToken {
self.guard.conn().cancel_token()
}
pub async fn advisory_lock(&self, key: i64) -> Result<(), TypedError> {
self.guard.conn().mark_state_mutated();
self.execute("SELECT pg_advisory_lock($1)", &[&key]).await?;
Ok(())
}
pub async fn try_advisory_lock(&self, key: i64) -> Result<bool, TypedError> {
self.guard.conn().mark_state_mutated();
let rows = self
.query("SELECT pg_try_advisory_lock($1)", &[&key])
.await?;
rows[0].get::<bool>(0)
}
pub async fn advisory_unlock(&self, key: i64) -> Result<bool, TypedError> {
let rows = self.query("SELECT pg_advisory_unlock($1)", &[&key]).await?;
rows[0].get::<bool>(0)
}
pub async fn advisory_xact_lock(&self, key: i64) -> Result<(), TypedError> {
self.execute("SELECT pg_advisory_xact_lock($1)", &[&key])
.await?;
Ok(())
}
pub async fn try_advisory_xact_lock(&self, key: i64) -> Result<bool, TypedError> {
let rows = self
.query("SELECT pg_try_advisory_xact_lock($1)", &[&key])
.await?;
rows[0].get::<bool>(0)
}
pub async fn begin(&self) -> Result<PooledTransaction<'_>, TypedError> {
self.simple_query("BEGIN").await?;
Ok(PooledTransaction {
client: self,
done: false,
})
}
pub async fn begin_with(
&self,
level: crate::IsolationLevel,
) -> Result<PooledTransaction<'_>, TypedError> {
let sql = format!("BEGIN ISOLATION LEVEL {}", level.as_sql());
self.simple_query(&sql).await?;
Ok(PooledTransaction {
client: self,
done: false,
})
}
}
pub struct PooledTransaction<'a> {
client: &'a PooledClient,
done: bool,
}
impl<'a> std::fmt::Debug for PooledTransaction<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PooledTransaction")
.field("done", &self.done)
.finish_non_exhaustive()
}
}
impl<'a> PooledTransaction<'a> {
pub fn client(&self) -> &'a PooledClient {
self.client
}
pub async fn query(&self, sql: &str, params: &[&dyn SqlParam]) -> Result<Vec<Row>, TypedError> {
self.client.query(sql, params).await
}
pub async fn execute(&self, sql: &str, params: &[&dyn SqlParam]) -> Result<u64, TypedError> {
self.client.execute(sql, params).await
}
pub async fn query_named(
&self,
sql: &str,
params: &[(&str, &dyn SqlParam)],
) -> Result<Vec<Row>, TypedError> {
let (rewritten, names) = crate::named_params::rewrite(sql);
let ordered = crate::query::resolve_named_params(&names, params)?;
self.client.query(&rewritten, &ordered).await
}
pub async fn execute_named(
&self,
sql: &str,
params: &[(&str, &dyn SqlParam)],
) -> Result<u64, TypedError> {
let (rewritten, names) = crate::named_params::rewrite(sql);
let ordered = crate::query::resolve_named_params(&names, params)?;
self.client.execute(&rewritten, &ordered).await
}
pub async fn commit(mut self) -> Result<(), TypedError> {
self.done = true;
self.client.simple_query("COMMIT").await
}
pub async fn rollback(mut self) -> Result<(), TypedError> {
self.done = true;
self.client.simple_query("ROLLBACK").await
}
}
impl<'a> Drop for PooledTransaction<'a> {
fn drop(&mut self) {
if !self.done && self.client.is_alive() {
if !self.client.conn().enqueue_rollback() {
self.client.conn().mark_broken();
tracing::warn!(
"PooledTransaction dropped without commit/rollback; could not queue ROLLBACK, connection marked broken"
);
}
}
}
}