1use std::future::Future;
89use std::time::Duration;
90use tracing::debug;
91
92use crate::error::QueryResult;
93
94#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
96pub enum IsolationLevel {
97 ReadUncommitted,
99 #[default]
101 ReadCommitted,
102 RepeatableRead,
104 Serializable,
106}
107
108impl IsolationLevel {
109 pub fn as_sql(&self) -> &'static str {
111 match self {
112 Self::ReadUncommitted => "READ UNCOMMITTED",
113 Self::ReadCommitted => "READ COMMITTED",
114 Self::RepeatableRead => "REPEATABLE READ",
115 Self::Serializable => "SERIALIZABLE",
116 }
117 }
118}
119
120#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
122pub enum AccessMode {
123 #[default]
125 ReadWrite,
126 ReadOnly,
128}
129
130impl AccessMode {
131 pub fn as_sql(&self) -> &'static str {
133 match self {
134 Self::ReadWrite => "READ WRITE",
135 Self::ReadOnly => "READ ONLY",
136 }
137 }
138}
139
140#[derive(Debug, Clone, Default)]
142pub struct TransactionConfig {
143 pub isolation: IsolationLevel,
145 pub access_mode: AccessMode,
147 pub timeout: Option<Duration>,
149 pub deferrable: bool,
151}
152
153impl TransactionConfig {
154 pub fn new() -> Self {
156 Self::default()
157 }
158
159 pub fn isolation(mut self, level: IsolationLevel) -> Self {
161 self.isolation = level;
162 self
163 }
164
165 pub fn access_mode(mut self, mode: AccessMode) -> Self {
167 self.access_mode = mode;
168 self
169 }
170
171 pub fn timeout(mut self, timeout: Duration) -> Self {
173 self.timeout = Some(timeout);
174 self
175 }
176
177 pub fn read_only(self) -> Self {
179 self.access_mode(AccessMode::ReadOnly)
180 }
181
182 pub fn deferrable(mut self) -> Self {
184 self.deferrable = true;
185 self
186 }
187
188 pub fn to_begin_sql(&self) -> String {
190 let mut parts = vec!["BEGIN"];
191
192 parts.push("ISOLATION LEVEL");
194 parts.push(self.isolation.as_sql());
195
196 parts.push(self.access_mode.as_sql());
198
199 if self.deferrable
201 && self.isolation == IsolationLevel::Serializable
202 && self.access_mode == AccessMode::ReadOnly
203 {
204 parts.push("DEFERRABLE");
205 }
206
207 let sql = parts.join(" ");
208 debug!(isolation = %self.isolation.as_sql(), access_mode = %self.access_mode.as_sql(), "Transaction BEGIN");
209 sql
210 }
211}
212
213pub struct Transaction<E> {
218 engine: E,
219 config: TransactionConfig,
220 committed: bool,
221 savepoint_count: u32,
222}
223
224impl<E> Transaction<E> {
225 pub fn new(engine: E, config: TransactionConfig) -> Self {
227 Self {
228 engine,
229 config,
230 committed: false,
231 savepoint_count: 0,
232 }
233 }
234
235 pub fn config(&self) -> &TransactionConfig {
237 &self.config
238 }
239
240 pub fn engine(&self) -> &E {
242 &self.engine
243 }
244
245 pub fn savepoint_name(&mut self) -> String {
247 self.savepoint_count += 1;
248 format!("sp_{}", self.savepoint_count)
249 }
250
251 pub fn mark_committed(&mut self) {
253 self.committed = true;
254 }
255
256 pub fn is_committed(&self) -> bool {
258 self.committed
259 }
260}
261
262pub struct TransactionBuilder<E, F, Fut, T>
264where
265 F: FnOnce(Transaction<E>) -> Fut,
266 Fut: Future<Output = QueryResult<T>>,
267{
268 engine: E,
269 callback: F,
270 config: TransactionConfig,
271}
272
273impl<E, F, Fut, T> TransactionBuilder<E, F, Fut, T>
274where
275 F: FnOnce(Transaction<E>) -> Fut,
276 Fut: Future<Output = QueryResult<T>>,
277{
278 pub fn new(engine: E, callback: F) -> Self {
280 Self {
281 engine,
282 callback,
283 config: TransactionConfig::default(),
284 }
285 }
286
287 pub fn isolation(mut self, level: IsolationLevel) -> Self {
289 self.config.isolation = level;
290 self
291 }
292
293 pub fn read_only(mut self) -> Self {
295 self.config.access_mode = AccessMode::ReadOnly;
296 self
297 }
298
299 pub fn timeout(mut self, timeout: Duration) -> Self {
301 self.config.timeout = Some(timeout);
302 self
303 }
304
305 pub fn deferrable(mut self) -> Self {
307 self.config.deferrable = true;
308 self
309 }
310}
311
312pub struct InteractiveTransaction<E> {
314 inner: Transaction<E>,
315 started: bool,
316}
317
318impl<E> InteractiveTransaction<E> {
319 pub fn new(engine: E) -> Self {
321 Self {
322 inner: Transaction::new(engine, TransactionConfig::default()),
323 started: false,
324 }
325 }
326
327 pub fn with_config(engine: E, config: TransactionConfig) -> Self {
329 Self {
330 inner: Transaction::new(engine, config),
331 started: false,
332 }
333 }
334
335 pub fn engine(&self) -> &E {
337 &self.inner.engine
338 }
339
340 pub fn is_started(&self) -> bool {
342 self.started
343 }
344
345 pub fn begin_sql(&self) -> String {
347 self.inner.config.to_begin_sql()
348 }
349
350 pub fn commit_sql(&self) -> &'static str {
352 "COMMIT"
353 }
354
355 pub fn rollback_sql(&self) -> &'static str {
357 "ROLLBACK"
358 }
359
360 pub fn savepoint_sql(&mut self, name: Option<&str>) -> String {
362 let name = name
363 .map(|s| s.to_string())
364 .unwrap_or_else(|| self.inner.savepoint_name());
365 format!("SAVEPOINT {}", name)
366 }
367
368 pub fn rollback_to_sql(&self, name: &str) -> String {
370 format!("ROLLBACK TO SAVEPOINT {}", name)
371 }
372
373 pub fn release_savepoint_sql(&self, name: &str) -> String {
375 format!("RELEASE SAVEPOINT {}", name)
376 }
377
378 pub fn mark_started(&mut self) {
380 self.started = true;
381 }
382}
383
384#[cfg(test)]
385mod tests {
386 use super::*;
387
388 #[test]
389 fn test_isolation_level() {
390 assert_eq!(IsolationLevel::ReadCommitted.as_sql(), "READ COMMITTED");
391 assert_eq!(IsolationLevel::Serializable.as_sql(), "SERIALIZABLE");
392 }
393
394 #[test]
395 fn test_access_mode() {
396 assert_eq!(AccessMode::ReadWrite.as_sql(), "READ WRITE");
397 assert_eq!(AccessMode::ReadOnly.as_sql(), "READ ONLY");
398 }
399
400 #[test]
401 fn test_transaction_config_default() {
402 let config = TransactionConfig::new();
403 assert_eq!(config.isolation, IsolationLevel::ReadCommitted);
404 assert_eq!(config.access_mode, AccessMode::ReadWrite);
405 assert!(config.timeout.is_none());
406 assert!(!config.deferrable);
407 }
408
409 #[test]
410 fn test_transaction_config_builder() {
411 let config = TransactionConfig::new()
412 .isolation(IsolationLevel::Serializable)
413 .read_only()
414 .deferrable()
415 .timeout(Duration::from_secs(30));
416
417 assert_eq!(config.isolation, IsolationLevel::Serializable);
418 assert_eq!(config.access_mode, AccessMode::ReadOnly);
419 assert!(config.deferrable);
420 assert_eq!(config.timeout, Some(Duration::from_secs(30)));
421 }
422
423 #[test]
424 fn test_begin_sql() {
425 let config = TransactionConfig::new();
426 let sql = config.to_begin_sql();
427 assert!(sql.contains("BEGIN"));
428 assert!(sql.contains("ISOLATION LEVEL READ COMMITTED"));
429 assert!(sql.contains("READ WRITE"));
430 }
431
432 #[test]
433 fn test_begin_sql_serializable_deferrable() {
434 let config = TransactionConfig::new()
435 .isolation(IsolationLevel::Serializable)
436 .read_only()
437 .deferrable();
438 let sql = config.to_begin_sql();
439 assert!(sql.contains("SERIALIZABLE"));
440 assert!(sql.contains("READ ONLY"));
441 assert!(sql.contains("DEFERRABLE"));
442 }
443
444 #[test]
445 fn test_interactive_transaction() {
446 #[derive(Clone)]
447 struct MockEngine;
448
449 let mut tx = InteractiveTransaction::new(MockEngine);
450 assert!(!tx.is_started());
451
452 let begin = tx.begin_sql();
453 assert!(begin.contains("BEGIN"));
454
455 let sp = tx.savepoint_sql(Some("test_sp"));
456 assert_eq!(sp, "SAVEPOINT test_sp");
457
458 let rollback_to = tx.rollback_to_sql("test_sp");
459 assert_eq!(rollback_to, "ROLLBACK TO SAVEPOINT test_sp");
460
461 let release = tx.release_savepoint_sql("test_sp");
462 assert_eq!(release, "RELEASE SAVEPOINT test_sp");
463 }
464}
465