1#![allow(dead_code)]
2
3use std::future::Future;
91use std::time::Duration;
92use tracing::debug;
93
94use crate::error::QueryResult;
95
96#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
98pub enum IsolationLevel {
99 ReadUncommitted,
101 #[default]
103 ReadCommitted,
104 RepeatableRead,
106 Serializable,
108}
109
110impl IsolationLevel {
111 pub fn as_sql(&self) -> &'static str {
113 match self {
114 Self::ReadUncommitted => "READ UNCOMMITTED",
115 Self::ReadCommitted => "READ COMMITTED",
116 Self::RepeatableRead => "REPEATABLE READ",
117 Self::Serializable => "SERIALIZABLE",
118 }
119 }
120}
121
122#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
124pub enum AccessMode {
125 #[default]
127 ReadWrite,
128 ReadOnly,
130}
131
132impl AccessMode {
133 pub fn as_sql(&self) -> &'static str {
135 match self {
136 Self::ReadWrite => "READ WRITE",
137 Self::ReadOnly => "READ ONLY",
138 }
139 }
140}
141
142#[derive(Debug, Clone, Default)]
144pub struct TransactionConfig {
145 pub isolation: IsolationLevel,
147 pub access_mode: AccessMode,
149 pub timeout: Option<Duration>,
151 pub deferrable: bool,
153}
154
155impl TransactionConfig {
156 pub fn new() -> Self {
158 Self::default()
159 }
160
161 pub fn isolation(mut self, level: IsolationLevel) -> Self {
163 self.isolation = level;
164 self
165 }
166
167 pub fn access_mode(mut self, mode: AccessMode) -> Self {
169 self.access_mode = mode;
170 self
171 }
172
173 pub fn timeout(mut self, timeout: Duration) -> Self {
175 self.timeout = Some(timeout);
176 self
177 }
178
179 pub fn read_only(self) -> Self {
181 self.access_mode(AccessMode::ReadOnly)
182 }
183
184 pub fn deferrable(mut self) -> Self {
186 self.deferrable = true;
187 self
188 }
189
190 pub fn to_begin_sql(&self) -> String {
192 let mut parts = vec!["BEGIN"];
193
194 parts.push("ISOLATION LEVEL");
196 parts.push(self.isolation.as_sql());
197
198 parts.push(self.access_mode.as_sql());
200
201 if self.deferrable
203 && self.isolation == IsolationLevel::Serializable
204 && self.access_mode == AccessMode::ReadOnly
205 {
206 parts.push("DEFERRABLE");
207 }
208
209 let sql = parts.join(" ");
210 debug!(isolation = %self.isolation.as_sql(), access_mode = %self.access_mode.as_sql(), "Transaction BEGIN");
211 sql
212 }
213}
214
215pub struct Transaction<E> {
220 engine: E,
221 config: TransactionConfig,
222 committed: bool,
223 savepoint_count: u32,
224}
225
226impl<E> Transaction<E> {
227 pub fn new(engine: E, config: TransactionConfig) -> Self {
229 Self {
230 engine,
231 config,
232 committed: false,
233 savepoint_count: 0,
234 }
235 }
236
237 pub fn config(&self) -> &TransactionConfig {
239 &self.config
240 }
241
242 pub fn engine(&self) -> &E {
244 &self.engine
245 }
246
247 pub fn savepoint_name(&mut self) -> String {
249 self.savepoint_count += 1;
250 format!("sp_{}", self.savepoint_count)
251 }
252
253 pub fn mark_committed(&mut self) {
255 self.committed = true;
256 }
257
258 pub fn is_committed(&self) -> bool {
260 self.committed
261 }
262}
263
264pub struct TransactionBuilder<E, F, Fut, T>
266where
267 F: FnOnce(Transaction<E>) -> Fut,
268 Fut: Future<Output = QueryResult<T>>,
269{
270 engine: E,
271 callback: F,
272 config: TransactionConfig,
273}
274
275impl<E, F, Fut, T> TransactionBuilder<E, F, Fut, T>
276where
277 F: FnOnce(Transaction<E>) -> Fut,
278 Fut: Future<Output = QueryResult<T>>,
279{
280 pub fn new(engine: E, callback: F) -> Self {
282 Self {
283 engine,
284 callback,
285 config: TransactionConfig::default(),
286 }
287 }
288
289 pub fn isolation(mut self, level: IsolationLevel) -> Self {
291 self.config.isolation = level;
292 self
293 }
294
295 pub fn read_only(mut self) -> Self {
297 self.config.access_mode = AccessMode::ReadOnly;
298 self
299 }
300
301 pub fn timeout(mut self, timeout: Duration) -> Self {
303 self.config.timeout = Some(timeout);
304 self
305 }
306
307 pub fn deferrable(mut self) -> Self {
309 self.config.deferrable = true;
310 self
311 }
312}
313
314pub struct InteractiveTransaction<E> {
316 inner: Transaction<E>,
317 started: bool,
318}
319
320impl<E> InteractiveTransaction<E> {
321 pub fn new(engine: E) -> Self {
323 Self {
324 inner: Transaction::new(engine, TransactionConfig::default()),
325 started: false,
326 }
327 }
328
329 pub fn with_config(engine: E, config: TransactionConfig) -> Self {
331 Self {
332 inner: Transaction::new(engine, config),
333 started: false,
334 }
335 }
336
337 pub fn engine(&self) -> &E {
339 &self.inner.engine
340 }
341
342 pub fn is_started(&self) -> bool {
344 self.started
345 }
346
347 pub fn begin_sql(&self) -> String {
349 self.inner.config.to_begin_sql()
350 }
351
352 pub fn commit_sql(&self) -> &'static str {
354 "COMMIT"
355 }
356
357 pub fn rollback_sql(&self) -> &'static str {
359 "ROLLBACK"
360 }
361
362 pub fn savepoint_sql(&mut self, name: Option<&str>) -> String {
364 let name = name
365 .map(|s| s.to_string())
366 .unwrap_or_else(|| self.inner.savepoint_name());
367 format!("SAVEPOINT {}", name)
368 }
369
370 pub fn rollback_to_sql(&self, name: &str) -> String {
372 format!("ROLLBACK TO SAVEPOINT {}", name)
373 }
374
375 pub fn release_savepoint_sql(&self, name: &str) -> String {
377 format!("RELEASE SAVEPOINT {}", name)
378 }
379
380 pub fn mark_started(&mut self) {
382 self.started = true;
383 }
384}
385
386#[cfg(test)]
387mod tests {
388 use super::*;
389
390 #[test]
391 fn test_isolation_level() {
392 assert_eq!(IsolationLevel::ReadCommitted.as_sql(), "READ COMMITTED");
393 assert_eq!(IsolationLevel::Serializable.as_sql(), "SERIALIZABLE");
394 }
395
396 #[test]
397 fn test_access_mode() {
398 assert_eq!(AccessMode::ReadWrite.as_sql(), "READ WRITE");
399 assert_eq!(AccessMode::ReadOnly.as_sql(), "READ ONLY");
400 }
401
402 #[test]
403 fn test_transaction_config_default() {
404 let config = TransactionConfig::new();
405 assert_eq!(config.isolation, IsolationLevel::ReadCommitted);
406 assert_eq!(config.access_mode, AccessMode::ReadWrite);
407 assert!(config.timeout.is_none());
408 assert!(!config.deferrable);
409 }
410
411 #[test]
412 fn test_transaction_config_builder() {
413 let config = TransactionConfig::new()
414 .isolation(IsolationLevel::Serializable)
415 .read_only()
416 .deferrable()
417 .timeout(Duration::from_secs(30));
418
419 assert_eq!(config.isolation, IsolationLevel::Serializable);
420 assert_eq!(config.access_mode, AccessMode::ReadOnly);
421 assert!(config.deferrable);
422 assert_eq!(config.timeout, Some(Duration::from_secs(30)));
423 }
424
425 #[test]
426 fn test_begin_sql() {
427 let config = TransactionConfig::new();
428 let sql = config.to_begin_sql();
429 assert!(sql.contains("BEGIN"));
430 assert!(sql.contains("ISOLATION LEVEL READ COMMITTED"));
431 assert!(sql.contains("READ WRITE"));
432 }
433
434 #[test]
435 fn test_begin_sql_serializable_deferrable() {
436 let config = TransactionConfig::new()
437 .isolation(IsolationLevel::Serializable)
438 .read_only()
439 .deferrable();
440 let sql = config.to_begin_sql();
441 assert!(sql.contains("SERIALIZABLE"));
442 assert!(sql.contains("READ ONLY"));
443 assert!(sql.contains("DEFERRABLE"));
444 }
445
446 #[test]
447 fn test_interactive_transaction() {
448 #[derive(Clone)]
449 struct MockEngine;
450
451 let mut tx = InteractiveTransaction::new(MockEngine);
452 assert!(!tx.is_started());
453
454 let begin = tx.begin_sql();
455 assert!(begin.contains("BEGIN"));
456
457 let sp = tx.savepoint_sql(Some("test_sp"));
458 assert_eq!(sp, "SAVEPOINT test_sp");
459
460 let rollback_to = tx.rollback_to_sql("test_sp");
461 assert_eq!(rollback_to, "ROLLBACK TO SAVEPOINT test_sp");
462
463 let release = tx.release_savepoint_sql("test_sp");
464 assert_eq!(release, "RELEASE SAVEPOINT test_sp");
465 }
466}