use std::fmt;
use bsql_driver_postgres::codec::Encode;
use crate::error::{BsqlError, BsqlResult, QueryError};
use crate::executor::OwnedResult;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum IsolationLevel {
ReadUncommitted,
ReadCommitted,
RepeatableRead,
Serializable,
}
impl IsolationLevel {
fn as_sql(&self) -> &'static str {
match self {
IsolationLevel::ReadUncommitted => "READ UNCOMMITTED",
IsolationLevel::ReadCommitted => "READ COMMITTED",
IsolationLevel::RepeatableRead => "REPEATABLE READ",
IsolationLevel::Serializable => "SERIALIZABLE",
}
}
}
impl fmt::Display for IsolationLevel {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.as_sql())
}
}
pub struct Transaction {
inner: Option<bsql_driver_postgres::Transaction>,
finished: bool,
}
impl Transaction {
pub(crate) fn from_driver(tx: bsql_driver_postgres::Transaction) -> Self {
Self {
inner: Some(tx),
finished: false,
}
}
fn consumed_error() -> BsqlError {
BsqlError::Query(QueryError {
message: "transaction already consumed".into(),
pg_code: None,
source: None,
})
}
pub async fn commit(mut self) -> BsqlResult<()> {
self.finished = true;
let tx = self.inner.take().ok_or_else(Self::consumed_error)?;
tx.commit().map_err(BsqlError::from)
}
pub async fn rollback(mut self) -> BsqlResult<()> {
self.finished = true;
let tx = self.inner.take().ok_or_else(Self::consumed_error)?;
tx.rollback().map_err(BsqlError::from)
}
pub async fn savepoint(&mut self, name: &str) -> BsqlResult<()> {
validate_savepoint_name(name)?;
let sql = format!("SAVEPOINT {name}");
let tx = self.inner.as_mut().ok_or_else(Self::consumed_error)?;
tx.simple_query(&sql).map_err(BsqlError::from_driver_query)
}
pub async fn release_savepoint(&mut self, name: &str) -> BsqlResult<()> {
validate_savepoint_name(name)?;
let sql = format!("RELEASE SAVEPOINT {name}");
let tx = self.inner.as_mut().ok_or_else(Self::consumed_error)?;
tx.simple_query(&sql).map_err(BsqlError::from_driver_query)
}
pub async fn rollback_to(&mut self, name: &str) -> BsqlResult<()> {
validate_savepoint_name(name)?;
let sql = format!("ROLLBACK TO SAVEPOINT {name}");
let tx = self.inner.as_mut().ok_or_else(Self::consumed_error)?;
tx.simple_query(&sql).map_err(BsqlError::from_driver_query)
}
pub async fn set_isolation(&mut self, level: IsolationLevel) -> BsqlResult<()> {
let sql = format!("SET TRANSACTION ISOLATION LEVEL {}", level.as_sql());
let tx = self.inner.as_mut().ok_or_else(Self::consumed_error)?;
tx.simple_query(&sql).map_err(BsqlError::from_driver_query)
}
pub(crate) fn query_inner(
&mut self,
sql: &str,
sql_hash: u64,
params: &[&(dyn Encode + Sync)],
) -> BsqlResult<OwnedResult> {
let tx = self.inner.as_mut().ok_or_else(Self::consumed_error)?;
let result = tx
.query(sql, sql_hash, params)
.map_err(BsqlError::from_driver_query)?;
Ok(OwnedResult::without_arena(result))
}
pub(crate) fn execute_inner(
&mut self,
sql: &str,
sql_hash: u64,
params: &[&(dyn Encode + Sync)],
) -> BsqlResult<u64> {
let tx = self.inner.as_mut().ok_or_else(Self::consumed_error)?;
tx.execute(sql, sql_hash, params)
.map_err(BsqlError::from_driver_query)
}
pub async fn execute_pipeline(
&mut self,
sql: &str,
sql_hash: u64,
param_sets: &[&[&(dyn Encode + Sync)]],
) -> BsqlResult<Vec<u64>> {
let tx = self.inner.as_mut().ok_or_else(Self::consumed_error)?;
tx.execute_pipeline(sql, sql_hash, param_sets)
.map_err(BsqlError::from_driver_query)
}
#[doc(hidden)]
pub async fn defer_execute(
&mut self,
sql: &str,
sql_hash: u64,
params: &[&(dyn Encode + Sync)],
) -> BsqlResult<()> {
let tx = self.inner.as_mut().ok_or_else(Self::consumed_error)?;
tx.defer_execute(sql, sql_hash, params)
.map_err(BsqlError::from_driver_query)
}
#[doc(hidden)]
pub async fn flush_deferred(&mut self) -> BsqlResult<Vec<u64>> {
let tx = self.inner.as_mut().ok_or_else(Self::consumed_error)?;
tx.flush_deferred().map_err(BsqlError::from_driver_query)
}
#[doc(hidden)]
pub fn deferred_count(&self) -> usize {
match self.inner.as_ref() {
Some(tx) => tx.deferred_count(),
None => 0,
}
}
pub async fn for_each_raw<F>(
&mut self,
sql: &str,
sql_hash: u64,
params: &[&(dyn Encode + Sync)],
mut f: F,
) -> BsqlResult<()>
where
F: FnMut(bsql_driver_postgres::PgDataRow<'_>) -> BsqlResult<()>,
{
let tx = self.inner.as_mut().ok_or_else(Self::consumed_error)?;
let mut user_err: Option<BsqlError> = None;
let driver_result = tx.for_each(sql, sql_hash, params, |row| match f(row) {
Ok(()) => Ok(()),
Err(e) => {
user_err = Some(e);
Err(bsql_driver_postgres::DriverError::Protocol(
"for_each closure error".into(),
))
}
});
if let Some(e) = user_err {
return Err(e);
}
driver_result.map_err(BsqlError::from_driver_query)
}
#[doc(hidden)]
pub async fn __for_each_raw_bytes<F>(
&mut self,
sql: &str,
sql_hash: u64,
params: &[&(dyn Encode + Sync)],
mut f: F,
) -> BsqlResult<()>
where
F: FnMut(&[u8]) -> BsqlResult<()>,
{
let tx = self.inner.as_mut().ok_or_else(Self::consumed_error)?;
let mut user_err: Option<BsqlError> = None;
let driver_result = tx.for_each_raw(sql, sql_hash, params, |data| match f(data) {
Ok(()) => Ok(()),
Err(e) => {
user_err = Some(e);
Err(bsql_driver_postgres::DriverError::Protocol(
"for_each closure error".into(),
))
}
});
if let Some(e) = user_err {
return Err(e);
}
driver_result.map_err(BsqlError::from_driver_query)
}
}
impl fmt::Debug for Transaction {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Transaction")
.field("finished", &self.finished)
.finish()
}
}
impl Drop for Transaction {
fn drop(&mut self) {
if !self.finished {
log::warn!(
"bsql: Transaction dropped without commit() or rollback() — \
connection discarded from pool. This is safe but wasteful."
);
}
}
}
fn validate_savepoint_name(name: &str) -> BsqlResult<()> {
crate::util::validate_savepoint_name(name)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn validate_savepoint_name_valid() {
assert!(validate_savepoint_name("sp1").is_ok());
assert!(validate_savepoint_name("_sp").is_ok());
assert!(validate_savepoint_name("my_savepoint_123").is_ok());
}
#[test]
fn validate_savepoint_name_empty() {
assert!(validate_savepoint_name("").is_err());
}
#[test]
fn validate_savepoint_name_too_long() {
let long = "a".repeat(64);
assert!(validate_savepoint_name(&long).is_err());
}
#[test]
fn validate_savepoint_name_max_length() {
let max = "a".repeat(63);
assert!(validate_savepoint_name(&max).is_ok());
}
#[test]
fn validate_savepoint_name_starts_with_digit() {
assert!(validate_savepoint_name("1sp").is_err());
}
#[test]
fn validate_savepoint_name_starts_with_underscore() {
assert!(validate_savepoint_name("_sp").is_ok());
}
#[test]
fn validate_savepoint_name_special_chars() {
assert!(validate_savepoint_name("sp-1").is_err());
assert!(validate_savepoint_name("sp.1").is_err());
assert!(validate_savepoint_name("sp 1").is_err());
assert!(validate_savepoint_name("sp;1").is_err());
assert!(validate_savepoint_name("sp'1").is_err());
}
#[test]
fn isolation_level_display() {
assert_eq!(
IsolationLevel::ReadUncommitted.to_string(),
"READ UNCOMMITTED"
);
assert_eq!(IsolationLevel::ReadCommitted.to_string(), "READ COMMITTED");
assert_eq!(
IsolationLevel::RepeatableRead.to_string(),
"REPEATABLE READ"
);
assert_eq!(IsolationLevel::Serializable.to_string(), "SERIALIZABLE");
}
#[test]
fn isolation_level_clone() {
let level = IsolationLevel::Serializable;
let cloned = level;
assert_eq!(level, cloned);
}
#[test]
fn isolation_level_debug() {
let level = IsolationLevel::RepeatableRead;
let dbg = format!("{level:?}");
assert!(
dbg.contains("RepeatableRead"),
"Debug should show variant name: {dbg}"
);
}
#[test]
fn isolation_level_eq() {
assert_eq!(IsolationLevel::Serializable, IsolationLevel::Serializable);
assert_ne!(IsolationLevel::Serializable, IsolationLevel::ReadCommitted);
}
#[test]
fn transaction_debug_shows_finished_false() {
fn _assert_debug<T: std::fmt::Debug>() {}
_assert_debug::<Transaction>();
}
fn _assert_send<T: Send>() {}
#[test]
fn transaction_is_send() {
_assert_send::<Transaction>();
}
#[test]
fn isolation_level_is_send() {
_assert_send::<IsolationLevel>();
}
#[test]
fn isolation_level_as_sql_all_variants() {
assert_eq!(IsolationLevel::ReadUncommitted.as_sql(), "READ UNCOMMITTED");
assert_eq!(IsolationLevel::ReadCommitted.as_sql(), "READ COMMITTED");
assert_eq!(IsolationLevel::RepeatableRead.as_sql(), "REPEATABLE READ");
assert_eq!(IsolationLevel::Serializable.as_sql(), "SERIALIZABLE");
}
#[test]
fn validate_savepoint_name_single_char() {
assert!(validate_savepoint_name("a").is_ok());
assert!(validate_savepoint_name("_").is_ok());
}
#[test]
fn validate_savepoint_name_all_digits_after_letter() {
assert!(validate_savepoint_name("a123456789").is_ok());
}
#[test]
fn validate_savepoint_name_all_underscores() {
assert!(validate_savepoint_name("___").is_ok());
}
#[test]
fn validate_savepoint_name_unicode_rejected() {
assert!(
validate_savepoint_name("sp_\u{00e9}").is_err(),
"unicode chars should be rejected"
);
}
#[test]
fn validate_savepoint_name_sql_injection_rejected() {
assert!(validate_savepoint_name("sp; DROP TABLE").is_err());
assert!(validate_savepoint_name("sp'--").is_err());
assert!(validate_savepoint_name("sp\"test").is_err());
}
#[test]
fn consumed_error_message_is_descriptive() {
let e = Transaction::consumed_error();
let display = e.to_string();
assert!(
display.contains("transaction already consumed"),
"consumed error should be descriptive: {display}"
);
}
#[test]
fn isolation_level_as_sql_is_idempotent() {
let level = IsolationLevel::Serializable;
assert_eq!(level.as_sql(), level.as_sql());
assert_eq!(level.as_sql(), "SERIALIZABLE");
}
#[test]
fn isolation_level_display_matches_as_sql() {
for level in [
IsolationLevel::ReadUncommitted,
IsolationLevel::ReadCommitted,
IsolationLevel::RepeatableRead,
IsolationLevel::Serializable,
] {
assert_eq!(level.to_string(), level.as_sql());
}
}
#[test]
fn transaction_from_driver_compiles() {
fn _check(_tx: bsql_driver_postgres::Transaction) -> Transaction {
Transaction::from_driver(_tx)
}
}
#[test]
fn validate_savepoint_name_null_byte_rejected() {
assert!(
validate_savepoint_name("sp\0name").is_err(),
"null byte in savepoint name should be rejected"
);
}
#[test]
fn validate_savepoint_name_boundary_63_and_64() {
let ok_63 = format!("a{}", "b".repeat(62));
assert!(validate_savepoint_name(&ok_63).is_ok());
let err_64 = format!("a{}", "b".repeat(63));
assert!(validate_savepoint_name(&err_64).is_err());
}
#[test]
fn consumed_error_is_query_variant() {
let e = Transaction::consumed_error();
assert!(
matches!(e, BsqlError::Query(_)),
"consumed_error should be Query variant"
);
}
}