use crate::connection::PooledConnection;
use crate::error::ClientError;
use crate::result::{QueryResult, Value};
use std::sync::atomic::{AtomicBool, Ordering};
pub struct Transaction {
connection: PooledConnection,
committed: AtomicBool,
rolled_back: AtomicBool,
}
impl Transaction {
pub async fn begin(connection: PooledConnection) -> Result<Self, ClientError> {
connection.execute("BEGIN").await?;
Ok(Self {
connection,
committed: AtomicBool::new(false),
rolled_back: AtomicBool::new(false),
})
}
pub fn is_active(&self) -> bool {
!self.committed.load(Ordering::SeqCst) && !self.rolled_back.load(Ordering::SeqCst)
}
pub async fn query(&self, sql: &str) -> Result<QueryResult, ClientError> {
self.check_active()?;
self.connection.query(sql).await
}
pub async fn query_with_params(
&self,
sql: &str,
params: Vec<Value>,
) -> Result<QueryResult, ClientError> {
self.check_active()?;
self.connection.query_with_params(sql, params).await
}
pub async fn execute(&self, sql: &str) -> Result<u64, ClientError> {
self.check_active()?;
self.connection.execute(sql).await
}
pub async fn execute_with_params(
&self,
sql: &str,
params: Vec<Value>,
) -> Result<u64, ClientError> {
self.check_active()?;
self.connection.execute_with_params(sql, params).await
}
pub async fn commit(self) -> Result<(), ClientError> {
self.check_active()?;
self.connection.execute("COMMIT").await?;
self.committed.store(true, Ordering::SeqCst);
Ok(())
}
pub async fn rollback(self) -> Result<(), ClientError> {
self.check_active()?;
self.connection.execute("ROLLBACK").await?;
self.rolled_back.store(true, Ordering::SeqCst);
Ok(())
}
pub async fn savepoint(&self, name: &str) -> Result<Savepoint<'_>, ClientError> {
self.check_active()?;
self.connection
.execute(&format!("SAVEPOINT {}", name))
.await?;
Ok(Savepoint {
transaction: self,
name: name.to_string(),
released: AtomicBool::new(false),
})
}
fn check_active(&self) -> Result<(), ClientError> {
if !self.is_active() {
return Err(ClientError::NoTransaction);
}
Ok(())
}
}
impl Drop for Transaction {
fn drop(&mut self) {
if self.is_active() {
self.rolled_back.store(true, Ordering::SeqCst);
}
}
}
pub struct Savepoint<'a> {
transaction: &'a Transaction,
name: String,
released: AtomicBool,
}
impl<'a> Savepoint<'a> {
pub async fn release(self) -> Result<(), ClientError> {
if self.released.load(Ordering::SeqCst) {
return Err(ClientError::NoTransaction);
}
self.transaction
.connection
.execute(&format!("RELEASE SAVEPOINT {}", self.name))
.await?;
self.released.store(true, Ordering::SeqCst);
Ok(())
}
pub async fn rollback(self) -> Result<(), ClientError> {
if self.released.load(Ordering::SeqCst) {
return Err(ClientError::NoTransaction);
}
self.transaction
.connection
.execute(&format!("ROLLBACK TO SAVEPOINT {}", self.name))
.await?;
self.released.store(true, Ordering::SeqCst);
Ok(())
}
pub fn name(&self) -> &str {
&self.name
}
}
#[derive(Debug, Clone, Default)]
pub struct TransactionOptions {
pub isolation_level: IsolationLevel,
pub read_only: bool,
pub deferrable: bool,
}
impl TransactionOptions {
pub fn new() -> Self {
Self::default()
}
pub fn with_isolation(mut self, level: IsolationLevel) -> Self {
self.isolation_level = level;
self
}
pub fn read_only(mut self) -> Self {
self.read_only = true;
self
}
pub fn deferrable(mut self) -> Self {
self.deferrable = true;
self
}
pub fn begin_statement(&self) -> String {
let mut parts = vec!["BEGIN".to_string()];
match self.isolation_level {
IsolationLevel::ReadCommitted => {
parts.push("ISOLATION LEVEL READ COMMITTED".to_string());
}
IsolationLevel::RepeatableRead => {
parts.push("ISOLATION LEVEL REPEATABLE READ".to_string());
}
IsolationLevel::Serializable => {
parts.push("ISOLATION LEVEL SERIALIZABLE".to_string());
}
IsolationLevel::ReadUncommitted => {
parts.push("ISOLATION LEVEL READ UNCOMMITTED".to_string());
}
}
if self.read_only {
parts.push("READ ONLY".to_string());
}
if self.deferrable {
parts.push("DEFERRABLE".to_string());
}
parts.join(" ")
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub enum IsolationLevel {
ReadUncommitted,
#[default]
ReadCommitted,
RepeatableRead,
Serializable,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::{ConnectionConfig, PoolConfig};
use crate::pool::ConnectionPool;
fn test_connection_config() -> ConnectionConfig {
let port = std::env::var("AEGIS_TEST_PORT")
.ok()
.and_then(|p| p.parse().ok())
.unwrap_or(9090);
ConnectionConfig {
host: "127.0.0.1".to_string(),
port,
..Default::default()
}
}
async fn try_create_transaction() -> Option<Transaction> {
let config = PoolConfig::default();
let pool = ConnectionPool::with_connection_config(config, test_connection_config())
.await
.ok()?;
let conn = pool.get().await.ok()?;
Transaction::begin(conn).await.ok()
}
#[tokio::test]
async fn test_transaction_begin() {
if let Some(tx) = try_create_transaction().await {
assert!(tx.is_active());
} else {
eprintln!("Skipping test, server not available");
}
}
#[tokio::test]
async fn test_transaction_commit() {
if let Some(tx) = try_create_transaction().await {
tx.commit()
.await
.expect("Transaction commit should succeed");
} else {
eprintln!("Skipping test, server not available");
}
}
#[tokio::test]
async fn test_transaction_rollback() {
if let Some(tx) = try_create_transaction().await {
tx.rollback()
.await
.expect("Transaction rollback should succeed");
} else {
eprintln!("Skipping test, server not available");
}
}
#[tokio::test]
async fn test_transaction_execute() {
if let Some(tx) = try_create_transaction().await {
match tx.execute("INSERT INTO test VALUES (1)").await {
Ok(affected) => {
assert_eq!(affected, 0); let _ = tx.commit().await;
}
Err(_) => {
let _ = tx.rollback().await;
}
}
} else {
eprintln!("Skipping test, server not available");
}
}
#[test]
fn test_transaction_options() {
let opts = TransactionOptions::new()
.with_isolation(IsolationLevel::Serializable)
.read_only();
let stmt = opts.begin_statement();
assert!(stmt.contains("SERIALIZABLE"));
assert!(stmt.contains("READ ONLY"));
}
#[test]
fn test_isolation_levels() {
let opts = TransactionOptions::new().with_isolation(IsolationLevel::RepeatableRead);
assert!(opts.begin_statement().contains("REPEATABLE READ"));
}
}