use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::time::Duration;
use crate::DriverError;
use crate::arena::Arena;
use crate::codec::Encode;
use crate::conn::Connection;
use crate::types::{Config, PgDataRow, QueryResult, SimpleRow};
#[cfg(feature = "async")]
use crate::async_conn::AsyncConnection;
pub(crate) enum PoolSlot {
Sync(Connection),
#[cfg(feature = "async")]
Async(AsyncConnection),
}
pub struct Pool {
inner: Arc<PoolInner>,
}
struct PoolInner {
stack: std::sync::Mutex<Vec<PoolSlot>>,
max_size: usize,
open_count: AtomicUsize,
config: Arc<Config>,
closed: AtomicBool,
release_pair: (std::sync::Mutex<()>, std::sync::Condvar),
max_lifetime: Option<Duration>,
acquire_timeout: Option<Duration>,
min_idle: usize,
warmup_sqls: std::sync::Mutex<Arc<Vec<Box<str>>>>,
max_stmt_cache_size: usize,
stale_timeout: Duration,
}
impl Pool {
pub fn connect(url: &str) -> Result<Self, DriverError> {
PoolBuilder::new().url(url).build()
}
pub fn builder() -> PoolBuilder {
PoolBuilder::new()
}
#[inline]
pub fn acquire(&self) -> Result<PoolGuard, DriverError> {
if self.inner.closed.load(Ordering::Acquire) {
return Err(DriverError::Pool("pool is closed".into()));
}
if let Some(guard) = self.try_pop_idle()? {
return Ok(guard);
}
loop {
let current = self.inner.open_count.load(Ordering::Acquire);
if current >= self.inner.max_size {
if let Some(timeout) = self.inner.acquire_timeout {
let (lock, cvar) = &self.inner.release_pair;
let guard = lock.lock().unwrap_or_else(|e| e.into_inner());
let (_guard, result) = cvar
.wait_timeout(guard, timeout)
.unwrap_or_else(|e| e.into_inner());
if result.timed_out() {
return Err(DriverError::Pool(
"pool exhausted: acquire timeout expired".into(),
));
}
if let Some(guard) = self.try_pop_idle()? {
return Ok(guard);
}
continue;
}
return Err(DriverError::Pool(
"pool exhausted: all connections in use".into(),
));
}
if self
.inner
.open_count
.compare_exchange(current, current + 1, Ordering::AcqRel, Ordering::Acquire)
.is_ok()
{
break;
}
}
let conn_result = Connection::connect_arc(self.inner.config.clone());
match conn_result {
Ok(mut conn) => {
conn.set_max_stmt_cache_size(self.inner.max_stmt_cache_size);
self.warmup_conn(&mut conn);
Ok(PoolGuard {
conn: Some(PoolSlot::Sync(conn)),
pool: self.inner.clone(),
discard: false,
})
}
Err(e) => {
self.inner.open_count.fetch_sub(1, Ordering::AcqRel);
Err(e)
}
}
}
#[inline]
fn try_pop_idle(&self) -> Result<Option<PoolGuard>, DriverError> {
let mut stack = self.inner.stack.lock().unwrap_or_else(|e| e.into_inner());
while let Some(mut slot) = stack.pop() {
let (created_at, idle_dur) = match &slot {
PoolSlot::Sync(conn) => (conn.created_at(), conn.idle_duration()),
#[cfg(feature = "async")]
PoolSlot::Async(conn) => (conn.created_at(), conn.idle_duration()),
};
if let Some(max_lifetime) = self.inner.max_lifetime {
if created_at.elapsed() >= max_lifetime {
self.inner.open_count.fetch_sub(1, Ordering::AcqRel);
continue;
}
}
if idle_dur >= self.inner.stale_timeout {
self.inner.open_count.fetch_sub(1, Ordering::AcqRel);
continue;
}
if idle_dur > Duration::from_secs(5) {
let alive = match &mut slot {
PoolSlot::Sync(conn) => conn.simple_query("").is_ok(),
#[cfg(feature = "async")]
PoolSlot::Async(_) => true, };
if !alive {
self.inner.open_count.fetch_sub(1, Ordering::AcqRel);
continue;
}
}
return Ok(Some(PoolGuard {
conn: Some(slot),
pool: self.inner.clone(),
discard: false,
}));
}
Ok(None)
}
pub fn is_uds(&self) -> bool {
#[cfg(unix)]
{
self.inner.config.host_is_uds()
}
#[cfg(not(unix))]
{
false
}
}
pub fn begin(&self) -> Result<Transaction, DriverError> {
let mut guard = self.acquire()?;
guard.simple_query("BEGIN")?;
Ok(Transaction {
guard,
committed: false,
deferred_buf: Vec::new(),
deferred_count: 0,
})
}
pub fn open_count(&self) -> usize {
self.inner.open_count.load(Ordering::Relaxed)
}
pub fn max_size(&self) -> usize {
self.inner.max_size
}
pub fn status(&self) -> PoolStatus {
let idle = self
.inner
.stack
.lock()
.unwrap_or_else(|e| e.into_inner())
.len();
let open = self.inner.open_count.load(Ordering::Relaxed);
let active = open.saturating_sub(idle);
PoolStatus {
idle,
active,
open,
max_size: self.inner.max_size,
}
}
fn warmup_conn(&self, conn: &mut Connection) {
let sqls = self
.inner
.warmup_sqls
.lock()
.unwrap_or_else(|e| e.into_inner())
.clone();
if sqls.is_empty() {
return;
}
for sql in sqls.iter() {
let sql_hash = crate::types::hash_sql(sql);
let _ = conn.prepare_only(sql, sql_hash);
}
}
pub fn set_warmup_sqls(&self, sqls: &[&str]) {
let boxed: Arc<Vec<Box<str>>> =
Arc::new(sqls.iter().map(|s| (*s).into()).collect::<Vec<_>>());
*self
.inner
.warmup_sqls
.lock()
.unwrap_or_else(|e| e.into_inner()) = boxed;
}
pub fn close(&self) {
self.inner.closed.store(true, Ordering::Release);
let slots: Vec<PoolSlot> = {
let mut stack = self.inner.stack.lock().unwrap_or_else(|e| e.into_inner());
std::mem::take(&mut *stack)
};
for slot in slots {
self.inner.open_count.fetch_sub(1, Ordering::AcqRel);
match slot {
PoolSlot::Sync(conn) => {
let _ = conn.close();
}
#[cfg(feature = "async")]
PoolSlot::Async(_conn) => {
}
}
}
let (_, cvar) = &self.inner.release_pair;
cvar.notify_all();
}
pub fn is_closed(&self) -> bool {
self.inner.closed.load(Ordering::Acquire)
}
#[cfg(feature = "async")]
pub async fn acquire_async(&self) -> Result<PoolGuard, DriverError> {
if self.inner.closed.load(Ordering::Acquire) {
return Err(DriverError::Pool("pool is closed".into()));
}
if let Some(guard) = self.try_pop_idle()? {
return Ok(guard);
}
loop {
let current = self.inner.open_count.load(Ordering::Acquire);
if current >= self.inner.max_size {
if let Some(timeout) = self.inner.acquire_timeout {
let (lock, cvar) = &self.inner.release_pair;
let guard = lock.lock().unwrap_or_else(|e| e.into_inner());
let (_guard, result) = cvar
.wait_timeout(guard, timeout)
.unwrap_or_else(|e| e.into_inner());
if result.timed_out() {
return Err(DriverError::Pool(
"pool exhausted: acquire timeout expired".into(),
));
}
if let Some(guard) = self.try_pop_idle()? {
return Ok(guard);
}
continue;
}
return Err(DriverError::Pool(
"pool exhausted: all connections in use".into(),
));
}
if self
.inner
.open_count
.compare_exchange(current, current + 1, Ordering::AcqRel, Ordering::Acquire)
.is_ok()
{
break;
}
}
if self.inner.config.host_is_uds() {
let conn_result = Connection::connect_arc(self.inner.config.clone());
match conn_result {
Ok(mut conn) => {
conn.set_max_stmt_cache_size(self.inner.max_stmt_cache_size);
self.warmup_conn(&mut conn);
Ok(PoolGuard {
conn: Some(PoolSlot::Sync(conn)),
pool: self.inner.clone(),
discard: false,
})
}
Err(e) => {
self.inner.open_count.fetch_sub(1, Ordering::AcqRel);
Err(e)
}
}
} else {
let conn_result = AsyncConnection::connect_arc(self.inner.config.clone()).await;
match conn_result {
Ok(mut conn) => {
conn.set_max_stmt_cache_size(self.inner.max_stmt_cache_size);
Ok(PoolGuard {
conn: Some(PoolSlot::Async(conn)),
pool: self.inner.clone(),
discard: false,
})
}
Err(e) => {
self.inner.open_count.fetch_sub(1, Ordering::AcqRel);
Err(e)
}
}
}
}
}
impl Clone for Pool {
fn clone(&self) -> Self {
Pool {
inner: self.inner.clone(),
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct PoolStatus {
pub idle: usize,
pub active: usize,
pub open: usize,
pub max_size: usize,
}
pub struct PoolBuilder {
url: Option<String>,
max_size: usize,
max_lifetime: Option<Duration>,
acquire_timeout: Option<Duration>,
min_idle: usize,
max_stmt_cache_size: usize,
stale_timeout: Duration,
}
impl PoolBuilder {
fn new() -> Self {
Self {
url: None,
max_size: 10,
max_lifetime: Some(Duration::from_secs(30 * 60)), acquire_timeout: Some(Duration::from_secs(5)), min_idle: 0, max_stmt_cache_size: 256, stale_timeout: Duration::from_secs(30), }
}
pub fn url(mut self, url: &str) -> Self {
self.url = Some(url.to_owned());
self
}
pub fn max_size(mut self, size: usize) -> Self {
self.max_size = size;
self
}
pub fn max_lifetime(mut self, lifetime: Option<Duration>) -> Self {
self.max_lifetime = lifetime;
self
}
pub fn acquire_timeout(mut self, timeout: Option<Duration>) -> Self {
self.acquire_timeout = timeout;
self
}
pub fn min_idle(mut self, count: usize) -> Self {
self.min_idle = count;
self
}
pub fn max_stmt_cache_size(mut self, size: usize) -> Self {
self.max_stmt_cache_size = size;
self
}
pub fn stale_timeout(mut self, timeout: Duration) -> Self {
self.stale_timeout = timeout;
self
}
pub fn build(self) -> Result<Pool, DriverError> {
let url = self
.url
.ok_or_else(|| DriverError::Pool("pool builder requires a URL".into()))?;
let config = Arc::new(Config::from_url(&url)?);
let pool = Pool {
inner: Arc::new(PoolInner {
stack: std::sync::Mutex::new(Vec::with_capacity(self.max_size)),
max_size: self.max_size,
open_count: AtomicUsize::new(0),
config,
closed: AtomicBool::new(false),
release_pair: (std::sync::Mutex::new(()), std::sync::Condvar::new()),
max_lifetime: self.max_lifetime,
acquire_timeout: self.acquire_timeout,
min_idle: self.min_idle,
warmup_sqls: std::sync::Mutex::new(Arc::new(Vec::new())),
max_stmt_cache_size: self.max_stmt_cache_size,
stale_timeout: self.stale_timeout,
}),
};
if self.min_idle > 0 {
let inner = pool.inner.clone();
std::thread::spawn(move || {
maintain_min_idle(inner);
});
}
Ok(pool)
}
}
fn maintain_min_idle(inner: Arc<PoolInner>) {
loop {
if inner.closed.load(Ordering::Acquire) {
return;
}
let idle_count = inner.stack.lock().unwrap_or_else(|e| e.into_inner()).len();
let needed = inner.min_idle.saturating_sub(idle_count);
for _ in 0..needed {
if inner.closed.load(Ordering::Acquire) {
return;
}
let current = inner.open_count.load(Ordering::Acquire);
if current >= inner.max_size {
break;
}
if inner
.open_count
.compare_exchange(current, current + 1, Ordering::AcqRel, Ordering::Acquire)
.is_err()
{
continue;
}
match Connection::connect_arc(inner.config.clone()) {
Ok(conn) => {
let mut stack = inner.stack.lock().unwrap_or_else(|e| e.into_inner());
stack.push(PoolSlot::Sync(conn));
let (_, cvar) = &inner.release_pair;
cvar.notify_one();
}
Err(_) => {
inner.open_count.fetch_sub(1, Ordering::AcqRel);
}
}
}
std::thread::sleep(Duration::from_secs(1));
}
}
pub struct PoolGuard {
conn: Option<PoolSlot>,
pool: Arc<PoolInner>,
discard: bool,
}
impl PoolGuard {
#[inline]
fn sync_conn(&self) -> Result<&Connection, DriverError> {
match self.conn.as_ref() {
Some(PoolSlot::Sync(conn)) => Ok(conn),
#[cfg(feature = "async")]
Some(PoolSlot::Async(_)) => Err(DriverError::Pool(
"expected sync connection, got async; use async methods".into(),
)),
None => Err(DriverError::Pool("connection already taken".into())),
}
}
#[inline]
fn sync_conn_mut(&mut self) -> Result<&mut Connection, DriverError> {
match self.conn.as_mut() {
Some(PoolSlot::Sync(conn)) => Ok(conn),
#[cfg(feature = "async")]
Some(PoolSlot::Async(_)) => Err(DriverError::Pool(
"expected sync connection, got async; use async methods".into(),
)),
None => Err(DriverError::Pool("connection already taken".into())),
}
}
pub fn mark_discard(&mut self) {
self.discard = true;
}
pub fn cancel(&self) -> Result<(), DriverError> {
self.sync_conn()?.cancel()
}
pub fn pid(&self) -> i32 {
match self.conn.as_ref().expect("connection taken") {
PoolSlot::Sync(conn) => conn.pid(),
#[cfg(feature = "async")]
PoolSlot::Async(conn) => conn.pid(),
}
}
pub fn is_idle(&self) -> bool {
match self.conn.as_ref().expect("connection taken") {
PoolSlot::Sync(conn) => conn.is_idle(),
#[cfg(feature = "async")]
PoolSlot::Async(conn) => conn.is_idle(),
}
}
pub fn is_in_transaction(&self) -> bool {
match self.conn.as_ref().expect("connection taken") {
PoolSlot::Sync(conn) => conn.is_in_transaction(),
#[cfg(feature = "async")]
PoolSlot::Async(conn) => conn.is_in_transaction(),
}
}
#[inline]
pub fn query(
&mut self,
sql: &str,
sql_hash: u64,
params: &[&(dyn Encode + Sync)],
) -> Result<QueryResult, DriverError> {
self.sync_conn_mut()?.query(sql, sql_hash, params)
}
#[inline]
pub fn execute(
&mut self,
sql: &str,
sql_hash: u64,
params: &[&(dyn Encode + Sync)],
) -> Result<u64, DriverError> {
self.sync_conn_mut()?.execute(sql, sql_hash, params)
}
pub fn execute_pipeline(
&mut self,
sql: &str,
sql_hash: u64,
param_sets: &[&[&(dyn Encode + Sync)]],
) -> Result<Vec<u64>, DriverError> {
self.sync_conn_mut()?
.execute_pipeline(sql, sql_hash, param_sets)
}
pub fn simple_query(&mut self, sql: &str) -> Result<(), DriverError> {
self.sync_conn_mut()?.simple_query(sql)
}
pub fn simple_query_rows(&mut self, sql: &str) -> Result<Vec<SimpleRow>, DriverError> {
self.sync_conn_mut()?.simple_query_rows(sql)
}
pub fn for_each<F>(
&mut self,
sql: &str,
sql_hash: u64,
params: &[&(dyn Encode + Sync)],
f: F,
) -> Result<(), DriverError>
where
F: FnMut(PgDataRow<'_>) -> Result<(), DriverError>,
{
self.sync_conn_mut()?.for_each(sql, sql_hash, params, f)
}
pub fn for_each_raw<F>(
&mut self,
sql: &str,
sql_hash: u64,
params: &[&(dyn Encode + Sync)],
f: F,
) -> Result<(), DriverError>
where
F: FnMut(&[u8]) -> Result<(), DriverError>,
{
self.sync_conn_mut()?.for_each_raw(sql, sql_hash, params, f)
}
pub fn query_streaming_start(
&mut self,
sql: &str,
sql_hash: u64,
params: &[&(dyn Encode + Sync)],
chunk_size: i32,
) -> Result<(std::sync::Arc<[crate::types::ColumnDesc]>, bool), DriverError> {
self.sync_conn_mut()?
.query_streaming_start(sql, sql_hash, params, chunk_size)
}
pub fn streaming_send_execute(&mut self, chunk_size: i32) -> Result<(), DriverError> {
self.sync_conn_mut()?.streaming_send_execute(chunk_size)
}
pub fn streaming_next_chunk(
&mut self,
arena: &mut Arena,
all_col_offsets: &mut Vec<(usize, i32)>,
) -> Result<bool, DriverError> {
self.sync_conn_mut()?
.streaming_next_chunk(arena, all_col_offsets)
}
pub fn copy_in<'a, I>(
&mut self,
table: &str,
columns: &[&str],
rows: I,
) -> Result<u64, DriverError>
where
I: IntoIterator<Item = &'a str>,
{
self.sync_conn_mut()?.copy_in(table, columns, rows)
}
pub fn copy_out<W: std::io::Write>(
&mut self,
query: &str,
writer: &mut W,
) -> Result<u64, DriverError> {
self.sync_conn_mut()?.copy_out(query, writer)
}
pub fn is_sync(&self) -> bool {
matches!(self.conn.as_ref(), Some(PoolSlot::Sync(_)))
}
#[cfg(feature = "async")]
pub fn is_async(&self) -> bool {
matches!(self.conn.as_ref(), Some(PoolSlot::Async(_)))
}
#[cfg(feature = "async")]
pub async fn query_async(
&mut self,
sql: &str,
sql_hash: u64,
params: &[&(dyn Encode + Sync)],
) -> Result<QueryResult, DriverError> {
match self.conn.as_mut() {
Some(PoolSlot::Sync(conn)) => conn.query(sql, sql_hash, params),
Some(PoolSlot::Async(conn)) => conn.query(sql, sql_hash, params).await,
None => Err(DriverError::Pool("connection already taken".into())),
}
}
#[cfg(feature = "async")]
pub async fn execute_async(
&mut self,
sql: &str,
sql_hash: u64,
params: &[&(dyn Encode + Sync)],
) -> Result<u64, DriverError> {
match self.conn.as_mut() {
Some(PoolSlot::Sync(conn)) => conn.execute(sql, sql_hash, params),
Some(PoolSlot::Async(conn)) => conn.execute(sql, sql_hash, params).await,
None => Err(DriverError::Pool("connection already taken".into())),
}
}
#[cfg(feature = "async")]
pub async fn simple_query_async(&mut self, sql: &str) -> Result<(), DriverError> {
match self.conn.as_mut() {
Some(PoolSlot::Sync(conn)) => conn.simple_query(sql),
Some(PoolSlot::Async(conn)) => conn.simple_query(sql).await,
None => Err(DriverError::Pool("connection already taken".into())),
}
}
pub(crate) fn ensure_stmt_prepared(
&mut self,
sql: &str,
sql_hash: u64,
params: &[&(dyn Encode + Sync)],
) -> Result<[u8; 18], DriverError> {
self.sync_conn_mut()?
.ensure_stmt_prepared(sql, sql_hash, params)
}
pub(crate) fn write_deferred_bind_execute(
&self,
sql: &str,
sql_hash: u64,
params: &[&(dyn Encode + Sync)],
buf: &mut Vec<u8>,
) {
let conn = self
.sync_conn()
.expect("sync_conn failed in write_deferred");
conn.write_deferred_bind_execute(sql, sql_hash, params, buf);
}
pub(crate) fn flush_deferred_pipeline(
&mut self,
buf: &mut Vec<u8>,
count: usize,
) -> Result<Vec<u64>, DriverError> {
self.sync_conn_mut()?.flush_deferred_pipeline(buf, count)
}
}
impl Drop for PoolGuard {
fn drop(&mut self) {
if let Some(slot) = self.conn.take() {
let should_discard = self.discard
|| self.pool.closed.load(Ordering::Acquire)
|| match &slot {
PoolSlot::Sync(conn) => {
conn.is_in_failed_transaction()
|| conn.is_in_transaction()
|| conn.is_streaming()
}
#[cfg(feature = "async")]
PoolSlot::Async(conn) => {
conn.is_in_failed_transaction() || conn.is_in_transaction()
}
};
if should_discard {
self.pool.open_count.fetch_sub(1, Ordering::AcqRel);
return;
}
let mut slot = slot;
match &mut slot {
PoolSlot::Sync(conn) => {
if conn.query_counter() & 63 == 0 {
conn.touch();
}
}
#[cfg(feature = "async")]
PoolSlot::Async(conn) => {
if conn.query_counter() & 63 == 0 {
conn.touch();
}
}
}
{
let mut stack = self.pool.stack.lock().unwrap_or_else(|e| e.into_inner());
stack.push(slot);
}
if self.pool.open_count.load(Ordering::Relaxed) >= self.pool.max_size {
let (_, cvar) = &self.pool.release_pair;
cvar.notify_one();
}
}
}
}
pub struct Transaction {
guard: PoolGuard,
committed: bool,
deferred_buf: Vec<u8>,
deferred_count: usize,
}
impl Transaction {
pub fn commit(mut self) -> Result<(), DriverError> {
if self.deferred_count > 0 {
self.flush_deferred()?;
}
self.guard.simple_query("COMMIT")?;
self.committed = true;
Ok(())
}
pub fn rollback(mut self) -> Result<(), DriverError> {
self.deferred_buf.clear();
self.deferred_count = 0;
self.guard.simple_query("ROLLBACK")?;
self.committed = true; Ok(())
}
pub fn query(
&mut self,
sql: &str,
sql_hash: u64,
params: &[&(dyn Encode + Sync)],
) -> Result<QueryResult, DriverError> {
if self.deferred_count > 0 {
self.flush_deferred()?;
}
self.guard.query(sql, sql_hash, params)
}
pub fn execute(
&mut self,
sql: &str,
sql_hash: u64,
params: &[&(dyn Encode + Sync)],
) -> Result<u64, DriverError> {
self.guard.execute(sql, sql_hash, params)
}
pub fn execute_pipeline(
&mut self,
sql: &str,
sql_hash: u64,
param_sets: &[&[&(dyn Encode + Sync)]],
) -> Result<Vec<u64>, DriverError> {
self.guard.execute_pipeline(sql, sql_hash, param_sets)
}
pub fn for_each<F>(
&mut self,
sql: &str,
sql_hash: u64,
params: &[&(dyn Encode + Sync)],
f: F,
) -> Result<(), DriverError>
where
F: FnMut(crate::types::PgDataRow<'_>) -> Result<(), DriverError>,
{
if self.deferred_count > 0 {
self.flush_deferred()?;
}
self.guard.for_each(sql, sql_hash, params, f)
}
pub fn for_each_raw<F>(
&mut self,
sql: &str,
sql_hash: u64,
params: &[&(dyn Encode + Sync)],
f: F,
) -> Result<(), DriverError>
where
F: FnMut(&[u8]) -> Result<(), DriverError>,
{
if self.deferred_count > 0 {
self.flush_deferred()?;
}
self.guard.for_each_raw(sql, sql_hash, params, f)
}
pub fn simple_query(&mut self, sql: &str) -> Result<(), DriverError> {
if self.deferred_count > 0 {
self.flush_deferred()?;
}
self.guard.simple_query(sql)
}
pub fn defer_execute(
&mut self,
sql: &str,
sql_hash: u64,
params: &[&(dyn Encode + Sync)],
) -> Result<(), DriverError> {
if params.len() > i16::MAX as usize {
return Err(DriverError::Protocol(format!(
"parameter count {} exceeds maximum {}",
params.len(),
i16::MAX
)));
}
self.guard.ensure_stmt_prepared(sql, sql_hash, params)?;
self.guard
.write_deferred_bind_execute(sql, sql_hash, params, &mut self.deferred_buf);
self.deferred_count += 1;
Ok(())
}
pub fn flush_deferred(&mut self) -> Result<Vec<u64>, DriverError> {
let count = self.deferred_count;
self.deferred_count = 0;
self.guard
.flush_deferred_pipeline(&mut self.deferred_buf, count)
}
pub fn deferred_count(&self) -> usize {
self.deferred_count
}
}
impl Drop for Transaction {
fn drop(&mut self) {
if !self.committed {
if let Some(_slot) = self.guard.conn.take() {
self.guard.pool.open_count.fetch_sub(1, Ordering::AcqRel);
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn pool_builder_requires_url() {
let result = PoolBuilder::new().build();
assert!(result.is_err());
}
#[test]
fn pool_builder_validates_url() {
let result = PoolBuilder::new().url("not_a_url").build();
assert!(result.is_err());
}
#[test]
fn pool_builder_accepts_valid_url() {
let pool = PoolBuilder::new()
.url("postgres://user:pass@localhost/db")
.max_size(5)
.build()
.unwrap();
assert_eq!(pool.max_size(), 5);
assert_eq!(pool.open_count(), 0);
}
#[test]
fn pool_connect_validates_url() {
let result = Pool::connect("not_a_url");
assert!(result.is_err());
}
#[test]
fn pool_max_size_zero() {
let pool = PoolBuilder::new()
.url("postgres://user:pass@localhost/db")
.max_size(0)
.build()
.unwrap();
let result = pool.acquire();
assert!(result.is_err());
match result {
Err(DriverError::Pool(msg)) => assert!(msg.contains("exhausted")),
Err(e) => panic!("expected Pool error, got: {e:?}"),
Ok(_) => panic!("expected error, got Ok"),
}
}
#[test]
fn pool_clone_shares_state() {
let pool = PoolBuilder::new()
.url("postgres://user:pass@localhost/db")
.max_size(5)
.build()
.unwrap();
let pool2 = pool.clone();
assert_eq!(pool.max_size(), pool2.max_size());
}
#[test]
fn pool_builder_max_lifetime() {
let pool = PoolBuilder::new()
.url("postgres://user:pass@localhost/db")
.max_lifetime(Some(Duration::from_secs(60)))
.build()
.unwrap();
assert_eq!(pool.inner.max_lifetime, Some(Duration::from_secs(60)));
}
#[test]
fn pool_builder_max_lifetime_none() {
let pool = PoolBuilder::new()
.url("postgres://user:pass@localhost/db")
.max_lifetime(None)
.build()
.unwrap();
assert_eq!(pool.inner.max_lifetime, None);
}
#[test]
fn pool_builder_acquire_timeout_none() {
let pool = PoolBuilder::new()
.url("postgres://user:pass@localhost/db")
.acquire_timeout(None)
.build()
.unwrap();
assert_eq!(pool.inner.acquire_timeout, None);
}
#[test]
fn pool_builder_acquire_timeout_custom() {
let pool = PoolBuilder::new()
.url("postgres://user:pass@localhost/db")
.acquire_timeout(Some(Duration::from_secs(10)))
.build()
.unwrap();
assert_eq!(pool.inner.acquire_timeout, Some(Duration::from_secs(10)));
}
#[test]
fn pool_builder_min_idle() {
let pool = PoolBuilder::new()
.url("postgres://user:pass@localhost/db")
.min_idle(2)
.build()
.unwrap();
assert_eq!(pool.inner.min_idle, 2);
}
#[test]
fn pool_close_marks_closed() {
let pool = PoolBuilder::new()
.url("postgres://user:pass@localhost/db")
.max_size(5)
.build()
.unwrap();
assert!(!pool.is_closed());
pool.close();
assert!(pool.is_closed());
let result = pool.acquire();
assert!(result.is_err());
match result {
Err(DriverError::Pool(msg)) => assert!(msg.contains("closed")),
Err(e) => panic!("expected Pool(closed) error, got: {e:?}"),
Ok(_) => panic!("expected error, got Ok"),
}
}
#[test]
fn pool_status_initial() {
let pool = PoolBuilder::new()
.url("postgres://user:pass@localhost/db")
.max_size(10)
.build()
.unwrap();
let status = pool.status();
assert_eq!(status.idle, 0);
assert_eq!(status.active, 0);
assert_eq!(status.open, 0);
assert_eq!(status.max_size, 10);
}
#[test]
fn pool_builder_defaults() {
let pool = PoolBuilder::new()
.url("postgres://user:pass@localhost/db")
.build()
.unwrap();
assert_eq!(pool.max_size(), 10);
assert_eq!(pool.inner.max_lifetime, Some(Duration::from_secs(30 * 60)));
assert_eq!(pool.inner.acquire_timeout, Some(Duration::from_secs(5)));
assert_eq!(pool.inner.min_idle, 0);
}
#[test]
fn pool_open_count_initial() {
let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
assert_eq!(pool.open_count(), 0);
}
#[test]
fn pool_builder_max_stmt_cache_size_default() {
let pool = PoolBuilder::new()
.url("postgres://user:pass@localhost/db")
.build()
.unwrap();
assert_eq!(pool.inner.max_stmt_cache_size, 256);
}
#[test]
fn pool_builder_max_stmt_cache_size_custom() {
let pool = PoolBuilder::new()
.url("postgres://user:pass@localhost/db")
.max_stmt_cache_size(512)
.build()
.unwrap();
assert_eq!(pool.inner.max_stmt_cache_size, 512);
}
#[test]
fn pool_is_uds_false_for_tcp() {
let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
assert!(!pool.is_uds());
}
#[cfg(unix)]
#[test]
fn pool_is_uds_true_for_unix_socket() {
let pool = Pool::connect("postgres://user@localhost/db?host=/tmp").unwrap();
assert!(pool.is_uds());
}
#[cfg(unix)]
#[test]
fn pool_is_uds_true_for_var_run_socket() {
let pool = Pool::connect("postgres://user@localhost/db?host=/var/run/postgresql").unwrap();
assert!(pool.is_uds());
}
#[test]
fn pool_is_uds_false_for_ip_address() {
let pool = Pool::connect("postgres://user:pass@127.0.0.1/db").unwrap();
assert!(!pool.is_uds());
}
#[cfg(unix)]
#[test]
fn pool_slot_sync_created_for_uds_config() {
let config = Config::from_url("postgres://user@localhost/db?host=/tmp").unwrap();
assert!(config.host_is_uds());
}
#[test]
fn pool_slot_tcp_config() {
let config = Config::from_url("postgres://user:pass@localhost/db").unwrap();
assert!(!config.host_is_uds());
}
#[test]
fn pool_is_uds_false_for_hostname() {
let pool = Pool::connect("postgres://user:pass@db.example.com/db").unwrap();
assert!(!pool.is_uds());
}
#[cfg(unix)]
#[test]
fn pool_is_uds_true_for_tmp() {
let pool = Pool::connect("postgres://user@localhost/db?host=/tmp").unwrap();
assert!(pool.is_uds());
}
#[test]
fn pool_close_then_acquire_fails() {
let pool = PoolBuilder::new()
.url("postgres://user:pass@localhost/db")
.max_size(5)
.build()
.unwrap();
pool.close();
let result = pool.acquire();
assert!(result.is_err());
match result {
Err(DriverError::Pool(msg)) => {
assert!(msg.contains("closed"), "should say closed: {msg}")
}
Err(e) => panic!("expected Pool error, got: {e:?}"),
Ok(_) => panic!("expected error"),
}
}
#[test]
fn pool_is_closed_before_and_after() {
let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
assert!(!pool.is_closed());
pool.close();
assert!(pool.is_closed());
}
#[test]
fn pool_exhausted_no_timeout() {
let pool = PoolBuilder::new()
.url("postgres://user:pass@localhost/db")
.max_size(0)
.acquire_timeout(None) .build()
.unwrap();
let result = pool.acquire();
assert!(result.is_err());
match result {
Err(DriverError::Pool(msg)) => {
assert!(msg.contains("exhausted"), "should say exhausted: {msg}")
}
Err(e) => panic!("expected Pool error, got: {e:?}"),
Ok(_) => panic!("expected error"),
}
}
#[test]
fn pool_builder_no_url_error() {
let result = PoolBuilder::new().max_size(5).build();
assert!(result.is_err());
match result {
Err(DriverError::Pool(msg)) => {
assert!(msg.contains("URL"), "should mention URL: {msg}")
}
Err(e) => panic!("expected Pool error, got: {e:?}"),
Ok(_) => panic!("expected error"),
}
}
#[test]
fn pool_builder_invalid_url_error() {
let result = PoolBuilder::new().url("ftp://something").build();
assert!(result.is_err());
}
#[test]
fn pool_builder_stmt_cache_size_zero() {
let pool = PoolBuilder::new()
.url("postgres://user:pass@localhost/db")
.max_stmt_cache_size(0)
.build()
.unwrap();
assert_eq!(pool.inner.max_stmt_cache_size, 0);
}
#[test]
fn pool_status_reflects_max_size() {
let pool = PoolBuilder::new()
.url("postgres://user:pass@localhost/db")
.max_size(20)
.build()
.unwrap();
let status = pool.status();
assert_eq!(status.max_size, 20);
assert_eq!(status.idle, 0);
assert_eq!(status.active, 0);
assert_eq!(status.open, 0);
}
#[test]
fn pool_clone_shares_config() {
let pool = PoolBuilder::new()
.url("postgres://user:pass@localhost/db")
.max_size(7)
.build()
.unwrap();
let p2 = pool.clone();
assert_eq!(pool.max_size(), 7);
assert_eq!(p2.max_size(), 7);
assert_eq!(pool.open_count(), p2.open_count());
}
#[test]
fn pool_set_warmup_sqls_empty() {
let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
pool.set_warmup_sqls(&[]);
let sqls = pool
.inner
.warmup_sqls
.lock()
.unwrap_or_else(|e| e.into_inner())
.clone();
assert!(sqls.is_empty());
}
#[test]
fn pool_set_warmup_sqls_multiple() {
let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
pool.set_warmup_sqls(&["SELECT 1", "SELECT 2", "SELECT 3"]);
let sqls = pool
.inner
.warmup_sqls
.lock()
.unwrap_or_else(|e| e.into_inner())
.clone();
assert_eq!(sqls.len(), 3);
assert_eq!(&*sqls[0], "SELECT 1");
assert_eq!(&*sqls[1], "SELECT 2");
assert_eq!(&*sqls[2], "SELECT 3");
}
#[test]
fn pool_set_warmup_sqls_overwrite() {
let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
pool.set_warmup_sqls(&["SELECT 1"]);
pool.set_warmup_sqls(&["SELECT 99"]);
let sqls = pool
.inner
.warmup_sqls
.lock()
.unwrap_or_else(|e| e.into_inner())
.clone();
assert_eq!(sqls.len(), 1);
assert_eq!(&*sqls[0], "SELECT 99");
}
#[test]
fn pool_status_debug() {
let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
let status = pool.status();
let dbg = format!("{status:?}");
assert!(dbg.contains("PoolStatus"));
assert!(dbg.contains("idle"));
assert!(dbg.contains("active"));
assert!(dbg.contains("open"));
assert!(dbg.contains("max_size"));
}
#[test]
fn config_host_is_uds_returns_true_for_slash() {
let config = Config::from_url("postgres://user@localhost/db?host=/tmp").unwrap();
assert!(config.host_is_uds());
}
#[test]
fn config_host_is_uds_returns_false_for_tcp() {
let config = Config::from_url("postgres://user:pass@localhost/db").unwrap();
assert!(!config.host_is_uds());
}
#[test]
fn config_host_is_uds_returns_false_for_ip() {
let config = Config::from_url("postgres://user:pass@192.168.1.1/db").unwrap();
assert!(!config.host_is_uds());
}
#[test]
fn pool_builder_full_chain() {
let pool = PoolBuilder::new()
.url("postgres://user:pass@localhost/db")
.max_size(3)
.max_lifetime(Some(Duration::from_secs(600)))
.acquire_timeout(Some(Duration::from_secs(5)))
.min_idle(1)
.max_stmt_cache_size(128)
.build()
.unwrap();
assert_eq!(pool.max_size(), 3);
assert_eq!(pool.inner.max_lifetime, Some(Duration::from_secs(600)));
assert_eq!(pool.inner.acquire_timeout, Some(Duration::from_secs(5)));
assert_eq!(pool.inner.min_idle, 1);
assert_eq!(pool.inner.max_stmt_cache_size, 128);
}
#[test]
fn pool_max_size_zero_rejects_all_acquires() {
let pool = PoolBuilder::new()
.url("postgres://user:pass@localhost/db")
.max_size(0)
.build()
.unwrap();
let result = pool.acquire();
assert!(result.is_err());
match &result {
Err(DriverError::Pool(msg)) => assert!(msg.contains("exhausted")),
_ => panic!("expected pool exhausted error"),
}
}
#[test]
fn url_parse_unknown_sslmode_returns_error() {
let result = Config::from_url("postgres://u:p@h/d?sslmode=bogus");
assert!(result.is_err());
let msg = format!("{}", result.unwrap_err());
assert!(msg.contains("unknown sslmode"));
}
#[test]
fn url_parse_invalid_port_returns_error() {
let result = Config::from_url("postgres://u:p@h:abc/d");
assert!(result.is_err());
let msg = format!("{}", result.unwrap_err());
assert!(msg.contains("invalid port"));
}
#[test]
fn url_parse_missing_at_sign_returns_error() {
let result = Config::from_url("postgres://u:plocalhost/d");
assert!(result.is_err());
let msg = format!("{}", result.unwrap_err());
assert!(msg.contains("missing @"));
}
#[test]
fn url_parse_empty_host_returns_error() {
let result = Config::from_url("postgres://u:p@/d");
assert!(result.is_err());
}
#[test]
fn url_parse_empty_user_returns_error() {
let result = Config::from_url("postgres://:p@h/d");
assert!(result.is_err());
}
#[test]
fn url_parse_statement_timeout_invalid_uses_default() {
let config = Config::from_url("postgres://u:p@h/d?statement_timeout=notnum").unwrap();
assert_eq!(config.statement_timeout_secs, 30);
}
#[test]
fn url_parse_malformed_percent_encoding() {
let result = Config::from_url("postgres://u%:p@h/d");
assert!(result.is_err());
}
#[test]
fn url_parse_invalid_hex_in_percent_encoding() {
let result = Config::from_url("postgres://u%ZZ:p@h/d");
assert!(result.is_err());
}
}