aegis_client/
transaction.rs1use crate::connection::PooledConnection;
9use crate::error::ClientError;
10use crate::result::{QueryResult, Value};
11use std::sync::atomic::{AtomicBool, Ordering};
12
13pub struct Transaction {
19 connection: PooledConnection,
20 committed: AtomicBool,
21 rolled_back: AtomicBool,
22}
23
24impl Transaction {
25 pub async fn begin(connection: PooledConnection) -> Result<Self, ClientError> {
27 connection.execute("BEGIN").await?;
28
29 Ok(Self {
30 connection,
31 committed: AtomicBool::new(false),
32 rolled_back: AtomicBool::new(false),
33 })
34 }
35
36 pub fn is_active(&self) -> bool {
38 !self.committed.load(Ordering::SeqCst) && !self.rolled_back.load(Ordering::SeqCst)
39 }
40
41 pub async fn query(&self, sql: &str) -> Result<QueryResult, ClientError> {
43 self.check_active()?;
44 self.connection.query(sql).await
45 }
46
47 pub async fn query_with_params(
49 &self,
50 sql: &str,
51 params: Vec<Value>,
52 ) -> Result<QueryResult, ClientError> {
53 self.check_active()?;
54 self.connection.query_with_params(sql, params).await
55 }
56
57 pub async fn execute(&self, sql: &str) -> Result<u64, ClientError> {
59 self.check_active()?;
60 self.connection.execute(sql).await
61 }
62
63 pub async fn execute_with_params(
65 &self,
66 sql: &str,
67 params: Vec<Value>,
68 ) -> Result<u64, ClientError> {
69 self.check_active()?;
70 self.connection.execute_with_params(sql, params).await
71 }
72
73 pub async fn commit(self) -> Result<(), ClientError> {
75 self.check_active()?;
76 self.connection.execute("COMMIT").await?;
77 self.committed.store(true, Ordering::SeqCst);
78 Ok(())
79 }
80
81 pub async fn rollback(self) -> Result<(), ClientError> {
83 self.check_active()?;
84 self.connection.execute("ROLLBACK").await?;
85 self.rolled_back.store(true, Ordering::SeqCst);
86 Ok(())
87 }
88
89 pub async fn savepoint(&self, name: &str) -> Result<Savepoint<'_>, ClientError> {
91 self.check_active()?;
92 self.connection
93 .execute(&format!("SAVEPOINT {}", name))
94 .await?;
95 Ok(Savepoint {
96 transaction: self,
97 name: name.to_string(),
98 released: AtomicBool::new(false),
99 })
100 }
101
102 fn check_active(&self) -> Result<(), ClientError> {
103 if !self.is_active() {
104 return Err(ClientError::NoTransaction);
105 }
106 Ok(())
107 }
108}
109
110impl Drop for Transaction {
111 fn drop(&mut self) {
112 if self.is_active() {
116 self.rolled_back.store(true, Ordering::SeqCst);
117 }
118 }
119}
120
121pub struct Savepoint<'a> {
127 transaction: &'a Transaction,
128 name: String,
129 released: AtomicBool,
130}
131
132impl<'a> Savepoint<'a> {
133 pub async fn release(self) -> Result<(), ClientError> {
135 if self.released.load(Ordering::SeqCst) {
136 return Err(ClientError::NoTransaction);
137 }
138 self.transaction
139 .connection
140 .execute(&format!("RELEASE SAVEPOINT {}", self.name))
141 .await?;
142 self.released.store(true, Ordering::SeqCst);
143 Ok(())
144 }
145
146 pub async fn rollback(self) -> Result<(), ClientError> {
148 if self.released.load(Ordering::SeqCst) {
149 return Err(ClientError::NoTransaction);
150 }
151 self.transaction
152 .connection
153 .execute(&format!("ROLLBACK TO SAVEPOINT {}", self.name))
154 .await?;
155 self.released.store(true, Ordering::SeqCst);
156 Ok(())
157 }
158
159 pub fn name(&self) -> &str {
161 &self.name
162 }
163}
164
165#[derive(Debug, Clone, Default)]
171pub struct TransactionOptions {
172 pub isolation_level: IsolationLevel,
173 pub read_only: bool,
174 pub deferrable: bool,
175}
176
177impl TransactionOptions {
178 pub fn new() -> Self {
179 Self::default()
180 }
181
182 pub fn with_isolation(mut self, level: IsolationLevel) -> Self {
183 self.isolation_level = level;
184 self
185 }
186
187 pub fn read_only(mut self) -> Self {
188 self.read_only = true;
189 self
190 }
191
192 pub fn deferrable(mut self) -> Self {
193 self.deferrable = true;
194 self
195 }
196
197 pub fn begin_statement(&self) -> String {
199 let mut parts = vec!["BEGIN".to_string()];
200
201 match self.isolation_level {
202 IsolationLevel::ReadCommitted => {
203 parts.push("ISOLATION LEVEL READ COMMITTED".to_string());
204 }
205 IsolationLevel::RepeatableRead => {
206 parts.push("ISOLATION LEVEL REPEATABLE READ".to_string());
207 }
208 IsolationLevel::Serializable => {
209 parts.push("ISOLATION LEVEL SERIALIZABLE".to_string());
210 }
211 IsolationLevel::ReadUncommitted => {
212 parts.push("ISOLATION LEVEL READ UNCOMMITTED".to_string());
213 }
214 }
215
216 if self.read_only {
217 parts.push("READ ONLY".to_string());
218 }
219
220 if self.deferrable {
221 parts.push("DEFERRABLE".to_string());
222 }
223
224 parts.join(" ")
225 }
226}
227
228#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
230pub enum IsolationLevel {
231 ReadUncommitted,
232 #[default]
233 ReadCommitted,
234 RepeatableRead,
235 Serializable,
236}
237
238#[cfg(test)]
243mod tests {
244 use super::*;
245 use crate::config::{ConnectionConfig, PoolConfig};
246 use crate::pool::ConnectionPool;
247
248 fn test_connection_config() -> ConnectionConfig {
250 let port = std::env::var("AEGIS_TEST_PORT")
251 .ok()
252 .and_then(|p| p.parse().ok())
253 .unwrap_or(9090);
254 ConnectionConfig {
255 host: "127.0.0.1".to_string(),
256 port,
257 ..Default::default()
258 }
259 }
260
261 async fn try_create_transaction() -> Option<Transaction> {
262 let config = PoolConfig::default();
263 let pool = ConnectionPool::with_connection_config(config, test_connection_config())
264 .await
265 .ok()?;
266 let conn = pool.get().await.ok()?;
267 Transaction::begin(conn).await.ok()
268 }
269
270 #[tokio::test]
271 async fn test_transaction_begin() {
272 if let Some(tx) = try_create_transaction().await {
273 assert!(tx.is_active());
274 } else {
275 eprintln!("Skipping test, server not available");
276 }
277 }
278
279 #[tokio::test]
280 async fn test_transaction_commit() {
281 if let Some(tx) = try_create_transaction().await {
282 tx.commit()
283 .await
284 .expect("Transaction commit should succeed");
285 } else {
286 eprintln!("Skipping test, server not available");
287 }
288 }
289
290 #[tokio::test]
291 async fn test_transaction_rollback() {
292 if let Some(tx) = try_create_transaction().await {
293 tx.rollback()
294 .await
295 .expect("Transaction rollback should succeed");
296 } else {
297 eprintln!("Skipping test, server not available");
298 }
299 }
300
301 #[tokio::test]
302 async fn test_transaction_execute() {
303 if let Some(tx) = try_create_transaction().await {
304 match tx.execute("INSERT INTO test VALUES (1)").await {
306 Ok(affected) => {
307 assert_eq!(affected, 0); let _ = tx.commit().await;
309 }
310 Err(_) => {
311 let _ = tx.rollback().await;
312 }
313 }
314 } else {
315 eprintln!("Skipping test, server not available");
316 }
317 }
318
319 #[test]
320 fn test_transaction_options() {
321 let opts = TransactionOptions::new()
322 .with_isolation(IsolationLevel::Serializable)
323 .read_only();
324
325 let stmt = opts.begin_statement();
326 assert!(stmt.contains("SERIALIZABLE"));
327 assert!(stmt.contains("READ ONLY"));
328 }
329
330 #[test]
331 fn test_isolation_levels() {
332 let opts = TransactionOptions::new().with_isolation(IsolationLevel::RepeatableRead);
333
334 assert!(opts.begin_statement().contains("REPEATABLE READ"));
335 }
336}