use std::sync::Arc;
use deadpool_postgres::{Config, ManagerConfig, RecyclingMethod, Runtime};
use tokio_postgres::NoTls;
use tokio_postgres::types::ToSql;
use crate::error::{BsqlError, BsqlResult, ConnectError};
use crate::singleflight::{FlightStatus, Singleflight, sql_key};
use crate::stream::QueryStream;
use crate::transaction::Transaction;
pub struct Pool {
primary: deadpool_postgres::Pool,
replicas: Vec<deadpool_postgres::Pool>,
replica_idx: std::sync::atomic::AtomicUsize,
pgbouncer: PgBouncerInfo,
singleflight: Singleflight,
}
#[derive(Debug, Clone, Copy)]
pub(crate) struct PgBouncerInfo {
detected: bool,
supports_named_stmts: bool,
}
impl PgBouncerInfo {
const DIRECT: Self = Self {
detected: false,
supports_named_stmts: true,
};
}
pub struct PoolBuilder {
host: Option<String>,
port: Option<u16>,
dbname: Option<String>,
user: Option<String>,
password: Option<String>,
max_size: usize,
connect_timeout_secs: u64,
replica_urls: Vec<String>,
}
impl PoolBuilder {
pub fn host(mut self, host: &str) -> Self {
self.host = Some(host.into());
self
}
pub fn port(mut self, port: u16) -> Self {
self.port = Some(port);
self
}
pub fn dbname(mut self, dbname: &str) -> Self {
self.dbname = Some(dbname.into());
self
}
pub fn user(mut self, user: &str) -> Self {
self.user = Some(user.into());
self
}
pub fn password(mut self, password: &str) -> Self {
self.password = Some(password.into());
self
}
pub fn max_size(mut self, size: usize) -> Self {
self.max_size = size;
self
}
pub fn connect_timeout(mut self, secs: u64) -> Self {
self.connect_timeout_secs = secs;
self
}
pub fn replica(mut self, url: &str) -> Self {
self.replica_urls.push(url.into());
self
}
pub async fn build(self) -> BsqlResult<Pool> {
let mut cfg = Config::new();
cfg.host = self.host;
cfg.port = self.port;
cfg.dbname = self.dbname;
cfg.user = self.user;
cfg.password = self.password;
cfg.connect_timeout = Some(std::time::Duration::from_secs(self.connect_timeout_secs));
cfg.manager = Some(ManagerConfig {
recycling_method: RecyclingMethod::Fast,
});
cfg.pool = Some(deadpool_postgres::PoolConfig {
max_size: self.max_size,
timeouts: deadpool_postgres::Timeouts {
wait: Some(std::time::Duration::ZERO),
create: None,
recycle: None,
},
..Default::default()
});
let pool = cfg
.create_pool(Some(Runtime::Tokio1), NoTls)
.map_err(|e| ConnectError::create(e.to_string()))?;
let pgbouncer = detect_pgbouncer(&pool).await?;
let mut replicas = Vec::with_capacity(self.replica_urls.len());
for url in &self.replica_urls {
let replica_pool = create_pool_from_url(url, self.max_size).await?;
replicas.push(replica_pool);
}
Ok(Pool {
primary: pool,
replicas,
replica_idx: std::sync::atomic::AtomicUsize::new(0),
pgbouncer,
singleflight: Singleflight::new(),
})
}
}
impl Pool {
pub async fn connect(url: &str) -> BsqlResult<Self> {
let config: tokio_postgres::Config = url
.parse()
.map_err(|e: tokio_postgres::Error| ConnectError::create(e.to_string()))?;
let mut cfg = Config::new();
cfg.host = config.get_hosts().first().map(|h| match h {
tokio_postgres::config::Host::Tcp(s) => s.clone(),
#[cfg(unix)]
tokio_postgres::config::Host::Unix(p) => p.to_string_lossy().into_owned(),
});
cfg.port = config.get_ports().first().copied();
cfg.dbname = config.get_dbname().map(String::from);
cfg.user = config.get_user().map(String::from);
cfg.password =
match config.get_password() {
Some(p) => Some(String::from_utf8(p.to_vec()).map_err(|_| {
ConnectError::create("database password contains invalid UTF-8")
})?),
None => None,
};
cfg.connect_timeout = Some(std::time::Duration::from_secs(5));
cfg.manager = Some(ManagerConfig {
recycling_method: RecyclingMethod::Fast,
});
cfg.pool = Some(deadpool_postgres::PoolConfig {
max_size: 16,
timeouts: deadpool_postgres::Timeouts {
wait: Some(std::time::Duration::ZERO),
create: None,
recycle: None,
},
..Default::default()
});
let pool = cfg
.create_pool(Some(Runtime::Tokio1), NoTls)
.map_err(|e| ConnectError::create(e.to_string()))?;
let pgbouncer = detect_pgbouncer(&pool).await?;
Ok(Pool {
primary: pool,
replicas: Vec::new(),
replica_idx: std::sync::atomic::AtomicUsize::new(0),
pgbouncer,
singleflight: Singleflight::new(),
})
}
pub fn builder() -> PoolBuilder {
PoolBuilder {
host: None,
port: None,
dbname: None,
user: None,
password: None,
max_size: 16,
connect_timeout_secs: 5,
replica_urls: Vec::new(),
}
}
pub async fn acquire(&self) -> BsqlResult<PoolConnection> {
let conn = self.primary.get().await.map_err(BsqlError::from)?;
Ok(PoolConnection {
inner: conn,
pgbouncer: self.pgbouncer,
})
}
pub fn is_pgbouncer(&self) -> bool {
self.pgbouncer.detected
}
pub fn supports_named_statements(&self) -> bool {
self.pgbouncer.supports_named_stmts
}
pub fn has_replicas(&self) -> bool {
!self.replicas.is_empty()
}
pub async fn begin(&self) -> BsqlResult<Transaction> {
let conn = self.acquire().await?;
conn.inner
.batch_execute("BEGIN")
.await
.map_err(BsqlError::from)?;
Ok(Transaction::new(conn))
}
pub async fn query_stream(
&self,
sql: &str,
params: &[&(dyn ToSql + Sync)],
) -> BsqlResult<QueryStream> {
let conn = self.acquire().await?;
let stmt = conn
.inner
.prepare_cached(sql)
.await
.map_err(BsqlError::from)?;
let row_stream = conn
.inner
.query_raw(&stmt, params.iter().copied())
.await
.map_err(BsqlError::from)?;
Ok(QueryStream::new(conn, row_stream))
}
pub fn status(&self) -> PoolStatus {
let status = self.primary.status();
PoolStatus {
available: status.available,
size: status.size,
max_size: status.max_size,
}
}
pub(crate) async fn query_raw_primary(
&self,
sql: &str,
params: &[&(dyn ToSql + Sync)],
) -> BsqlResult<Arc<Vec<tokio_postgres::Row>>> {
if params.is_empty() {
let key = sql_key(sql);
self.query_with_singleflight(key, sql, params, false).await
} else {
self.execute_on_pool(sql, params, false).await
}
}
pub(crate) async fn query_raw_read(
&self,
sql: &str,
params: &[&(dyn ToSql + Sync)],
) -> BsqlResult<Arc<Vec<tokio_postgres::Row>>> {
if self.replicas.is_empty() {
return self.query_raw_primary(sql, params).await;
}
if params.is_empty() {
let key = sql_key(sql);
match self.query_with_singleflight(key, sql, params, true).await {
Ok(rows) => Ok(rows),
Err(_) => self.query_with_singleflight(key, sql, params, false).await,
}
} else {
match self.execute_on_pool(sql, params, true).await {
Ok(rows) => Ok(rows),
Err(_) => self.execute_on_pool(sql, params, false).await,
}
}
}
async fn query_with_singleflight(
&self,
key: u64,
sql: &str,
params: &[&(dyn ToSql + Sync)],
use_replica: bool,
) -> BsqlResult<Arc<Vec<tokio_postgres::Row>>> {
match self.singleflight.try_join(key) {
FlightStatus::Follower(mut rx) => {
match rx.recv().await {
Ok(rows) => Ok(rows),
Err(_) => {
self.execute_on_pool(sql, params, use_replica).await
}
}
}
FlightStatus::Leader => match self.execute_on_pool(sql, params, use_replica).await {
Ok(rows) => {
self.singleflight.complete(key, Arc::clone(&rows));
Ok(rows)
}
Err(e) => {
self.singleflight.abandon(key);
Err(e)
}
},
}
}
async fn execute_on_pool(
&self,
sql: &str,
params: &[&(dyn ToSql + Sync)],
use_replica: bool,
) -> BsqlResult<Arc<Vec<tokio_postgres::Row>>> {
let raw_conn = if use_replica && !self.replicas.is_empty() {
let idx = self
.replica_idx
.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
% self.replicas.len();
self.replicas[idx].get().await.map_err(BsqlError::from)?
} else {
self.primary.get().await.map_err(BsqlError::from)?
};
let stmt = raw_conn
.prepare_cached(sql)
.await
.map_err(BsqlError::from)?;
let rows = raw_conn
.query(&stmt, params)
.await
.map_err(BsqlError::from)?;
Ok(Arc::new(rows))
}
}
pub struct PoolConnection {
pub(crate) inner: deadpool_postgres::Object,
pub(crate) pgbouncer: PgBouncerInfo,
}
impl PoolConnection {
pub fn supports_named_statements(&self) -> bool {
self.pgbouncer.supports_named_stmts
}
}
#[derive(Debug, Clone, Copy)]
pub struct PoolStatus {
pub available: usize,
pub size: usize,
pub max_size: usize,
}
async fn create_pool_from_url(url: &str, max_size: usize) -> BsqlResult<deadpool_postgres::Pool> {
let config: tokio_postgres::Config = url
.parse()
.map_err(|e: tokio_postgres::Error| ConnectError::create(e.to_string()))?;
let mut cfg = Config::new();
cfg.host = config.get_hosts().first().map(|h| match h {
tokio_postgres::config::Host::Tcp(s) => s.clone(),
#[cfg(unix)]
tokio_postgres::config::Host::Unix(p) => p.to_string_lossy().into_owned(),
});
cfg.port = config.get_ports().first().copied();
cfg.dbname = config.get_dbname().map(String::from);
cfg.user = config.get_user().map(String::from);
cfg.password = match config.get_password() {
Some(p) => Some(
String::from_utf8(p.to_vec())
.map_err(|_| ConnectError::create("database password contains invalid UTF-8"))?,
),
None => None,
};
cfg.connect_timeout = Some(std::time::Duration::from_secs(5));
cfg.manager = Some(ManagerConfig {
recycling_method: RecyclingMethod::Fast,
});
cfg.pool = Some(deadpool_postgres::PoolConfig {
max_size,
timeouts: deadpool_postgres::Timeouts {
wait: Some(std::time::Duration::ZERO),
create: None,
recycle: None,
},
..Default::default()
});
let pool = cfg
.create_pool(Some(Runtime::Tokio1), NoTls)
.map_err(|e| ConnectError::create(e.to_string()))?;
let _conn = pool
.get()
.await
.map_err(|e| ConnectError::with_source(format!("failed to connect to replica: {e}"), e))?;
Ok(pool)
}
async fn detect_pgbouncer(pool: &deadpool_postgres::Pool) -> BsqlResult<PgBouncerInfo> {
let conn = pool.get().await.map_err(|e| {
ConnectError::with_source(format!("failed to establish initial connection: {e}"), e)
})?;
let is_pgbouncer = conn.simple_query("SHOW POOLS").await.is_ok();
if !is_pgbouncer {
return Ok(PgBouncerInfo::DIRECT);
}
let supports_named = match conn.simple_query("SHOW CONFIG").await {
Ok(messages) => messages.iter().any(|msg| {
if let tokio_postgres::SimpleQueryMessage::Row(row) = msg {
row.get(0) == Some("prepared_statements") && row.get(1) == Some("yes")
} else {
false
}
}),
Err(_) => false,
};
Ok(PgBouncerInfo {
detected: true,
supports_named_stmts: supports_named,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn builder_defaults() {
let b = Pool::builder();
assert_eq!(b.max_size, 16);
assert_eq!(b.connect_timeout_secs, 5);
assert!(b.replica_urls.is_empty());
}
#[test]
fn builder_config() {
let b = Pool::builder()
.host("localhost")
.port(5432)
.dbname("test")
.user("app")
.password("secret")
.max_size(8)
.connect_timeout(10);
assert_eq!(b.host.as_deref(), Some("localhost"));
assert_eq!(b.port, Some(5432));
assert_eq!(b.dbname.as_deref(), Some("test"));
assert_eq!(b.user.as_deref(), Some("app"));
assert_eq!(b.password.as_deref(), Some("secret"));
assert_eq!(b.max_size, 8);
assert_eq!(b.connect_timeout_secs, 10);
}
#[test]
fn builder_replicas() {
let b = Pool::builder()
.replica("postgres://replica1:5432/db")
.replica("postgres://replica2:5432/db");
assert_eq!(b.replica_urls.len(), 2);
}
#[test]
fn pgbouncer_direct_defaults() {
let info = PgBouncerInfo::DIRECT;
assert!(!info.detected);
assert!(info.supports_named_stmts);
}
#[test]
fn pool_status_type_is_copy() {
fn assert_copy<T: Copy>() {}
assert_copy::<PoolStatus>();
}
}