use std::sync::Mutex;
use std::time::Duration;
use bsql_driver_postgres::arena::acquire_arena;
use bsql_driver_postgres::codec::Encode;
use crate::error::{BsqlError, BsqlResult};
use crate::stream::QueryStream;
use crate::transaction::Transaction;
#[derive(Debug, Clone)]
pub struct RawRow(Vec<Option<String>>);
impl RawRow {
pub fn get(&self, idx: usize) -> Option<&str> {
self.0.get(idx)?.as_deref()
}
pub fn len(&self) -> usize {
self.0.len()
}
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
pub fn iter(&self) -> impl Iterator<Item = Option<&str>> {
self.0.iter().map(|v| v.as_deref())
}
}
pub struct Pool {
pub(crate) inner: bsql_driver_postgres::Pool,
pub(crate) read_pool: Option<bsql_driver_postgres::Pool>,
}
pub struct PoolBuilder {
url: Option<String>,
max_size: usize,
max_lifetime: Option<Option<Duration>>,
acquire_timeout: Option<Option<Duration>>,
min_idle: Option<usize>,
replica_url: Option<String>,
replica_max_size: Option<usize>,
}
impl PoolBuilder {
pub fn url(mut self, url: &str) -> Self {
self.url = Some(url.into());
self
}
pub fn max_size(mut self, size: usize) -> Self {
self.max_size = size;
self
}
pub fn max_lifetime(mut self, d: Option<Duration>) -> Self {
self.max_lifetime = Some(d);
self
}
pub fn max_lifetime_secs(self, secs: u64) -> Self {
self.max_lifetime(Some(Duration::from_secs(secs)))
}
pub fn lifetime_secs(self, secs: u64) -> Self {
self.max_lifetime_secs(secs)
}
pub fn acquire_timeout(mut self, d: Option<Duration>) -> Self {
self.acquire_timeout = Some(d);
self
}
pub fn acquire_timeout_secs(self, secs: u64) -> Self {
self.acquire_timeout(Some(Duration::from_secs(secs)))
}
pub fn timeout_secs(self, secs: u64) -> Self {
self.acquire_timeout_secs(secs)
}
pub fn min_idle(mut self, n: usize) -> Self {
self.min_idle = Some(n);
self
}
pub fn replica_url(mut self, url: &str) -> Self {
self.replica_url = Some(url.into());
self
}
pub fn replica_max_size(mut self, size: usize) -> Self {
self.replica_max_size = Some(size);
self
}
pub async fn build(self) -> BsqlResult<Pool> {
let url = self.url.ok_or_else(|| {
BsqlError::from(bsql_driver_postgres::DriverError::Pool(
"pool builder requires a URL".into(),
))
})?;
let mut builder = bsql_driver_postgres::Pool::builder()
.url(&url)
.max_size(self.max_size);
if let Some(lt) = self.max_lifetime {
builder = builder.max_lifetime(lt);
}
if let Some(at) = self.acquire_timeout {
builder = builder.acquire_timeout(at);
}
if let Some(mi) = self.min_idle {
builder = builder.min_idle(mi);
}
let inner = builder.build().map_err(BsqlError::from)?;
let read_pool = if let Some(replica_url) = &self.replica_url {
let replica_size = self.replica_max_size.unwrap_or(self.max_size);
let mut rbuilder = bsql_driver_postgres::Pool::builder()
.url(replica_url)
.max_size(replica_size);
if let Some(lt) = self.max_lifetime {
rbuilder = rbuilder.max_lifetime(lt);
}
if let Some(at) = self.acquire_timeout {
rbuilder = rbuilder.acquire_timeout(at);
}
Some(rbuilder.build().map_err(BsqlError::from)?)
} else {
None
};
Ok(Pool { inner, read_pool })
}
}
impl Pool {
pub async fn connect(url: &str) -> BsqlResult<Self> {
let inner = bsql_driver_postgres::Pool::connect(url).map_err(BsqlError::from)?;
Ok(Pool {
inner,
read_pool: None,
})
}
pub fn builder() -> PoolBuilder {
PoolBuilder {
url: None,
max_size: 10,
max_lifetime: None,
acquire_timeout: None,
min_idle: None,
replica_url: None,
replica_max_size: None,
}
}
pub async fn acquire(&self) -> BsqlResult<PoolConnection> {
let guard = self.inner.acquire().map_err(BsqlError::from)?;
Ok(PoolConnection {
inner: Mutex::new(guard),
})
}
pub async fn begin(&self) -> BsqlResult<Transaction> {
let tx = self.inner.begin().map_err(BsqlError::from)?;
Ok(Transaction::from_driver(tx))
}
pub async fn query_stream(
&self,
sql: &str,
sql_hash: u64,
params: &[&(dyn Encode + Sync)],
) -> BsqlResult<QueryStream> {
let mut guard = self.inner.acquire().map_err(BsqlError::from)?;
let mut arena = acquire_arena();
const CHUNK_SIZE: i32 = 64;
let (columns, _) = guard
.query_streaming_start(sql, sql_hash, params, CHUNK_SIZE)
.map_err(BsqlError::from)?;
let num_cols = columns.len();
let mut all_col_offsets: Vec<(usize, i32)> =
Vec::with_capacity(num_cols * CHUNK_SIZE as usize);
let more = guard
.streaming_next_chunk(&mut arena, &mut all_col_offsets)
.map_err(BsqlError::from)?;
let first_result = bsql_driver_postgres::QueryResult::from_parts(
all_col_offsets,
num_cols,
columns.clone(),
0,
);
Ok(QueryStream::new(guard, arena, first_result, columns, !more))
}
pub fn set_warmup_sqls(&self, sqls: &[&str]) {
self.inner.set_warmup_sqls(sqls);
}
pub async fn raw_query(&self, sql: &str) -> BsqlResult<Vec<RawRow>> {
let mut guard = self.inner.acquire().map_err(BsqlError::from)?;
let rows = guard
.simple_query_rows(sql)
.map_err(BsqlError::from_driver_query)?;
Ok(rows.into_iter().map(RawRow).collect())
}
pub async fn raw_execute(&self, sql: &str) -> BsqlResult<()> {
let mut guard = self.inner.acquire().map_err(BsqlError::from)?;
guard
.simple_query(sql)
.map_err(BsqlError::from_driver_query)?;
Ok(())
}
pub async fn copy_in<'a, I>(&self, table: &str, columns: &[&str], rows: I) -> BsqlResult<u64>
where
I: IntoIterator<Item = &'a str>,
{
let mut guard = self.inner.acquire().map_err(BsqlError::from)?;
guard
.copy_in(table, columns, rows)
.map_err(BsqlError::from_driver_query)
}
pub async fn copy_out<W: std::io::Write>(
&self,
query: &str,
writer: &mut W,
) -> BsqlResult<u64> {
let mut guard = self.inner.acquire().map_err(BsqlError::from)?;
guard
.copy_out(query, writer)
.map_err(BsqlError::from_driver_query)
}
pub fn status(&self) -> PoolStatus {
let driver_status = self.inner.status();
PoolStatus {
idle: driver_status.idle,
active: driver_status.active,
open: driver_status.open,
max_size: driver_status.max_size,
}
}
pub fn close(&self) {
self.inner.close();
if let Some(ref rp) = self.read_pool {
rp.close();
}
}
pub fn is_closed(&self) -> bool {
self.inner.is_closed()
}
pub fn has_replica(&self) -> bool {
self.read_pool.is_some()
}
pub fn is_uds(&self) -> bool {
self.inner.is_uds()
}
pub async fn for_each_raw<F>(
&self,
sql: &str,
sql_hash: u64,
params: &[&(dyn Encode + Sync)],
readonly: bool,
mut f: F,
) -> BsqlResult<()>
where
F: FnMut(bsql_driver_postgres::PgDataRow<'_>) -> BsqlResult<()>,
{
let pool = if readonly {
self.read_pool.as_ref().unwrap_or(&self.inner)
} else {
&self.inner
};
let mut guard = pool.acquire().map_err(BsqlError::from)?;
let mut user_err: Option<BsqlError> = None;
let driver_result = guard.for_each(sql, sql_hash, params, |row| match f(row) {
Ok(()) => Ok(()),
Err(e) => {
user_err = Some(e);
Err(bsql_driver_postgres::DriverError::Protocol(
"for_each closure error".into(),
))
}
});
if let Some(e) = user_err {
return Err(e);
}
driver_result.map_err(BsqlError::from_driver_query)
}
#[doc(hidden)]
pub async fn __for_each_raw_bytes<F>(
&self,
sql: &str,
sql_hash: u64,
params: &[&(dyn Encode + Sync)],
readonly: bool,
mut f: F,
) -> BsqlResult<()>
where
F: FnMut(&[u8]) -> BsqlResult<()>,
{
let pool = if readonly {
self.read_pool.as_ref().unwrap_or(&self.inner)
} else {
&self.inner
};
let mut guard = pool.acquire().map_err(BsqlError::from)?;
let mut user_err: Option<BsqlError> = None;
let driver_result = guard.for_each_raw(sql, sql_hash, params, |data| match f(data) {
Ok(()) => Ok(()),
Err(e) => {
user_err = Some(e);
Err(bsql_driver_postgres::DriverError::Protocol(
"for_each closure error".into(),
))
}
});
if let Some(e) = user_err {
return Err(e);
}
driver_result.map_err(BsqlError::from_driver_query)
}
}
impl Clone for Pool {
fn clone(&self) -> Self {
Pool {
inner: self.inner.clone(),
read_pool: self.read_pool.clone(),
}
}
}
impl std::fmt::Debug for Pool {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Pool")
.field("status", &self.status())
.finish()
}
}
pub struct PoolConnection {
pub(crate) inner: Mutex<bsql_driver_postgres::PoolGuard>,
}
impl std::fmt::Debug for PoolConnection {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PoolConnection").finish()
}
}
#[derive(Debug, Clone, Copy)]
pub struct PoolStatus {
pub idle: usize,
pub active: usize,
pub open: usize,
pub max_size: usize,
}
impl std::fmt::Display for PoolStatus {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"idle={}, active={}, open={}, max={}",
self.idle, self.active, self.open, self.max_size
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn builder_defaults() {
let b = Pool::builder();
assert_eq!(b.max_size, 10);
assert!(b.max_lifetime.is_none());
assert!(b.acquire_timeout.is_none());
assert!(b.min_idle.is_none());
}
#[test]
fn builder_max_lifetime() {
let b = Pool::builder().max_lifetime(Some(Duration::from_secs(60)));
assert_eq!(b.max_lifetime, Some(Some(Duration::from_secs(60))));
}
#[test]
fn builder_max_lifetime_none_disables() {
let b = Pool::builder().max_lifetime(None);
assert_eq!(b.max_lifetime, Some(None));
}
#[test]
fn builder_acquire_timeout() {
let b = Pool::builder().acquire_timeout(Some(Duration::from_secs(3)));
assert_eq!(b.acquire_timeout, Some(Some(Duration::from_secs(3))));
}
#[test]
fn builder_acquire_timeout_none_disables() {
let b = Pool::builder().acquire_timeout(None);
assert_eq!(b.acquire_timeout, Some(None));
}
#[test]
fn builder_min_idle() {
let b = Pool::builder().min_idle(5);
assert_eq!(b.min_idle, Some(5));
}
#[test]
fn builder_max_lifetime_secs() {
let b = Pool::builder().max_lifetime_secs(1800);
assert_eq!(b.max_lifetime, Some(Some(Duration::from_secs(1800))));
}
#[test]
fn builder_acquire_timeout_secs() {
let b = Pool::builder().acquire_timeout_secs(5);
assert_eq!(b.acquire_timeout, Some(Some(Duration::from_secs(5))));
}
#[test]
fn builder_lifetime_secs_shorthand() {
let b = Pool::builder().lifetime_secs(900);
assert_eq!(b.max_lifetime, Some(Some(Duration::from_secs(900))));
}
#[test]
fn builder_timeout_secs_shorthand() {
let b = Pool::builder().timeout_secs(3);
assert_eq!(b.acquire_timeout, Some(Some(Duration::from_secs(3))));
}
#[test]
fn builder_defaults_no_replica() {
let b = Pool::builder();
assert!(b.replica_url.is_none());
assert!(b.replica_max_size.is_none());
}
#[test]
fn builder_replica_url() {
let b = Pool::builder().replica_url("postgres://replica:5432/db");
assert_eq!(b.replica_url.as_deref(), Some("postgres://replica:5432/db"));
}
#[test]
fn builder_replica_max_size() {
let b = Pool::builder().replica_max_size(20);
assert_eq!(b.replica_max_size, Some(20));
}
#[tokio::test]
async fn pool_connect_has_no_replica() {
let pool = Pool::connect("postgres://user:pass@localhost/db")
.await
.unwrap();
assert!(!pool.has_replica());
}
#[tokio::test]
async fn pool_is_uds_false_for_tcp() {
let pool = Pool::connect("postgres://user:pass@localhost/db")
.await
.unwrap();
assert!(!pool.is_uds());
}
#[cfg(unix)]
#[tokio::test]
async fn pool_is_uds_true_for_unix_socket() {
let pool = Pool::connect("postgres://user@localhost/db?host=/tmp")
.await
.unwrap();
assert!(pool.is_uds());
}
#[tokio::test]
async fn pool_is_uds_false_for_ip() {
let pool = Pool::connect("postgres://user:pass@127.0.0.1/db")
.await
.unwrap();
assert!(!pool.is_uds());
}
#[test]
fn pool_status_display() {
let status = PoolStatus {
idle: 3,
active: 2,
open: 5,
max_size: 10,
};
assert_eq!(status.to_string(), "idle=3, active=2, open=5, max=10");
}
#[test]
fn pool_status_display_zeros() {
let status = PoolStatus {
idle: 0,
active: 0,
open: 0,
max_size: 0,
};
assert_eq!(status.to_string(), "idle=0, active=0, open=0, max=0");
}
#[test]
fn pool_connection_debug() {
let dbg_str = "PoolConnection";
assert!(!dbg_str.is_empty());
fn _assert_debug<T: std::fmt::Debug>() {}
_assert_debug::<PoolConnection>();
}
#[tokio::test]
async fn pool_debug() {
let pool = Pool::connect("postgres://user:pass@localhost/db")
.await
.unwrap();
let dbg = format!("{pool:?}");
assert!(dbg.contains("Pool"), "Debug should show Pool: {dbg}");
}
#[tokio::test]
async fn pool_clone_is_cheap() {
let pool = Pool::connect("postgres://user:pass@localhost/db")
.await
.unwrap();
let pool2 = pool.clone();
assert_eq!(pool.status().max_size, pool2.status().max_size);
assert!(!pool.has_replica());
assert!(!pool2.has_replica());
}
fn _assert_send<T: Send>() {}
fn _assert_sync<T: Sync>() {}
#[test]
fn pool_is_send_and_sync() {
_assert_send::<Pool>();
_assert_sync::<Pool>();
}
#[test]
fn pool_connection_is_send_and_sync() {
_assert_send::<PoolConnection>();
_assert_sync::<PoolConnection>();
}
#[test]
fn pool_status_is_send_and_sync() {
_assert_send::<PoolStatus>();
_assert_sync::<PoolStatus>();
}
#[tokio::test]
async fn builder_build_without_url_errors() {
let result = Pool::builder().build().await;
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("URL"), "error should mention URL: {err}");
}
#[test]
fn builder_chaining() {
let b = Pool::builder()
.url("postgres://u@localhost/db")
.max_size(20)
.lifetime_secs(600)
.timeout_secs(3)
.min_idle(2)
.replica_url("postgres://u@replica/db")
.replica_max_size(10);
assert_eq!(b.max_size, 20);
assert_eq!(b.min_idle, Some(2));
assert_eq!(b.replica_max_size, Some(10));
}
#[test]
fn raw_row_get() {
let row = RawRow(vec![Some("hello".into()), None, Some("42".into())]);
assert_eq!(row.get(0), Some("hello"));
assert_eq!(row.get(1), None);
assert_eq!(row.get(2), Some("42"));
assert_eq!(row.get(99), None);
assert_eq!(row.len(), 3);
}
#[test]
fn raw_row_is_empty() {
let empty = RawRow(vec![]);
assert!(empty.is_empty());
assert_eq!(empty.len(), 0);
let non_empty = RawRow(vec![Some("x".into())]);
assert!(!non_empty.is_empty());
}
#[test]
fn raw_row_iter() {
let row = RawRow(vec![Some("a".into()), None, Some("b".into())]);
let vals: Vec<_> = row.iter().collect();
assert_eq!(vals, vec![Some("a"), None, Some("b")]);
}
#[test]
fn raw_row_clone() {
let row = RawRow(vec![Some("hello".into()), None]);
let cloned = row.clone();
assert_eq!(cloned.get(0), Some("hello"));
assert_eq!(cloned.get(1), None);
assert_eq!(cloned.len(), 2);
}
#[test]
fn raw_row_debug() {
let row = RawRow(vec![Some("x".into())]);
let dbg = format!("{row:?}");
assert!(dbg.contains("RawRow"), "Debug should show RawRow: {dbg}");
}
#[test]
fn raw_row_all_null_values() {
let row = RawRow(vec![None, None, None]);
assert_eq!(row.len(), 3);
assert!(!row.is_empty());
assert_eq!(row.get(0), None);
assert_eq!(row.get(1), None);
assert_eq!(row.get(2), None);
let vals: Vec<_> = row.iter().collect();
assert_eq!(vals, vec![None, None, None]);
}
#[test]
fn raw_row_empty_string_values() {
let row = RawRow(vec![Some(String::new()), Some("".into())]);
assert_eq!(row.len(), 2);
assert_eq!(row.get(0), Some(""));
assert_eq!(row.get(1), Some(""));
}
#[test]
fn raw_row_get_out_of_bounds() {
let row = RawRow(vec![Some("only".into())]);
assert_eq!(row.get(0), Some("only"));
assert_eq!(row.get(1), None);
assert_eq!(row.get(100), None);
assert_eq!(row.get(usize::MAX), None);
}
#[test]
fn raw_row_iter_empty() {
let row = RawRow(vec![]);
let vals: Vec<_> = row.iter().collect();
assert!(vals.is_empty());
}
#[test]
fn raw_row_iter_mixed() {
let row = RawRow(vec![
Some("hello".into()),
None,
Some("world".into()),
None,
Some("".into()),
]);
let vals: Vec<_> = row.iter().collect();
assert_eq!(
vals,
vec![Some("hello"), None, Some("world"), None, Some("")]
);
}
#[test]
fn raw_row_single_null() {
let row = RawRow(vec![None]);
assert_eq!(row.len(), 1);
assert!(!row.is_empty());
assert_eq!(row.get(0), None);
}
}