aegis-client 0.2.6

Client SDK for Aegis database
Documentation
//! Aegis Client Transaction Management
//!
//! Transaction handling for database operations.
//!
//! @version 0.1.0
//! @author AutomataNexus Development Team

use crate::connection::PooledConnection;
use crate::error::ClientError;
use crate::result::{QueryResult, Value};
use std::sync::atomic::{AtomicBool, Ordering};

// =============================================================================
// Transaction
// =============================================================================

/// A database transaction.
pub struct Transaction {
    connection: PooledConnection,
    committed: AtomicBool,
    rolled_back: AtomicBool,
}

impl Transaction {
    /// Begin a new transaction.
    pub async fn begin(connection: PooledConnection) -> Result<Self, ClientError> {
        connection.execute("BEGIN").await?;

        Ok(Self {
            connection,
            committed: AtomicBool::new(false),
            rolled_back: AtomicBool::new(false),
        })
    }

    /// Check if the transaction is active.
    pub fn is_active(&self) -> bool {
        !self.committed.load(Ordering::SeqCst) && !self.rolled_back.load(Ordering::SeqCst)
    }

    /// Execute a query within the transaction.
    pub async fn query(&self, sql: &str) -> Result<QueryResult, ClientError> {
        self.check_active()?;
        self.connection.query(sql).await
    }

    /// Execute a query with parameters.
    pub async fn query_with_params(
        &self,
        sql: &str,
        params: Vec<Value>,
    ) -> Result<QueryResult, ClientError> {
        self.check_active()?;
        self.connection.query_with_params(sql, params).await
    }

    /// Execute a statement within the transaction.
    pub async fn execute(&self, sql: &str) -> Result<u64, ClientError> {
        self.check_active()?;
        self.connection.execute(sql).await
    }

    /// Execute a statement with parameters.
    pub async fn execute_with_params(
        &self,
        sql: &str,
        params: Vec<Value>,
    ) -> Result<u64, ClientError> {
        self.check_active()?;
        self.connection.execute_with_params(sql, params).await
    }

    /// Commit the transaction.
    pub async fn commit(self) -> Result<(), ClientError> {
        self.check_active()?;
        self.connection.execute("COMMIT").await?;
        self.committed.store(true, Ordering::SeqCst);
        Ok(())
    }

    /// Rollback the transaction.
    pub async fn rollback(self) -> Result<(), ClientError> {
        self.check_active()?;
        self.connection.execute("ROLLBACK").await?;
        self.rolled_back.store(true, Ordering::SeqCst);
        Ok(())
    }

    /// Create a savepoint.
    pub async fn savepoint(&self, name: &str) -> Result<Savepoint<'_>, ClientError> {
        self.check_active()?;
        self.connection
            .execute(&format!("SAVEPOINT {}", name))
            .await?;
        Ok(Savepoint {
            transaction: self,
            name: name.to_string(),
            released: AtomicBool::new(false),
        })
    }

    fn check_active(&self) -> Result<(), ClientError> {
        if !self.is_active() {
            return Err(ClientError::NoTransaction);
        }
        Ok(())
    }
}

impl Drop for Transaction {
    fn drop(&mut self) {
        // If the transaction is still active when dropped, it should be rolled back
        // In async context, we can't do async operations in Drop
        // A production implementation would use a background task or similar
        if self.is_active() {
            self.rolled_back.store(true, Ordering::SeqCst);
        }
    }
}

// =============================================================================
// Savepoint
// =============================================================================

/// A savepoint within a transaction.
pub struct Savepoint<'a> {
    transaction: &'a Transaction,
    name: String,
    released: AtomicBool,
}

impl<'a> Savepoint<'a> {
    /// Release the savepoint (commit changes since savepoint).
    pub async fn release(self) -> Result<(), ClientError> {
        if self.released.load(Ordering::SeqCst) {
            return Err(ClientError::NoTransaction);
        }
        self.transaction
            .connection
            .execute(&format!("RELEASE SAVEPOINT {}", self.name))
            .await?;
        self.released.store(true, Ordering::SeqCst);
        Ok(())
    }

    /// Rollback to the savepoint.
    pub async fn rollback(self) -> Result<(), ClientError> {
        if self.released.load(Ordering::SeqCst) {
            return Err(ClientError::NoTransaction);
        }
        self.transaction
            .connection
            .execute(&format!("ROLLBACK TO SAVEPOINT {}", self.name))
            .await?;
        self.released.store(true, Ordering::SeqCst);
        Ok(())
    }

    /// Get the savepoint name.
    pub fn name(&self) -> &str {
        &self.name
    }
}

