use std::fmt;
use bsql_driver_postgres::arena::acquire_arena;
use bsql_driver_postgres::codec::Encode;
use tokio::sync::Mutex;
use crate::error::{BsqlError, BsqlResult, QueryError};
use crate::executor::OwnedResult;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum IsolationLevel {
ReadUncommitted,
ReadCommitted,
RepeatableRead,
Serializable,
}
impl IsolationLevel {
fn as_sql(&self) -> &'static str {
match self {
IsolationLevel::ReadUncommitted => "READ UNCOMMITTED",
IsolationLevel::ReadCommitted => "READ COMMITTED",
IsolationLevel::RepeatableRead => "REPEATABLE READ",
IsolationLevel::Serializable => "SERIALIZABLE",
}
}
}
impl fmt::Display for IsolationLevel {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.as_sql())
}
}
pub struct Transaction {
inner: Mutex<Option<bsql_driver_postgres::Transaction>>,
finished: bool,
}
impl Transaction {
pub(crate) fn from_driver(tx: bsql_driver_postgres::Transaction) -> Self {
Self {
inner: Mutex::new(Some(tx)),
finished: false,
}
}
fn consumed_error() -> BsqlError {
BsqlError::Query(QueryError {
message: "transaction already consumed".into(),
pg_code: None,
source: None,
})
}
pub async fn commit(mut self) -> BsqlResult<()> {
self.finished = true;
let tx = self
.inner
.lock()
.await
.take()
.ok_or_else(Self::consumed_error)?;
tx.commit().await.map_err(BsqlError::from)
}
pub async fn rollback(mut self) -> BsqlResult<()> {
self.finished = true;
let tx = self
.inner
.lock()
.await
.take()
.ok_or_else(Self::consumed_error)?;
tx.rollback().await.map_err(BsqlError::from)
}
pub async fn savepoint(&self, name: &str) -> BsqlResult<()> {
validate_savepoint_name(name)?;
let sql = format!("SAVEPOINT {name}");
let mut guard = self.inner.lock().await;
let tx = guard.as_mut().ok_or_else(Self::consumed_error)?;
tx.simple_query(&sql)
.await
.map_err(BsqlError::from_driver_query)
}
pub async fn release_savepoint(&self, name: &str) -> BsqlResult<()> {
validate_savepoint_name(name)?;
let sql = format!("RELEASE SAVEPOINT {name}");
let mut guard = self.inner.lock().await;
let tx = guard.as_mut().ok_or_else(Self::consumed_error)?;
tx.simple_query(&sql)
.await
.map_err(BsqlError::from_driver_query)
}
pub async fn rollback_to(&self, name: &str) -> BsqlResult<()> {
validate_savepoint_name(name)?;
let sql = format!("ROLLBACK TO SAVEPOINT {name}");
let mut guard = self.inner.lock().await;
let tx = guard.as_mut().ok_or_else(Self::consumed_error)?;
tx.simple_query(&sql)
.await
.map_err(BsqlError::from_driver_query)
}
pub async fn set_isolation(&self, level: IsolationLevel) -> BsqlResult<()> {
let sql = format!("SET TRANSACTION ISOLATION LEVEL {}", level.as_sql());
let mut guard = self.inner.lock().await;
let tx = guard.as_mut().ok_or_else(Self::consumed_error)?;
tx.simple_query(&sql)
.await
.map_err(BsqlError::from_driver_query)
}
pub(crate) async fn query_inner(
&self,
sql: &str,
sql_hash: u64,
params: &[&(dyn Encode + Sync)],
) -> BsqlResult<OwnedResult> {
let mut guard = self.inner.lock().await;
let tx = guard.as_mut().ok_or_else(Self::consumed_error)?;
let mut arena = acquire_arena();
let result = tx
.query(sql, sql_hash, params, &mut arena)
.await
.map_err(BsqlError::from_driver_query)?;
Ok(OwnedResult::new(result, arena))
}
pub(crate) async fn execute_inner(
&self,
sql: &str,
sql_hash: u64,
params: &[&(dyn Encode + Sync)],
) -> BsqlResult<u64> {
let mut guard = self.inner.lock().await;
let tx = guard.as_mut().ok_or_else(Self::consumed_error)?;
tx.execute(sql, sql_hash, params)
.await
.map_err(BsqlError::from_driver_query)
}
pub async fn execute_pipeline(
&self,
sql: &str,
sql_hash: u64,
param_sets: &[&[&(dyn Encode + Sync)]],
) -> BsqlResult<Vec<u64>> {
let mut guard = self.inner.lock().await;
let tx = guard.as_mut().ok_or_else(Self::consumed_error)?;
tx.execute_pipeline(sql, sql_hash, param_sets)
.await
.map_err(BsqlError::from_driver_query)
}
#[doc(hidden)]
pub async fn defer_execute(
&self,
sql: &str,
sql_hash: u64,
params: &[&(dyn Encode + Sync)],
) -> BsqlResult<()> {
let mut guard = self.inner.lock().await;
let tx = guard.as_mut().ok_or_else(Self::consumed_error)?;
tx.defer_execute(sql, sql_hash, params)
.await
.map_err(BsqlError::from_driver_query)
}
#[doc(hidden)]
pub async fn flush_deferred(&self) -> BsqlResult<Vec<u64>> {
let mut guard = self.inner.lock().await;
let tx = guard.as_mut().ok_or_else(Self::consumed_error)?;
tx.flush_deferred()
.await
.map_err(BsqlError::from_driver_query)
}
#[doc(hidden)]
pub async fn deferred_count(&self) -> usize {
let guard = self.inner.lock().await;
match guard.as_ref() {
Some(tx) => tx.deferred_count(),
None => 0,
}
}
pub async fn for_each_raw<F>(
&self,
sql: &str,
sql_hash: u64,
params: &[&(dyn Encode + Sync)],
mut f: F,
) -> BsqlResult<()>
where
F: FnMut(bsql_driver_postgres::PgDataRow<'_>) -> BsqlResult<()>,
{
let mut guard = self.inner.lock().await;
let tx = guard.as_mut().ok_or_else(Self::consumed_error)?;
let mut user_err: Option<BsqlError> = None;
let driver_result = tx
.for_each(sql, sql_hash, params, |row| match f(row) {
Ok(()) => Ok(()),
Err(e) => {
user_err = Some(e);
Err(bsql_driver_postgres::DriverError::Protocol(
"for_each closure error".into(),
))
}
})
.await;
if let Some(e) = user_err {
return Err(e);
}
driver_result.map_err(BsqlError::from_driver_query)
}
#[doc(hidden)]
pub async fn __for_each_raw_bytes<F>(
&self,
sql: &str,
sql_hash: u64,
params: &[&(dyn Encode + Sync)],
mut f: F,
) -> BsqlResult<()>
where
F: FnMut(&[u8]) -> BsqlResult<()>,
{
let mut guard = self.inner.lock().await;
let tx = guard.as_mut().ok_or_else(Self::consumed_error)?;
let mut user_err: Option<BsqlError> = None;
let driver_result = tx
.for_each_raw(sql, sql_hash, params, |data| match f(data) {
Ok(()) => Ok(()),
Err(e) => {
user_err = Some(e);
Err(bsql_driver_postgres::DriverError::Protocol(
"for_each closure error".into(),
))
}
})
.await;
if let Some(e) = user_err {
return Err(e);
}
driver_result.map_err(BsqlError::from_driver_query)
}
}
impl fmt::Debug for Transaction {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Transaction")
.field("finished", &self.finished)
.finish()
}
}
impl Drop for Transaction {
fn drop(&mut self) {
if !self.finished {
eprintln!(
"bsql: Transaction dropped without commit() or rollback() — \
connection discarded from pool. This is safe but wasteful."
);
}
}
}
fn validate_savepoint_name(name: &str) -> BsqlResult<()> {
crate::util::validate_savepoint_name(name)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn validate_savepoint_name_valid() {
assert!(validate_savepoint_name("sp1").is_ok());
assert!(validate_savepoint_name("_sp").is_ok());
assert!(validate_savepoint_name("my_savepoint_123").is_ok());
}
#[test]
fn validate_savepoint_name_empty() {
assert!(validate_savepoint_name("").is_err());
}
#[test]
fn validate_savepoint_name_too_long() {
let long = "a".repeat(64);
assert!(validate_savepoint_name(&long).is_err());
}
#[test]
fn validate_savepoint_name_max_length() {
let max = "a".repeat(63);
assert!(validate_savepoint_name(&max).is_ok());
}
#[test]
fn validate_savepoint_name_starts_with_digit() {
assert!(validate_savepoint_name("1sp").is_err());
}
#[test]
fn validate_savepoint_name_starts_with_underscore() {
assert!(validate_savepoint_name("_sp").is_ok());
}
#[test]
fn validate_savepoint_name_special_chars() {
assert!(validate_savepoint_name("sp-1").is_err());
assert!(validate_savepoint_name("sp.1").is_err());
assert!(validate_savepoint_name("sp 1").is_err());
assert!(validate_savepoint_name("sp;1").is_err());
assert!(validate_savepoint_name("sp'1").is_err());
}
#[test]
fn isolation_level_display() {
assert_eq!(
IsolationLevel::ReadUncommitted.to_string(),
"READ UNCOMMITTED"
);
assert_eq!(IsolationLevel::ReadCommitted.to_string(), "READ COMMITTED");
assert_eq!(
IsolationLevel::RepeatableRead.to_string(),
"REPEATABLE READ"
);
assert_eq!(IsolationLevel::Serializable.to_string(), "SERIALIZABLE");
}
#[test]
fn isolation_level_clone() {
let level = IsolationLevel::Serializable;
let cloned = level;
assert_eq!(level, cloned);
}
#[test]
fn isolation_level_debug() {
let level = IsolationLevel::RepeatableRead;
let dbg = format!("{level:?}");
assert!(
dbg.contains("RepeatableRead"),
"Debug should show variant name: {dbg}"
);
}
#[test]
fn isolation_level_eq() {
assert_eq!(IsolationLevel::Serializable, IsolationLevel::Serializable);
assert_ne!(IsolationLevel::Serializable, IsolationLevel::ReadCommitted);
}
#[test]
fn transaction_debug_shows_finished_false() {
fn _assert_debug<T: std::fmt::Debug>() {}
_assert_debug::<Transaction>();
}
fn _assert_send<T: Send>() {}
fn _assert_sync<T: Sync>() {}
#[test]
fn transaction_is_send() {
_assert_send::<Transaction>();
}
#[test]
fn transaction_is_sync() {
_assert_sync::<Transaction>();
}
#[test]
fn isolation_level_is_send_and_sync() {
_assert_send::<IsolationLevel>();
_assert_sync::<IsolationLevel>();
}
}