use deadpool_postgres::{Config, ManagerConfig, RecyclingMethod, Runtime};
use tokio_postgres::NoTls;
use tokio_postgres::types::ToSql;
use crate::error::{BsqlError, BsqlResult, ConnectError};
use crate::stream::QueryStream;
use crate::transaction::Transaction;
pub struct Pool {
inner: deadpool_postgres::Pool,
pgbouncer: PgBouncerInfo,
}
#[derive(Debug, Clone, Copy)]
pub(crate) struct PgBouncerInfo {
detected: bool,
supports_named_stmts: bool,
}
impl PgBouncerInfo {
const DIRECT: Self = Self {
detected: false,
supports_named_stmts: true,
};
}
pub struct PoolBuilder {
host: Option<String>,
port: Option<u16>,
dbname: Option<String>,
user: Option<String>,
password: Option<String>,
max_size: usize,
connect_timeout_secs: u64,
}
impl PoolBuilder {
pub fn host(mut self, host: &str) -> Self {
self.host = Some(host.into());
self
}
pub fn port(mut self, port: u16) -> Self {
self.port = Some(port);
self
}
pub fn dbname(mut self, dbname: &str) -> Self {
self.dbname = Some(dbname.into());
self
}
pub fn user(mut self, user: &str) -> Self {
self.user = Some(user.into());
self
}
pub fn password(mut self, password: &str) -> Self {
self.password = Some(password.into());
self
}
pub fn max_size(mut self, size: usize) -> Self {
self.max_size = size;
self
}
pub fn connect_timeout(mut self, secs: u64) -> Self {
self.connect_timeout_secs = secs;
self
}
pub async fn build(self) -> BsqlResult<Pool> {
let mut cfg = Config::new();
cfg.host = self.host;
cfg.port = self.port;
cfg.dbname = self.dbname;
cfg.user = self.user;
cfg.password = self.password;
cfg.connect_timeout = Some(std::time::Duration::from_secs(self.connect_timeout_secs));
cfg.manager = Some(ManagerConfig {
recycling_method: RecyclingMethod::Fast,
});
cfg.pool = Some(deadpool_postgres::PoolConfig {
max_size: self.max_size,
timeouts: deadpool_postgres::Timeouts {
wait: Some(std::time::Duration::ZERO),
create: None,
recycle: None,
},
..Default::default()
});
let pool = cfg
.create_pool(Some(Runtime::Tokio1), NoTls)
.map_err(|e| ConnectError::create(e.to_string()))?;
let pgbouncer = detect_pgbouncer(&pool).await?;
Ok(Pool {
inner: pool,
pgbouncer,
})
}
}
impl Pool {
pub async fn connect(url: &str) -> BsqlResult<Self> {
let config: tokio_postgres::Config = url
.parse()
.map_err(|e: tokio_postgres::Error| ConnectError::create(e.to_string()))?;
let mut cfg = Config::new();
cfg.host = config.get_hosts().first().map(|h| match h {
tokio_postgres::config::Host::Tcp(s) => s.clone(),
#[cfg(unix)]
tokio_postgres::config::Host::Unix(p) => p.to_string_lossy().into_owned(),
});
cfg.port = config.get_ports().first().copied();
cfg.dbname = config.get_dbname().map(String::from);
cfg.user = config.get_user().map(String::from);
cfg.password = config
.get_password()
.map(|p| String::from_utf8_lossy(p).into_owned());
cfg.connect_timeout = Some(std::time::Duration::from_secs(5));
cfg.manager = Some(ManagerConfig {
recycling_method: RecyclingMethod::Fast,
});
cfg.pool = Some(deadpool_postgres::PoolConfig {
max_size: 16,
timeouts: deadpool_postgres::Timeouts {
wait: Some(std::time::Duration::ZERO),
create: None,
recycle: None,
},
..Default::default()
});
let pool = cfg
.create_pool(Some(Runtime::Tokio1), NoTls)
.map_err(|e| ConnectError::create(e.to_string()))?;
let pgbouncer = detect_pgbouncer(&pool).await?;
Ok(Pool {
inner: pool,
pgbouncer,
})
}
pub fn builder() -> PoolBuilder {
PoolBuilder {
host: None,
port: None,
dbname: None,
user: None,
password: None,
max_size: 16,
connect_timeout_secs: 5,
}
}
pub async fn acquire(&self) -> BsqlResult<PoolConnection> {
let conn = self.inner.get().await.map_err(BsqlError::from)?;
Ok(PoolConnection {
inner: conn,
pgbouncer: self.pgbouncer,
})
}
pub fn is_pgbouncer(&self) -> bool {
self.pgbouncer.detected
}
pub fn supports_named_statements(&self) -> bool {
self.pgbouncer.supports_named_stmts
}
pub async fn begin(&self) -> BsqlResult<Transaction> {
let conn = self.acquire().await?;
conn.inner
.batch_execute("BEGIN")
.await
.map_err(BsqlError::from)?;
Ok(Transaction::new(conn))
}
pub async fn query_stream(
&self,
sql: &str,
params: &[&(dyn ToSql + Sync)],
) -> BsqlResult<QueryStream> {
let conn = self.acquire().await?;
let stmt = conn
.inner
.prepare_cached(sql)
.await
.map_err(BsqlError::from)?;
let row_stream = conn
.inner
.query_raw(&stmt, params.iter().copied())
.await
.map_err(BsqlError::from)?;
Ok(QueryStream::new(conn, row_stream))
}
pub fn status(&self) -> PoolStatus {
let status = self.inner.status();
PoolStatus {
available: status.available,
size: status.size,
max_size: status.max_size,
}
}
}
pub struct PoolConnection {
pub(crate) inner: deadpool_postgres::Object,
pub(crate) pgbouncer: PgBouncerInfo,
}
impl PoolConnection {
pub fn supports_named_statements(&self) -> bool {
self.pgbouncer.supports_named_stmts
}
}
#[derive(Debug, Clone, Copy)]
pub struct PoolStatus {
pub available: usize,
pub size: usize,
pub max_size: usize,
}
async fn detect_pgbouncer(pool: &deadpool_postgres::Pool) -> BsqlResult<PgBouncerInfo> {
let conn = pool.get().await.map_err(|e| {
ConnectError::with_source(format!("failed to establish initial connection: {e}"), e)
})?;
let is_pgbouncer = conn.simple_query("SHOW POOLS").await.is_ok();
if !is_pgbouncer {
return Ok(PgBouncerInfo::DIRECT);
}
let supports_named = match conn.simple_query("SHOW CONFIG").await {
Ok(messages) => messages.iter().any(|msg| {
if let tokio_postgres::SimpleQueryMessage::Row(row) = msg {
row.get(0) == Some("prepared_statements") && row.get(1) == Some("yes")
} else {
false
}
}),
Err(_) => false,
};
Ok(PgBouncerInfo {
detected: true,
supports_named_stmts: supports_named,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn builder_defaults() {
let b = Pool::builder();
assert_eq!(b.max_size, 16);
assert_eq!(b.connect_timeout_secs, 5);
}
#[test]
fn builder_config() {
let b = Pool::builder()
.host("localhost")
.port(5432)
.dbname("test")
.user("app")
.password("secret")
.max_size(8)
.connect_timeout(10);
assert_eq!(b.host.as_deref(), Some("localhost"));
assert_eq!(b.port, Some(5432));
assert_eq!(b.dbname.as_deref(), Some("test"));
assert_eq!(b.user.as_deref(), Some("app"));
assert_eq!(b.password.as_deref(), Some("secret"));
assert_eq!(b.max_size, 8);
assert_eq!(b.connect_timeout_secs, 10);
}
#[test]
fn pgbouncer_direct_defaults() {
let info = PgBouncerInfo::DIRECT;
assert!(!info.detected);
assert!(info.supports_named_stmts);
}
#[test]
fn pool_status_type_is_copy() {
fn assert_copy<T: Copy>() {}
assert_copy::<PoolStatus>();
}
}