1use super::lock::{LockManager, LockMode, LockResult};
6use super::log::{TransactionLog, WalConfig};
7use super::savepoint::TxnSavepoints;
8use std::collections::HashMap;
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::sync::{Arc, RwLock, RwLockReadGuard, RwLockWriteGuard};
11use std::time::{Duration, Instant};
12
13pub type TxnId = u64;
15
16pub type Timestamp = u64;
18
19fn tx_lock_error(context: &'static str) -> TxnError {
20 TxnError::Internal(format!("{context} lock poisoned"))
21}
22
23fn read_guard_or_err<'a, T>(
24 lock: &'a RwLock<T>,
25 context: &'static str,
26) -> Result<RwLockReadGuard<'a, T>, TxnError> {
27 lock.read().map_err(|_| tx_lock_error(context))
28}
29
30fn write_guard_or_err<'a, T>(
31 lock: &'a RwLock<T>,
32 context: &'static str,
33) -> Result<RwLockWriteGuard<'a, T>, TxnError> {
34 lock.write().map_err(|_| tx_lock_error(context))
35}
36
37fn recover_read_guard<'a, T>(lock: &'a RwLock<T>) -> RwLockReadGuard<'a, T> {
38 match lock.read() {
39 Ok(guard) => guard,
40 Err(poisoned) => poisoned.into_inner(),
41 }
42}
43
44fn recover_write_guard<'a, T>(lock: &'a RwLock<T>) -> RwLockWriteGuard<'a, T> {
45 match lock.write() {
46 Ok(guard) => guard,
47 Err(poisoned) => poisoned.into_inner(),
48 }
49}
50
51#[derive(Debug, Clone)]
53pub enum TxnError {
54 NotFound(TxnId),
56 AlreadyCommitted(TxnId),
58 AlreadyAborted(TxnId),
60 WriteConflict { key: Vec<u8>, holder: TxnId },
62 Deadlock(Vec<TxnId>),
64 LockLimitExceeded { limit: usize },
66 LockTimeout { key: Vec<u8>, timeout: Duration },
68 ValidationFailed {
70 key: Vec<u8>,
71 expected_ts: Timestamp,
72 actual_ts: Timestamp,
73 },
74 LogError(String),
76 SavepointNotFound(String),
78 Timeout(TxnId),
80 Internal(String),
82}
83
84impl std::fmt::Display for TxnError {
85 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
86 match self {
87 TxnError::NotFound(id) => write!(f, "Transaction {} not found", id),
88 TxnError::AlreadyCommitted(id) => write!(f, "Transaction {} already committed", id),
89 TxnError::AlreadyAborted(id) => write!(f, "Transaction {} already aborted", id),
90 TxnError::WriteConflict { key, holder } => {
91 write!(f, "Write conflict on {:?}, held by txn {}", key, holder)
92 }
93 TxnError::Deadlock(cycle) => write!(f, "Deadlock detected: {:?}", cycle),
94 TxnError::LockLimitExceeded { limit } => {
95 write!(f, "Lock limit exceeded: max {}", limit)
96 }
97 TxnError::LockTimeout { key, timeout } => {
98 write!(f, "Lock timeout on {:?} after {:?}", key, timeout)
99 }
100 TxnError::ValidationFailed {
101 key,
102 expected_ts,
103 actual_ts,
104 } => {
105 write!(
106 f,
107 "Validation failed for {:?}: expected ts {}, actual {}",
108 key, expected_ts, actual_ts
109 )
110 }
111 TxnError::LogError(msg) => write!(f, "WAL error: {}", msg),
112 TxnError::SavepointNotFound(name) => write!(f, "Savepoint '{}' not found", name),
113 TxnError::Timeout(id) => write!(f, "Transaction {} timed out", id),
114 TxnError::Internal(msg) => write!(f, "Internal error: {}", msg),
115 }
116 }
117}
118
119impl std::error::Error for TxnError {}
120
121#[derive(Debug, Clone, Copy, PartialEq, Eq)]
123pub enum TxnState {
124 Active,
126 Preparing,
128 Committed,
130 Aborted,
132}
133
134#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
136pub enum IsolationLevel {
137 ReadUncommitted,
139 ReadCommitted,
141 #[default]
143 SnapshotIsolation,
144 Serializable,
146}
147
148#[derive(Debug, Clone)]
150pub struct TxnConfig {
151 pub isolation_level: IsolationLevel,
153 pub lock_timeout: Duration,
155 pub txn_timeout: Duration,
157 pub optimistic: bool,
159 pub wal_enabled: bool,
161 pub wal_sync: WalSyncMode,
163}
164
165#[derive(Debug, Clone, Copy, PartialEq, Eq)]
167pub enum WalSyncMode {
168 EveryCommit,
170 Periodic(Duration),
172 None,
174}
175
176impl TxnConfig {
177 pub fn new() -> Self {
179 Self {
180 isolation_level: IsolationLevel::SnapshotIsolation,
181 lock_timeout: Duration::from_secs(30),
182 txn_timeout: Duration::from_secs(300),
183 optimistic: true,
184 wal_enabled: true,
185 wal_sync: WalSyncMode::EveryCommit,
186 }
187 }
188
189 pub fn with_isolation(mut self, level: IsolationLevel) -> Self {
191 self.isolation_level = level;
192 self
193 }
194
195 pub fn with_lock_timeout(mut self, timeout: Duration) -> Self {
197 self.lock_timeout = timeout;
198 self
199 }
200
201 pub fn with_optimistic(mut self, enabled: bool) -> Self {
203 self.optimistic = enabled;
204 self
205 }
206}
207
208impl Default for TxnConfig {
209 fn default() -> Self {
210 Self::new()
211 }
212}
213
214#[derive(Debug, Clone)]
216pub struct TxnHandle {
217 pub id: TxnId,
219 pub start_ts: Timestamp,
221 pub isolation: IsolationLevel,
223}
224
225impl TxnHandle {
226 pub fn id(&self) -> TxnId {
228 self.id
229 }
230
231 pub fn start_ts(&self) -> Timestamp {
233 self.start_ts
234 }
235}
236
237struct TransactionState {
239 handle: TxnHandle,
241 state: TxnState,
243 start_time: Instant,
245 read_set: Vec<(Vec<u8>, Timestamp)>,
247 write_set: Vec<WriteEntry>,
249 savepoints: TxnSavepoints,
251 locks_held: Vec<Vec<u8>>,
253}
254
255#[derive(Debug, Clone)]
257struct WriteEntry {
258 key: Vec<u8>,
260 old_value: Option<Vec<u8>>,
262 new_value: Option<Vec<u8>>,
264 timestamp: Timestamp,
266}
267
268pub struct Transaction {
270 id: TxnId,
272 coordinator: Arc<TransactionManager>,
274}
275
276impl Transaction {
277 pub fn id(&self) -> TxnId {
279 self.id
280 }
281
282 pub fn record_read(&self, key: &[u8], read_ts: Timestamp) {
284 self.coordinator.record_read(self.id, key, read_ts);
285 }
286
287 pub fn record_write(&self, key: &[u8], old_value: Option<&[u8]>, new_value: Option<&[u8]>) {
289 self.coordinator
290 .record_write(self.id, key, old_value, new_value);
291 }
292
293 pub fn savepoint(&self, name: &str) -> Result<(), TxnError> {
295 self.coordinator.create_savepoint(self.id, name)
296 }
297
298 pub fn rollback_to(&self, name: &str) -> Result<(), TxnError> {
300 self.coordinator.rollback_to_savepoint(self.id, name)
301 }
302
303 pub fn commit(self) -> Result<(), TxnError> {
305 self.coordinator.commit(self.id)
306 }
307
308 pub fn abort(self) -> Result<(), TxnError> {
310 self.coordinator.abort(self.id)
311 }
312}
313
314pub struct TransactionManager {
316 config: TxnConfig,
318 next_id: AtomicU64,
320 current_ts: AtomicU64,
322 transactions: RwLock<HashMap<TxnId, TransactionState>>,
324 lock_manager: LockManager,
326 log: Option<TransactionLog>,
328 committed_ts: RwLock<HashMap<Vec<u8>, Timestamp>>,
330}
331
332impl TransactionManager {
333 pub fn new(config: TxnConfig) -> Self {
335 let log = if config.wal_enabled {
336 Some(TransactionLog::new(WalConfig::default()))
337 } else {
338 None
339 };
340
341 Self {
342 config,
343 next_id: AtomicU64::new(1),
344 current_ts: AtomicU64::new(1),
345 transactions: RwLock::new(HashMap::new()),
346 lock_manager: LockManager::with_defaults(),
347 log: log.and_then(|r| r.ok()),
348 committed_ts: RwLock::new(HashMap::new()),
349 }
350 }
351
352 pub fn with_default_config() -> Self {
354 Self::new(TxnConfig::default())
355 }
356
357 pub fn config(&self) -> &TxnConfig {
359 &self.config
360 }
361
362 fn next_timestamp(&self) -> Timestamp {
364 self.current_ts.fetch_add(1, Ordering::SeqCst)
365 }
366
367 pub fn begin(&self) -> TxnHandle {
369 self.begin_with_isolation(self.config.isolation_level)
370 }
371
372 pub fn begin_with_isolation(&self, isolation: IsolationLevel) -> TxnHandle {
374 let id = self.next_id.fetch_add(1, Ordering::SeqCst);
375 let start_ts = self.next_timestamp();
376
377 let handle = TxnHandle {
378 id,
379 start_ts,
380 isolation,
381 };
382
383 let state = TransactionState {
384 handle: handle.clone(),
385 state: TxnState::Active,
386 start_time: Instant::now(),
387 read_set: Vec::new(),
388 write_set: Vec::new(),
389 savepoints: TxnSavepoints::new(id),
390 locks_held: Vec::new(),
391 };
392
393 if let Some(ref log) = self.log {
395 let _ = log.log_begin(id);
396 }
397
398 recover_write_guard(&self.transactions).insert(id, state);
399
400 handle
401 }
402
403 pub fn begin_transaction(self: &Arc<Self>) -> Transaction {
405 let handle = self.begin();
406 Transaction {
407 id: handle.id,
408 coordinator: Arc::clone(self),
409 }
410 }
411
412 pub fn record_read(&self, txn_id: TxnId, key: &[u8], read_ts: Timestamp) {
414 let mut txns = recover_write_guard(&self.transactions);
415 if let Some(state) = txns.get_mut(&txn_id) {
416 if state.state == TxnState::Active {
417 state.read_set.push((key.to_vec(), read_ts));
418 }
419 }
420 }
421
422 pub fn record_write(
424 &self,
425 txn_id: TxnId,
426 key: &[u8],
427 old_value: Option<&[u8]>,
428 new_value: Option<&[u8]>,
429 ) {
430 let timestamp = self.next_timestamp();
431
432 let mut txns = recover_write_guard(&self.transactions);
433 if let Some(state) = txns.get_mut(&txn_id) {
434 if state.state == TxnState::Active {
435 let entry = WriteEntry {
436 key: key.to_vec(),
437 old_value: old_value.map(|v| v.to_vec()),
438 new_value: new_value.map(|v| v.to_vec()),
439 timestamp,
440 };
441
442 if let Some(ref log) = self.log {
444 if let Some(old) = old_value {
445 if let Some(new) = new_value {
446 let _ =
447 log.log_update(txn_id, key.to_vec(), old.to_vec(), new.to_vec());
448 } else {
449 let _ = log.log_delete(txn_id, key.to_vec(), old.to_vec());
450 }
451 } else if let Some(new) = new_value {
452 let _ = log.log_insert(txn_id, key.to_vec(), new.to_vec());
453 }
454 }
455
456 state.write_set.push(entry);
457 }
458 }
459 }
460
461 pub fn acquire_lock(&self, txn_id: TxnId, key: &[u8], mode: LockMode) -> Result<(), TxnError> {
463 {
465 let txns = read_guard_or_err(&self.transactions, "transaction manager state")?;
466 let state = txns.get(&txn_id).ok_or(TxnError::NotFound(txn_id))?;
467 if state.state != TxnState::Active {
468 return Err(TxnError::AlreadyAborted(txn_id));
469 }
470 }
471
472 match self
474 .lock_manager
475 .acquire_with_timeout(txn_id, key, mode, self.config.lock_timeout)
476 {
477 LockResult::Granted | LockResult::Upgraded | LockResult::AlreadyHeld => {
478 let mut txns = write_guard_or_err(&self.transactions, "transaction manager state")?;
480 if let Some(state) = txns.get_mut(&txn_id) {
481 if !state.locks_held.contains(&key.to_vec()) {
482 state.locks_held.push(key.to_vec());
483 }
484 }
485 Ok(())
486 }
487 LockResult::Waiting => {
488 Err(TxnError::Internal(
490 "Lock returned Waiting unexpectedly".to_string(),
491 ))
492 }
493 LockResult::Timeout => Err(TxnError::LockTimeout {
494 key: key.to_vec(),
495 timeout: self.config.lock_timeout,
496 }),
497 LockResult::Deadlock(cycle) => Err(TxnError::Deadlock(cycle)),
498 LockResult::LockLimitExceeded => Err(TxnError::LockLimitExceeded {
499 limit: self.lock_manager.config().max_locks_per_txn,
500 }),
501 LockResult::TxnNotFound => Err(TxnError::NotFound(txn_id)),
502 }
503 }
504
505 fn release_locks(&self, txn_id: TxnId) {
507 let locks = {
508 let txns = recover_read_guard(&self.transactions);
509 txns.get(&txn_id)
510 .map(|s| s.locks_held.clone())
511 .unwrap_or_default()
512 };
513
514 for key in locks {
515 self.lock_manager.release(txn_id, &key);
516 }
517 }
518
519 fn validate(&self, txn_id: TxnId) -> Result<(), TxnError> {
521 let txns = read_guard_or_err(&self.transactions, "transaction manager state")?;
522 let state = txns.get(&txn_id).ok_or(TxnError::NotFound(txn_id))?;
523
524 if !self.config.optimistic {
525 return Ok(());
526 }
527
528 let committed = read_guard_or_err(&self.committed_ts, "transaction manager committed_ts")?;
529
530 for (key, read_ts) in &state.read_set {
532 if let Some(&commit_ts) = committed.get(key) {
533 if commit_ts > *read_ts && commit_ts > state.handle.start_ts {
534 return Err(TxnError::ValidationFailed {
535 key: key.clone(),
536 expected_ts: *read_ts,
537 actual_ts: commit_ts,
538 });
539 }
540 }
541 }
542
543 Ok(())
544 }
545
546 pub fn commit(&self, txn_id: TxnId) -> Result<(), TxnError> {
548 self.validate(txn_id)?;
550
551 let commit_ts = self.next_timestamp();
552
553 {
555 let mut txns = write_guard_or_err(&self.transactions, "transaction manager state")?;
556 let state = txns.get_mut(&txn_id).ok_or(TxnError::NotFound(txn_id))?;
557
558 match state.state {
559 TxnState::Active | TxnState::Preparing => {
560 state.state = TxnState::Committed;
561 }
562 TxnState::Committed => return Err(TxnError::AlreadyCommitted(txn_id)),
563 TxnState::Aborted => return Err(TxnError::AlreadyAborted(txn_id)),
564 }
565
566 let mut committed =
568 write_guard_or_err(&self.committed_ts, "transaction manager committed_ts")?;
569 for entry in &state.write_set {
570 committed.insert(entry.key.clone(), commit_ts);
571 }
572 }
573
574 if let Some(ref log) = self.log {
576 let _ = log.log_commit(txn_id);
577
578 if matches!(self.config.wal_sync, WalSyncMode::EveryCommit) {
580 let _ = log.flush();
581 }
582 }
583
584 self.release_locks(txn_id);
586
587 Ok(())
588 }
589
590 pub fn abort(&self, txn_id: TxnId) -> Result<(), TxnError> {
592 {
594 let mut txns = write_guard_or_err(&self.transactions, "transaction manager state")?;
595 let state = txns.get_mut(&txn_id).ok_or(TxnError::NotFound(txn_id))?;
596
597 match state.state {
598 TxnState::Active | TxnState::Preparing => {
599 state.state = TxnState::Aborted;
600 }
601 TxnState::Committed => return Err(TxnError::AlreadyCommitted(txn_id)),
602 TxnState::Aborted => return Err(TxnError::AlreadyAborted(txn_id)),
603 }
604 }
605
606 if let Some(ref log) = self.log {
608 let _ = log.log_abort(txn_id);
609 }
610
611 self.release_locks(txn_id);
613
614 Ok(())
615 }
616
617 pub fn create_savepoint(&self, txn_id: TxnId, name: &str) -> Result<(), TxnError> {
619 let mut txns = write_guard_or_err(&self.transactions, "transaction manager state")?;
620 let state = txns.get_mut(&txn_id).ok_or(TxnError::NotFound(txn_id))?;
621
622 if state.state != TxnState::Active {
623 return Err(TxnError::AlreadyAborted(txn_id));
624 }
625
626 let write_set_index = state.write_set.len();
627 let lock_count = state.locks_held.len();
628 state
630 .savepoints
631 .create(name.to_string(), 0, lock_count, write_set_index);
632
633 Ok(())
634 }
635
636 pub fn rollback_to_savepoint(&self, txn_id: TxnId, name: &str) -> Result<(), TxnError> {
638 let mut txns = write_guard_or_err(&self.transactions, "transaction manager state")?;
639 let state = txns.get_mut(&txn_id).ok_or(TxnError::NotFound(txn_id))?;
640
641 if state.state != TxnState::Active {
642 return Err(TxnError::AlreadyAborted(txn_id));
643 }
644
645 let savepoint = state
646 .savepoints
647 .get(name)
648 .ok_or_else(|| TxnError::SavepointNotFound(name.to_string()))?;
649
650 state.write_set.truncate(savepoint.write_set_index);
652
653 state.savepoints.release(name);
655
656 Ok(())
657 }
658
659 pub fn get_state(&self, txn_id: TxnId) -> Option<TxnState> {
661 self.transactions
662 .read()
663 .unwrap_or_else(|poisoned| poisoned.into_inner())
664 .get(&txn_id)
665 .map(|s| s.state)
666 }
667
668 pub fn is_active(&self, txn_id: TxnId) -> bool {
670 self.get_state(txn_id) == Some(TxnState::Active)
671 }
672
673 pub fn active_count(&self) -> usize {
675 self.transactions
676 .read()
677 .unwrap_or_else(|poisoned| poisoned.into_inner())
678 .values()
679 .filter(|s| s.state == TxnState::Active)
680 .count()
681 }
682
683 pub fn oldest_active_ts(&self) -> Option<Timestamp> {
685 self.transactions
686 .read()
687 .unwrap_or_else(|poisoned| poisoned.into_inner())
688 .values()
689 .filter(|s| s.state == TxnState::Active)
690 .map(|s| s.handle.start_ts)
691 .min()
692 }
693
694 pub fn cleanup(&self, max_age: Duration) {
696 let mut txns = recover_write_guard(&self.transactions);
697 let now = Instant::now();
698
699 txns.retain(|_, state| {
700 if state.state == TxnState::Active {
701 true
702 } else {
703 now.duration_since(state.start_time) < max_age
704 }
705 });
706 }
707}
708
709impl Default for TransactionManager {
710 fn default() -> Self {
711 Self::with_default_config()
712 }
713}
714
715#[cfg(test)]
716mod tests {
717 use super::*;
718
719 #[test]
720 fn test_begin_commit() {
721 let tm = TransactionManager::with_default_config();
722
723 let handle = tm.begin();
724 assert!(tm.is_active(handle.id));
725
726 tm.commit(handle.id).unwrap();
727 assert!(!tm.is_active(handle.id));
728 assert_eq!(tm.get_state(handle.id), Some(TxnState::Committed));
729 }
730
731 #[test]
732 fn test_begin_abort() {
733 let tm = TransactionManager::with_default_config();
734
735 let handle = tm.begin();
736 assert!(tm.is_active(handle.id));
737
738 tm.abort(handle.id).unwrap();
739 assert!(!tm.is_active(handle.id));
740 assert_eq!(tm.get_state(handle.id), Some(TxnState::Aborted));
741 }
742
743 #[test]
744 fn test_double_commit() {
745 let tm = TransactionManager::with_default_config();
746
747 let handle = tm.begin();
748 tm.commit(handle.id).unwrap();
749
750 assert!(matches!(
751 tm.commit(handle.id),
752 Err(TxnError::AlreadyCommitted(_))
753 ));
754 }
755
756 #[test]
757 fn test_transaction_wrapper() {
758 let tm = Arc::new(TransactionManager::with_default_config());
759
760 let txn = tm.begin_transaction();
761 let id = txn.id();
762
763 txn.record_write(b"key1", None, Some(b"value1"));
764 txn.commit().unwrap();
765
766 assert!(!tm.is_active(id));
767 }
768
769 #[test]
770 fn test_savepoints() {
771 let tm = TransactionManager::with_default_config();
772
773 let handle = tm.begin();
774
775 tm.record_write(handle.id, b"key1", None, Some(b"v1"));
776 tm.create_savepoint(handle.id, "sp1").unwrap();
777
778 tm.record_write(handle.id, b"key2", None, Some(b"v2"));
779 tm.record_write(handle.id, b"key3", None, Some(b"v3"));
780
781 tm.rollback_to_savepoint(handle.id, "sp1").unwrap();
783
784 tm.commit(handle.id).unwrap();
786 }
787
788 #[test]
789 fn test_isolation_levels() {
790 let tm = TransactionManager::with_default_config();
791
792 let h1 = tm.begin_with_isolation(IsolationLevel::ReadCommitted);
793 let h2 = tm.begin_with_isolation(IsolationLevel::SnapshotIsolation);
794
795 assert_eq!(h1.isolation, IsolationLevel::ReadCommitted);
796 assert_eq!(h2.isolation, IsolationLevel::SnapshotIsolation);
797
798 tm.abort(h1.id).unwrap();
799 tm.abort(h2.id).unwrap();
800 }
801
802 #[test]
803 fn test_active_count() {
804 let tm = TransactionManager::with_default_config();
805
806 assert_eq!(tm.active_count(), 0);
807
808 let h1 = tm.begin();
809 let h2 = tm.begin();
810 assert_eq!(tm.active_count(), 2);
811
812 tm.commit(h1.id).unwrap();
813 assert_eq!(tm.active_count(), 1);
814
815 tm.abort(h2.id).unwrap();
816 assert_eq!(tm.active_count(), 0);
817 }
818
819 #[test]
820 fn test_oldest_active_ts() {
821 let tm = TransactionManager::with_default_config();
822
823 let h1 = tm.begin();
824 let ts1 = h1.start_ts;
825
826 let _h2 = tm.begin();
827
828 assert_eq!(tm.oldest_active_ts(), Some(ts1));
829 }
830
831 #[test]
832 fn test_config() {
833 let config = TxnConfig::new()
834 .with_isolation(IsolationLevel::Serializable)
835 .with_lock_timeout(Duration::from_secs(10))
836 .with_optimistic(false);
837
838 assert_eq!(config.isolation_level, IsolationLevel::Serializable);
839 assert_eq!(config.lock_timeout, Duration::from_secs(10));
840 assert!(!config.optimistic);
841 }
842}