// =============================================================================
// Transaction Options
// =============================================================================

/// Options for transaction behavior.
#[derive(Debug, Clone, Default)]
pub struct TransactionOptions {
    pub isolation_level: IsolationLevel,
    pub read_only: bool,
    pub deferrable: bool,
}

impl TransactionOptions {
    pub fn new() -> Self {
        Self::default()
    }

    pub fn with_isolation(mut self, level: IsolationLevel) -> Self {
        self.isolation_level = level;
        self
    }

    pub fn read_only(mut self) -> Self {
        self.read_only = true;
        self
    }

    pub fn deferrable(mut self) -> Self {
        self.deferrable = true;
        self
    }

    /// Generate the BEGIN statement for these options.
    pub fn begin_statement(&self) -> String {
        let mut parts = vec!["BEGIN".to_string()];

        match self.isolation_level {
            IsolationLevel::ReadCommitted => {
                parts.push("ISOLATION LEVEL READ COMMITTED".to_string());
            }
            IsolationLevel::RepeatableRead => {
                parts.push("ISOLATION LEVEL REPEATABLE READ".to_string());
            }
            IsolationLevel::Serializable => {
                parts.push("ISOLATION LEVEL SERIALIZABLE".to_string());
            }
            IsolationLevel::ReadUncommitted => {
                parts.push("ISOLATION LEVEL READ UNCOMMITTED".to_string());
            }
        }

        if self.read_only {
            parts.push("READ ONLY".to_string());
        }

        if self.deferrable {
            parts.push("DEFERRABLE".to_string());
        }

        parts.join(" ")
    }
}

/// Transaction isolation levels.
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub enum IsolationLevel {
    ReadUncommitted,
    #[default]
    ReadCommitted,
    RepeatableRead,
    Serializable,
}

// =============================================================================
// Tests
// =============================================================================

#[cfg(test)]
mod tests {
    use super::*;
    use crate::config::{ConnectionConfig, PoolConfig};
    use crate::pool::ConnectionPool;

    /// Get test connection config - uses AEGIS_TEST_PORT env var or defaults to 9090
    fn test_connection_config() -> ConnectionConfig {
        let port = std::env::var("AEGIS_TEST_PORT")
            .ok()
            .and_then(|p| p.parse().ok())
            .unwrap_or(9090);
        ConnectionConfig {
            host: "127.0.0.1".to_string(),
            port,
            ..Default::default()
        }
    }

    async fn try_create_transaction() -> Option<Transaction> {
        let config = PoolConfig::default();
        let pool = ConnectionPool::with_connection_config(config, test_connection_config())
            .await
            .ok()?;
        let conn = pool.get().await.ok()?;
        Transaction::begin(conn).await.ok()
    }

    #[tokio::test]
    async fn test_transaction_begin() {
        if let Some(tx) = try_create_transaction().await {
            assert!(tx.is_active());
        } else {
            eprintln!("Skipping test, server not available");
        }
    }

    #[tokio::test]
    async fn test_transaction_commit() {
        if let Some(tx) = try_create_transaction().await {
            tx.commit()
                .await
                .expect("Transaction commit should succeed");
        } else {
            eprintln!("Skipping test, server not available");
        }
    }

    #[tokio::test]
    async fn test_transaction_rollback() {
        if let Some(tx) = try_create_transaction().await {
            tx.rollback()
                .await
                .expect("Transaction rollback should succeed");
        } else {
            eprintln!("Skipping test, server not available");
        }
    }

    #[tokio::test]
    async fn test_transaction_execute() {
        if let Some(tx) = try_create_transaction().await {
            // Note: This may fail if the server doesn't support this query
            match tx.execute("INSERT INTO test VALUES (1)").await {
                Ok(affected) => {
                    assert_eq!(affected, 0); // Server may return 0 for unsupported
                    let _ = tx.commit().await;
                }
                Err(_) => {
                    let _ = tx.rollback().await;
                }
            }
        } else {
            eprintln!("Skipping test, server not available");
        }
    }

    #[test]
    fn test_transaction_options() {
        let opts = TransactionOptions::new()
            .with_isolation(IsolationLevel::Serializable)
            .read_only();

        let stmt = opts.begin_statement();
        assert!(stmt.contains("SERIALIZABLE"));
        assert!(stmt.contains("READ ONLY"));
    }

    #[test]
    fn test_isolation_levels() {
        let opts = TransactionOptions::new().with_isolation(IsolationLevel::RepeatableRead);

        assert!(opts.begin_statement().contains("REPEATABLE READ"));
    }
}