Skip to main content

aegis_client/
transaction.rs

1//! Aegis Client Transaction Management
2//!
3//! Transaction handling for database operations.
4//!
5//! @version 0.1.0
6//! @author AutomataNexus Development Team
7
8use crate::connection::PooledConnection;
9use crate::error::ClientError;
10use crate::result::{QueryResult, Value};
11use std::sync::atomic::{AtomicBool, Ordering};
12
13// =============================================================================
14// Transaction
15// =============================================================================
16
17/// A database transaction.
18pub struct Transaction {
19    connection: PooledConnection,
20    committed: AtomicBool,
21    rolled_back: AtomicBool,
22}
23
24impl Transaction {
25    /// Begin a new transaction.
26    pub async fn begin(connection: PooledConnection) -> Result<Self, ClientError> {
27        connection.execute("BEGIN").await?;
28
29        Ok(Self {
30            connection,
31            committed: AtomicBool::new(false),
32            rolled_back: AtomicBool::new(false),
33        })
34    }
35
36    /// Check if the transaction is active.
37    pub fn is_active(&self) -> bool {
38        !self.committed.load(Ordering::SeqCst) && !self.rolled_back.load(Ordering::SeqCst)
39    }
40
41    /// Execute a query within the transaction.
42    pub async fn query(&self, sql: &str) -> Result<QueryResult, ClientError> {
43        self.check_active()?;
44        self.connection.query(sql).await
45    }
46
47    /// Execute a query with parameters.
48    pub async fn query_with_params(
49        &self,
50        sql: &str,
51        params: Vec<Value>,
52    ) -> Result<QueryResult, ClientError> {
53        self.check_active()?;
54        self.connection.query_with_params(sql, params).await
55    }
56
57    /// Execute a statement within the transaction.
58    pub async fn execute(&self, sql: &str) -> Result<u64, ClientError> {
59        self.check_active()?;
60        self.connection.execute(sql).await
61    }
62
63    /// Execute a statement with parameters.
64    pub async fn execute_with_params(
65        &self,
66        sql: &str,
67        params: Vec<Value>,
68    ) -> Result<u64, ClientError> {
69        self.check_active()?;
70        self.connection.execute_with_params(sql, params).await
71    }
72
73    /// Commit the transaction.
74    pub async fn commit(self) -> Result<(), ClientError> {
75        self.check_active()?;
76        self.connection.execute("COMMIT").await?;
77        self.committed.store(true, Ordering::SeqCst);
78        Ok(())
79    }
80
81    /// Rollback the transaction.
82    pub async fn rollback(self) -> Result<(), ClientError> {
83        self.check_active()?;
84        self.connection.execute("ROLLBACK").await?;
85        self.rolled_back.store(true, Ordering::SeqCst);
86        Ok(())
87    }
88
89    /// Create a savepoint.
90    pub async fn savepoint(&self, name: &str) -> Result<Savepoint<'_>, ClientError> {
91        self.check_active()?;
92        self.connection
93            .execute(&format!("SAVEPOINT {}", name))
94            .await?;
95        Ok(Savepoint {
96            transaction: self,
97            name: name.to_string(),
98            released: AtomicBool::new(false),
99        })
100    }
101
102    fn check_active(&self) -> Result<(), ClientError> {
103        if !self.is_active() {
104            return Err(ClientError::NoTransaction);
105        }
106        Ok(())
107    }
108}
109
110impl Drop for Transaction {
111    fn drop(&mut self) {
112        // If the transaction is still active when dropped, it should be rolled back
113        // In async context, we can't do async operations in Drop
114        // A production implementation would use a background task or similar
115        if self.is_active() {
116            self.rolled_back.store(true, Ordering::SeqCst);
117        }
118    }
119}
120
121// =============================================================================
122// Savepoint
123// =============================================================================
124
125/// A savepoint within a transaction.
126pub struct Savepoint<'a> {
127    transaction: &'a Transaction,
128    name: String,
129    released: AtomicBool,
130}
131
132impl<'a> Savepoint<'a> {
133    /// Release the savepoint (commit changes since savepoint).
134    pub async fn release(self) -> Result<(), ClientError> {
135        if self.released.load(Ordering::SeqCst) {
136            return Err(ClientError::NoTransaction);
137        }
138        self.transaction
139            .connection
140            .execute(&format!("RELEASE SAVEPOINT {}", self.name))
141            .await?;
142        self.released.store(true, Ordering::SeqCst);
143        Ok(())
144    }
145
146    /// Rollback to the savepoint.
147    pub async fn rollback(self) -> Result<(), ClientError> {
148        if self.released.load(Ordering::SeqCst) {
149            return Err(ClientError::NoTransaction);
150        }
151        self.transaction
152            .connection
153            .execute(&format!("ROLLBACK TO SAVEPOINT {}", self.name))
154            .await?;
155        self.released.store(true, Ordering::SeqCst);
156        Ok(())
157    }
158
159    /// Get the savepoint name.
160    pub fn name(&self) -> &str {
161        &self.name
162    }
163}
164
165// =============================================================================
166// Transaction Options
167// =============================================================================
168
169/// Options for transaction behavior.
170#[derive(Debug, Clone, Default)]
171pub struct TransactionOptions {
172    pub isolation_level: IsolationLevel,
173    pub read_only: bool,
174    pub deferrable: bool,
175}
176
177impl TransactionOptions {
178    pub fn new() -> Self {
179        Self::default()
180    }
181
182    pub fn with_isolation(mut self, level: IsolationLevel) -> Self {
183        self.isolation_level = level;
184        self
185    }
186
187    pub fn read_only(mut self) -> Self {
188        self.read_only = true;
189        self
190    }
191
192    pub fn deferrable(mut self) -> Self {
193        self.deferrable = true;
194        self
195    }
196
197    /// Generate the BEGIN statement for these options.
198    pub fn begin_statement(&self) -> String {
199        let mut parts = vec!["BEGIN".to_string()];
200
201        match self.isolation_level {
202            IsolationLevel::ReadCommitted => {
203                parts.push("ISOLATION LEVEL READ COMMITTED".to_string());
204            }
205            IsolationLevel::RepeatableRead => {
206                parts.push("ISOLATION LEVEL REPEATABLE READ".to_string());
207            }
208            IsolationLevel::Serializable => {
209                parts.push("ISOLATION LEVEL SERIALIZABLE".to_string());
210            }
211            IsolationLevel::ReadUncommitted => {
212                parts.push("ISOLATION LEVEL READ UNCOMMITTED".to_string());
213            }
214        }
215
216        if self.read_only {
217            parts.push("READ ONLY".to_string());
218        }
219
220        if self.deferrable {
221            parts.push("DEFERRABLE".to_string());
222        }
223
224        parts.join(" ")
225    }
226}
227
228/// Transaction isolation levels.
229#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
230pub enum IsolationLevel {
231    ReadUncommitted,
232    #[default]
233    ReadCommitted,
234    RepeatableRead,
235    Serializable,
236}
237
238// =============================================================================
239// Tests
240// =============================================================================
241
242#[cfg(test)]
243mod tests {
244    use super::*;
245    use crate::config::{ConnectionConfig, PoolConfig};
246    use crate::pool::ConnectionPool;
247
248    /// Get test connection config - uses AEGIS_TEST_PORT env var or defaults to 9090
249    fn test_connection_config() -> ConnectionConfig {
250        let port = std::env::var("AEGIS_TEST_PORT")
251            .ok()
252            .and_then(|p| p.parse().ok())
253            .unwrap_or(9090);
254        ConnectionConfig {
255            host: "127.0.0.1".to_string(),
256            port,
257            ..Default::default()
258        }
259    }
260
261    async fn try_create_transaction() -> Option<Transaction> {
262        let config = PoolConfig::default();
263        let pool = ConnectionPool::with_connection_config(config, test_connection_config())
264            .await
265            .ok()?;
266        let conn = pool.get().await.ok()?;
267        Transaction::begin(conn).await.ok()
268    }
269
270    #[tokio::test]
271    async fn test_transaction_begin() {
272        if let Some(tx) = try_create_transaction().await {
273            assert!(tx.is_active());
274        } else {
275            eprintln!("Skipping test, server not available");
276        }
277    }
278
279    #[tokio::test]
280    async fn test_transaction_commit() {
281        if let Some(tx) = try_create_transaction().await {
282            tx.commit()
283                .await
284                .expect("Transaction commit should succeed");
285        } else {
286            eprintln!("Skipping test, server not available");
287        }
288    }
289
290    #[tokio::test]
291    async fn test_transaction_rollback() {
292        if let Some(tx) = try_create_transaction().await {
293            tx.rollback()
294                .await
295                .expect("Transaction rollback should succeed");
296        } else {
297            eprintln!("Skipping test, server not available");
298        }
299    }
300
301    #[tokio::test]
302    async fn test_transaction_execute() {
303        if let Some(tx) = try_create_transaction().await {
304            // Note: This may fail if the server doesn't support this query
305            match tx.execute("INSERT INTO test VALUES (1)").await {
306                Ok(affected) => {
307                    assert_eq!(affected, 0); // Server may return 0 for unsupported
308                    let _ = tx.commit().await;
309                }
310                Err(_) => {
311                    let _ = tx.rollback().await;
312                }
313            }
314        } else {
315            eprintln!("Skipping test, server not available");
316        }
317    }
318
319    #[test]
320    fn test_transaction_options() {
321        let opts = TransactionOptions::new()
322            .with_isolation(IsolationLevel::Serializable)
323            .read_only();
324
325        let stmt = opts.begin_statement();
326        assert!(stmt.contains("SERIALIZABLE"));
327        assert!(stmt.contains("READ ONLY"));
328    }
329
330    #[test]
331    fn test_isolation_levels() {
332        let opts = TransactionOptions::new().with_isolation(IsolationLevel::RepeatableRead);
333
334        assert!(opts.begin_statement().contains("REPEATABLE READ"));
335    }
336}