use std::sync::Arc;
use deadpool_postgres::{Config, ManagerConfig, RecyclingMethod, Runtime};
#[cfg(not(feature = "tls"))]
use tokio_postgres::NoTls;
use tokio_postgres::types::ToSql;
use crate::error::{BsqlError, BsqlResult, ConnectError};
use crate::singleflight::{FlightStatus, Singleflight, sql_key};
use crate::stream::QueryStream;
use crate::transaction::Transaction;
pub struct Pool {
primary: deadpool_postgres::Pool,
replicas: Vec<deadpool_postgres::Pool>,
replica_idx: std::sync::atomic::AtomicUsize,
pgbouncer: bool,
singleflight: Singleflight,
}
pub struct PoolBuilder {
host: Option<String>,
port: Option<u16>,
dbname: Option<String>,
user: Option<String>,
password: Option<String>,
max_size: usize,
connect_timeout_secs: u64,
replica_urls: Vec<String>,
}
impl PoolBuilder {
pub fn url(mut self, url: &str) -> Result<Self, BsqlError> {
let parsed = parse_pg_url(url)?;
self.host = parsed.host;
self.port = parsed.port;
self.dbname = parsed.dbname;
self.user = parsed.user;
self.password = parsed.password;
Ok(self)
}
pub fn host(mut self, host: &str) -> Self {
self.host = Some(host.into());
self
}
pub fn port(mut self, port: u16) -> Self {
self.port = Some(port);
self
}
pub fn dbname(mut self, dbname: &str) -> Self {
self.dbname = Some(dbname.into());
self
}
pub fn user(mut self, user: &str) -> Self {
self.user = Some(user.into());
self
}
pub fn password(mut self, password: &str) -> Self {
self.password = Some(password.into());
self
}
pub fn max_size(mut self, size: usize) -> Self {
self.max_size = size;
self
}
pub fn connect_timeout(mut self, secs: u64) -> Self {
self.connect_timeout_secs = secs;
self
}
pub fn replica(mut self, url: &str) -> Self {
self.replica_urls.push(url.into());
self
}
pub async fn build(self) -> BsqlResult<Pool> {
let mut cfg = Config::new();
cfg.host = self.host;
cfg.port = self.port;
cfg.dbname = self.dbname;
cfg.user = self.user;
cfg.password = self.password;
cfg.connect_timeout = Some(std::time::Duration::from_secs(self.connect_timeout_secs));
cfg.manager = Some(ManagerConfig {
recycling_method: RecyclingMethod::Fast,
});
cfg.pool = Some(deadpool_postgres::PoolConfig {
max_size: self.max_size,
timeouts: deadpool_postgres::Timeouts {
wait: Some(std::time::Duration::ZERO),
create: None,
recycle: None,
},
..Default::default()
});
let pool = create_deadpool(cfg)?;
let mut pgbouncer = detect_pgbouncer(&pool).await?;
let mut replicas = Vec::with_capacity(self.replica_urls.len());
for url in &self.replica_urls {
let replica_pool =
create_pool_from_url(url, self.max_size, self.connect_timeout_secs).await?;
pgbouncer |= detect_pgbouncer(&replica_pool).await?;
replicas.push(replica_pool);
}
Ok(Pool {
primary: pool,
replicas,
replica_idx: std::sync::atomic::AtomicUsize::new(0),
pgbouncer,
singleflight: Singleflight::new(),
})
}
}
impl Pool {
pub async fn connect(url: &str) -> BsqlResult<Self> {
Pool::builder().url(url)?.build().await
}
pub fn builder() -> PoolBuilder {
PoolBuilder {
host: None,
port: None,
dbname: None,
user: None,
password: None,
max_size: 16,
connect_timeout_secs: 5,
replica_urls: Vec::new(),
}
}
pub async fn acquire(&self) -> BsqlResult<PoolConnection> {
let conn = self.primary.get().await.map_err(BsqlError::from)?;
Ok(PoolConnection { inner: conn })
}
pub fn is_pgbouncer(&self) -> bool {
self.pgbouncer
}
pub fn has_replicas(&self) -> bool {
!self.replicas.is_empty()
}
pub async fn begin(&self) -> BsqlResult<Transaction> {
let conn = self.acquire().await?;
Ok(Transaction::new(conn))
}
pub async fn query_stream(
&self,
sql: &str,
params: &[&(dyn ToSql + Sync)],
) -> BsqlResult<QueryStream> {
let conn = self.acquire().await?;
let stmt = conn
.inner
.prepare_cached(sql)
.await
.map_err(BsqlError::from)?;
let row_stream = conn
.inner
.query_raw(&stmt, params.iter().copied())
.await
.map_err(BsqlError::from)?;
Ok(QueryStream::new(conn, row_stream))
}
pub async fn warmup(&self, sqls: &[&str]) -> BsqlResult<()> {
if sqls.is_empty() {
return Ok(());
}
let conn = self.acquire().await?;
for sql in sqls {
conn.inner
.prepare_cached(sql)
.await
.map_err(BsqlError::from)?;
}
Ok(())
}
pub fn status(&self) -> PoolStatus {
let status = self.primary.status();
PoolStatus {
available: status.available,
size: status.size,
max_size: status.max_size,
}
}
pub(crate) async fn query_raw_primary(
&self,
sql: &str,
params: &[&(dyn ToSql + Sync)],
) -> BsqlResult<Arc<[tokio_postgres::Row]>> {
if params.is_empty() {
let key = sql_key(sql);
self.query_with_singleflight(key, sql, params, false).await
} else {
self.execute_on_pool(sql, params, false).await
}
}
pub(crate) async fn query_raw_read(
&self,
sql: &str,
params: &[&(dyn ToSql + Sync)],
) -> BsqlResult<Arc<[tokio_postgres::Row]>> {
if self.replicas.is_empty() {
return self.query_raw_primary(sql, params).await;
}
if params.is_empty() {
let key = sql_key(sql);
match self.query_with_singleflight(key, sql, params, true).await {
Ok(rows) => Ok(rows),
Err(_) => self.query_with_singleflight(key, sql, params, false).await,
}
} else {
match self.execute_on_pool(sql, params, true).await {
Ok(rows) => Ok(rows),
Err(_) => self.execute_on_pool(sql, params, false).await,
}
}
}
async fn query_with_singleflight(
&self,
key: u64,
sql: &str,
params: &[&(dyn ToSql + Sync)],
use_replica: bool,
) -> BsqlResult<Arc<[tokio_postgres::Row]>> {
match self.singleflight.try_join(key) {
FlightStatus::Follower(mut rx) => {
match rx.recv().await {
Ok(rows) => Ok(rows),
Err(_) => {
self.execute_on_pool(sql, params, use_replica).await
}
}
}
FlightStatus::Leader => match self.execute_on_pool(sql, params, use_replica).await {
Ok(rows) => {
self.singleflight.complete(key, Arc::clone(&rows));
Ok(rows)
}
Err(e) => {
self.singleflight.abandon(key);
Err(e)
}
},
}
}
async fn execute_on_pool(
&self,
sql: &str,
params: &[&(dyn ToSql + Sync)],
use_replica: bool,
) -> BsqlResult<Arc<[tokio_postgres::Row]>> {
let raw_conn = if use_replica && !self.replicas.is_empty() {
let idx = self
.replica_idx
.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
% self.replicas.len();
self.replicas[idx].get().await.map_err(BsqlError::from)?
} else {
self.primary.get().await.map_err(BsqlError::from)?
};
let stmt = raw_conn
.prepare_cached(sql)
.await
.map_err(BsqlError::from)?;
let rows = raw_conn
.query(&stmt, params)
.await
.map_err(BsqlError::from)?;
Ok(Arc::from(rows))
}
}
impl std::fmt::Debug for Pool {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Pool")
.field("status", &self.status())
.field("is_pgbouncer", &self.pgbouncer)
.field("replicas", &self.replicas.len())
.finish()
}
}
pub struct PoolConnection {
pub(crate) inner: deadpool_postgres::Object,
}
#[derive(Debug, Clone, Copy)]
pub struct PoolStatus {
pub available: usize,
pub size: usize,
pub max_size: usize,
}
struct ParsedUrl {
host: Option<String>,
port: Option<u16>,
dbname: Option<String>,
user: Option<String>,
password: Option<String>,
}
fn parse_pg_url(url: &str) -> BsqlResult<ParsedUrl> {
let config: tokio_postgres::Config = url
.parse()
.map_err(|e: tokio_postgres::Error| ConnectError::create(e.to_string()))?;
let host = config.get_hosts().first().map(|h| match h {
tokio_postgres::config::Host::Tcp(s) => s.clone(),
#[cfg(unix)]
tokio_postgres::config::Host::Unix(p) => p.to_string_lossy().into_owned(),
});
let port = config.get_ports().first().copied();
let dbname = config.get_dbname().map(String::from);
let user = config.get_user().map(String::from);
let password = match config.get_password() {
Some(p) => Some(
String::from_utf8(p.to_vec())
.map_err(|_| ConnectError::create("database password contains invalid UTF-8"))?,
),
None => None,
};
Ok(ParsedUrl {
host,
port,
dbname,
user,
password,
})
}
async fn create_pool_from_url(
url: &str,
max_size: usize,
connect_timeout_secs: u64,
) -> BsqlResult<deadpool_postgres::Pool> {
let parsed = parse_pg_url(url)?;
let mut cfg = Config::new();
cfg.host = parsed.host;
cfg.port = parsed.port;
cfg.dbname = parsed.dbname;
cfg.user = parsed.user;
cfg.password = parsed.password;
cfg.connect_timeout = Some(std::time::Duration::from_secs(connect_timeout_secs));
cfg.manager = Some(ManagerConfig {
recycling_method: RecyclingMethod::Fast,
});
cfg.pool = Some(deadpool_postgres::PoolConfig {
max_size,
timeouts: deadpool_postgres::Timeouts {
wait: Some(std::time::Duration::ZERO),
create: None,
recycle: None,
},
..Default::default()
});
let pool = create_deadpool(cfg)?;
let _conn = pool
.get()
.await
.map_err(|e| ConnectError::with_source(format!("failed to connect to replica: {e}"), e))?;
Ok(pool)
}
fn create_deadpool(cfg: Config) -> BsqlResult<deadpool_postgres::Pool> {
#[cfg(feature = "tls")]
{
let tls = make_rustls_connect();
cfg.create_pool(Some(Runtime::Tokio1), tls)
.map_err(|e| ConnectError::create(e.to_string()))
}
#[cfg(not(feature = "tls"))]
{
cfg.create_pool(Some(Runtime::Tokio1), NoTls)
.map_err(|e| ConnectError::create(e.to_string()))
}
}
#[cfg(feature = "tls")]
pub(crate) fn make_rustls_connect() -> tokio_postgres_rustls::MakeRustlsConnect {
let mut roots = rustls::RootCertStore::empty();
roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
let config = rustls::ClientConfig::builder()
.with_root_certificates(roots)
.with_no_client_auth();
tokio_postgres_rustls::MakeRustlsConnect::new(config)
}
async fn detect_pgbouncer(pool: &deadpool_postgres::Pool) -> BsqlResult<bool> {
let conn = pool.get().await.map_err(|e| {
ConnectError::with_source(format!("failed to establish initial connection: {e}"), e)
})?;
Ok(conn.simple_query("SHOW POOLS").await.is_ok())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn builder_defaults() {
let b = Pool::builder();
assert_eq!(b.max_size, 16);
assert_eq!(b.connect_timeout_secs, 5);
assert!(b.replica_urls.is_empty());
}
#[test]
fn builder_config() {
let b = Pool::builder()
.host("localhost")
.port(5432)
.dbname("test")
.user("app")
.password("secret")
.max_size(8)
.connect_timeout(10);
assert_eq!(b.host.as_deref(), Some("localhost"));
assert_eq!(b.port, Some(5432));
assert_eq!(b.dbname.as_deref(), Some("test"));
assert_eq!(b.user.as_deref(), Some("app"));
assert_eq!(b.password.as_deref(), Some("secret"));
assert_eq!(b.max_size, 8);
assert_eq!(b.connect_timeout_secs, 10);
}
#[test]
fn builder_replicas() {
let b = Pool::builder()
.replica("postgres://replica1:5432/db")
.replica("postgres://replica2:5432/db");
assert_eq!(b.replica_urls.len(), 2);
}
}