Skip to main content

grafeo_engine/transaction/
manager.rs

1//! Transaction manager.
2
3use std::collections::HashSet;
4use std::sync::atomic::{AtomicU64, Ordering};
5
6use grafeo_common::types::{EdgeId, EpochId, NodeId, TxId};
7use grafeo_common::utils::error::{Error, Result, TransactionError};
8use grafeo_common::utils::hash::FxHashMap;
9use parking_lot::RwLock;
10
11/// State of a transaction.
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum TxState {
14    /// Transaction is active.
15    Active,
16    /// Transaction is committed.
17    Committed,
18    /// Transaction is aborted.
19    Aborted,
20}
21
22/// Entity identifier for write tracking.
23#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
24pub enum EntityId {
25    /// A node.
26    Node(NodeId),
27    /// An edge.
28    Edge(EdgeId),
29}
30
31impl From<NodeId> for EntityId {
32    fn from(id: NodeId) -> Self {
33        Self::Node(id)
34    }
35}
36
37impl From<EdgeId> for EntityId {
38    fn from(id: EdgeId) -> Self {
39        Self::Edge(id)
40    }
41}
42
43/// Information about an active transaction.
44pub struct TxInfo {
45    /// Transaction state.
46    pub state: TxState,
47    /// Start epoch (snapshot epoch for reads).
48    pub start_epoch: EpochId,
49    /// Set of entities written by this transaction.
50    pub write_set: HashSet<EntityId>,
51    /// Set of entities read by this transaction (for serializable isolation).
52    pub read_set: HashSet<EntityId>,
53}
54
55impl TxInfo {
56    /// Creates a new transaction info.
57    fn new(start_epoch: EpochId) -> Self {
58        Self {
59            state: TxState::Active,
60            start_epoch,
61            write_set: HashSet::new(),
62            read_set: HashSet::new(),
63        }
64    }
65}
66
67/// Manages transactions and MVCC versioning.
68pub struct TransactionManager {
69    /// Next transaction ID.
70    next_tx_id: AtomicU64,
71    /// Current epoch.
72    current_epoch: AtomicU64,
73    /// Active transactions.
74    transactions: RwLock<FxHashMap<TxId, TxInfo>>,
75    /// Committed transaction epochs (for conflict detection).
76    /// Maps TxId -> commit epoch.
77    committed_epochs: RwLock<FxHashMap<TxId, EpochId>>,
78}
79
80impl TransactionManager {
81    /// Creates a new transaction manager.
82    #[must_use]
83    pub fn new() -> Self {
84        Self {
85            // Start at 2 to avoid collision with TxId::SYSTEM (which is 1)
86            // TxId::INVALID = 0, TxId::SYSTEM = 1, user transactions start at 2
87            next_tx_id: AtomicU64::new(2),
88            current_epoch: AtomicU64::new(0),
89            transactions: RwLock::new(FxHashMap::default()),
90            committed_epochs: RwLock::new(FxHashMap::default()),
91        }
92    }
93
94    /// Begins a new transaction.
95    pub fn begin(&self) -> TxId {
96        let tx_id = TxId::new(self.next_tx_id.fetch_add(1, Ordering::Relaxed));
97        let epoch = EpochId::new(self.current_epoch.load(Ordering::Acquire));
98
99        let info = TxInfo::new(epoch);
100        self.transactions.write().insert(tx_id, info);
101        tx_id
102    }
103
104    /// Records a write operation for the transaction.
105    ///
106    /// # Errors
107    ///
108    /// Returns an error if the transaction is not active.
109    pub fn record_write(&self, tx_id: TxId, entity: impl Into<EntityId>) -> Result<()> {
110        let mut txns = self.transactions.write();
111        let info = txns.get_mut(&tx_id).ok_or_else(|| {
112            Error::Transaction(TransactionError::InvalidState(
113                "Transaction not found".to_string(),
114            ))
115        })?;
116
117        if info.state != TxState::Active {
118            return Err(Error::Transaction(TransactionError::InvalidState(
119                "Transaction is not active".to_string(),
120            )));
121        }
122
123        info.write_set.insert(entity.into());
124        Ok(())
125    }
126
127    /// Records a read operation for the transaction (for serializable isolation).
128    ///
129    /// # Errors
130    ///
131    /// Returns an error if the transaction is not active.
132    pub fn record_read(&self, tx_id: TxId, entity: impl Into<EntityId>) -> Result<()> {
133        let mut txns = self.transactions.write();
134        let info = txns.get_mut(&tx_id).ok_or_else(|| {
135            Error::Transaction(TransactionError::InvalidState(
136                "Transaction not found".to_string(),
137            ))
138        })?;
139
140        if info.state != TxState::Active {
141            return Err(Error::Transaction(TransactionError::InvalidState(
142                "Transaction is not active".to_string(),
143            )));
144        }
145
146        info.read_set.insert(entity.into());
147        Ok(())
148    }
149
150    /// Commits a transaction with conflict detection.
151    ///
152    /// # Errors
153    ///
154    /// Returns an error if:
155    /// - The transaction is not active
156    /// - There's a write-write conflict with another committed transaction
157    pub fn commit(&self, tx_id: TxId) -> Result<EpochId> {
158        let mut txns = self.transactions.write();
159        let committed = self.committed_epochs.read();
160
161        // First, validate the transaction exists and is active
162        {
163            let info = txns.get(&tx_id).ok_or_else(|| {
164                Error::Transaction(TransactionError::InvalidState(
165                    "Transaction not found".to_string(),
166                ))
167            })?;
168
169            if info.state != TxState::Active {
170                return Err(Error::Transaction(TransactionError::InvalidState(
171                    "Transaction is not active".to_string(),
172                )));
173            }
174        }
175
176        // Get our write set for conflict checking
177        let our_write_set: HashSet<EntityId> = txns
178            .get(&tx_id)
179            .map(|info| info.write_set.clone())
180            .unwrap_or_default();
181
182        let our_start_epoch = txns
183            .get(&tx_id)
184            .map(|info| info.start_epoch)
185            .unwrap_or(EpochId::new(0));
186
187        // Check for write-write conflicts with other committed transactions
188        for (other_tx, other_info) in txns.iter() {
189            if *other_tx == tx_id {
190                continue;
191            }
192            if other_info.state == TxState::Committed {
193                // Check if any of our writes conflict with their writes
194                for entity in &our_write_set {
195                    if other_info.write_set.contains(entity) {
196                        return Err(Error::Transaction(TransactionError::WriteConflict(
197                            format!("Write-write conflict on entity {:?}", entity),
198                        )));
199                    }
200                }
201            }
202        }
203
204        // Also check against recently committed transactions
205        for (other_tx, commit_epoch) in committed.iter() {
206            if *other_tx != tx_id && commit_epoch.as_u64() > our_start_epoch.as_u64() {
207                // Check if that transaction wrote to any of our entities
208                if let Some(other_info) = txns.get(other_tx) {
209                    for entity in &our_write_set {
210                        if other_info.write_set.contains(entity) {
211                            return Err(Error::Transaction(TransactionError::WriteConflict(
212                                format!("Write-write conflict on entity {:?}", entity),
213                            )));
214                        }
215                    }
216                }
217            }
218        }
219
220        // Commit successful - advance epoch atomically
221        // SeqCst ensures all threads see commits in a consistent total order
222        let commit_epoch = EpochId::new(self.current_epoch.fetch_add(1, Ordering::SeqCst) + 1);
223
224        // Now update state
225        if let Some(info) = txns.get_mut(&tx_id) {
226            info.state = TxState::Committed;
227        }
228
229        // Record commit epoch (need to drop read lock first)
230        drop(committed);
231        self.committed_epochs.write().insert(tx_id, commit_epoch);
232
233        Ok(commit_epoch)
234    }
235
236    /// Aborts a transaction.
237    ///
238    /// # Errors
239    ///
240    /// Returns an error if the transaction is not active.
241    pub fn abort(&self, tx_id: TxId) -> Result<()> {
242        let mut txns = self.transactions.write();
243
244        let info = txns.get_mut(&tx_id).ok_or_else(|| {
245            Error::Transaction(TransactionError::InvalidState(
246                "Transaction not found".to_string(),
247            ))
248        })?;
249
250        if info.state != TxState::Active {
251            return Err(Error::Transaction(TransactionError::InvalidState(
252                "Transaction is not active".to_string(),
253            )));
254        }
255
256        info.state = TxState::Aborted;
257        Ok(())
258    }
259
260    /// Returns the write set of a transaction.
261    ///
262    /// This returns a copy of the entities written by this transaction,
263    /// used for rollback to discard uncommitted versions.
264    pub fn get_write_set(&self, tx_id: TxId) -> Result<HashSet<EntityId>> {
265        let txns = self.transactions.read();
266        let info = txns.get(&tx_id).ok_or_else(|| {
267            Error::Transaction(TransactionError::InvalidState(
268                "Transaction not found".to_string(),
269            ))
270        })?;
271        Ok(info.write_set.clone())
272    }
273
274    /// Aborts all active transactions.
275    ///
276    /// Used during database shutdown.
277    pub fn abort_all_active(&self) {
278        let mut txns = self.transactions.write();
279        for info in txns.values_mut() {
280            if info.state == TxState::Active {
281                info.state = TxState::Aborted;
282            }
283        }
284    }
285
286    /// Returns the state of a transaction.
287    pub fn state(&self, tx_id: TxId) -> Option<TxState> {
288        self.transactions.read().get(&tx_id).map(|info| info.state)
289    }
290
291    /// Returns the start epoch of a transaction.
292    pub fn start_epoch(&self, tx_id: TxId) -> Option<EpochId> {
293        self.transactions
294            .read()
295            .get(&tx_id)
296            .map(|info| info.start_epoch)
297    }
298
299    /// Returns the current epoch.
300    #[must_use]
301    pub fn current_epoch(&self) -> EpochId {
302        EpochId::new(self.current_epoch.load(Ordering::Acquire))
303    }
304
305    /// Returns the minimum epoch that must be preserved for active transactions.
306    ///
307    /// This is used for garbage collection - versions visible at this epoch
308    /// must be preserved.
309    #[must_use]
310    pub fn min_active_epoch(&self) -> EpochId {
311        let txns = self.transactions.read();
312        txns.values()
313            .filter(|info| info.state == TxState::Active)
314            .map(|info| info.start_epoch)
315            .min()
316            .unwrap_or_else(|| self.current_epoch())
317    }
318
319    /// Returns the number of active transactions.
320    #[must_use]
321    pub fn active_count(&self) -> usize {
322        self.transactions
323            .read()
324            .values()
325            .filter(|info| info.state == TxState::Active)
326            .count()
327    }
328
329    /// Cleans up completed transactions that are no longer needed for conflict detection.
330    ///
331    /// A committed transaction's write set must be preserved until all transactions
332    /// that started before its commit have completed. This ensures write-write
333    /// conflict detection works correctly.
334    ///
335    /// Returns the number of transactions cleaned up.
336    pub fn gc(&self) -> usize {
337        let mut txns = self.transactions.write();
338        let mut committed = self.committed_epochs.write();
339
340        // Find the minimum start epoch among active transactions
341        let min_active_start = txns
342            .values()
343            .filter(|info| info.state == TxState::Active)
344            .map(|info| info.start_epoch)
345            .min();
346
347        let initial_count = txns.len();
348
349        // Collect transactions safe to remove
350        let to_remove: Vec<TxId> = txns
351            .iter()
352            .filter(|(tx_id, info)| {
353                match info.state {
354                    TxState::Active => false, // Never remove active transactions
355                    TxState::Aborted => true, // Always safe to remove aborted transactions
356                    TxState::Committed => {
357                        // Only remove committed transactions if their commit epoch
358                        // is older than all active transactions' start epochs
359                        if let Some(min_start) = min_active_start {
360                            if let Some(commit_epoch) = committed.get(*tx_id) {
361                                // Safe to remove if committed before all active txns started
362                                commit_epoch.as_u64() < min_start.as_u64()
363                            } else {
364                                // No commit epoch recorded, keep it to be safe
365                                false
366                            }
367                        } else {
368                            // No active transactions, safe to remove all committed
369                            true
370                        }
371                    }
372                }
373            })
374            .map(|(id, _)| *id)
375            .collect();
376
377        for id in &to_remove {
378            txns.remove(id);
379            committed.remove(id);
380        }
381
382        initial_count - txns.len()
383    }
384
385    /// Marks a transaction as committed at a specific epoch.
386    ///
387    /// Used during recovery to restore transaction state.
388    pub fn mark_committed(&self, tx_id: TxId, epoch: EpochId) {
389        self.committed_epochs.write().insert(tx_id, epoch);
390    }
391
392    /// Returns the last assigned transaction ID.
393    ///
394    /// Returns `None` if no transactions have been started yet.
395    #[must_use]
396    pub fn last_assigned_tx_id(&self) -> Option<TxId> {
397        let next = self.next_tx_id.load(Ordering::Relaxed);
398        if next > 1 {
399            Some(TxId::new(next - 1))
400        } else {
401            None
402        }
403    }
404}
405
406impl Default for TransactionManager {
407    fn default() -> Self {
408        Self::new()
409    }
410}
411
412#[cfg(test)]
413mod tests {
414    use super::*;
415
416    #[test]
417    fn test_begin_commit() {
418        let mgr = TransactionManager::new();
419
420        let tx = mgr.begin();
421        assert_eq!(mgr.state(tx), Some(TxState::Active));
422
423        let commit_epoch = mgr.commit(tx).unwrap();
424        assert_eq!(mgr.state(tx), Some(TxState::Committed));
425        assert!(commit_epoch.as_u64() > 0);
426    }
427
428    #[test]
429    fn test_begin_abort() {
430        let mgr = TransactionManager::new();
431
432        let tx = mgr.begin();
433        mgr.abort(tx).unwrap();
434        assert_eq!(mgr.state(tx), Some(TxState::Aborted));
435    }
436
437    #[test]
438    fn test_epoch_advancement() {
439        let mgr = TransactionManager::new();
440
441        let initial_epoch = mgr.current_epoch();
442
443        let tx = mgr.begin();
444        let commit_epoch = mgr.commit(tx).unwrap();
445
446        assert!(mgr.current_epoch().as_u64() > initial_epoch.as_u64());
447        assert!(commit_epoch.as_u64() > initial_epoch.as_u64());
448    }
449
450    #[test]
451    fn test_gc_preserves_needed_write_sets() {
452        let mgr = TransactionManager::new();
453
454        let tx1 = mgr.begin();
455        let tx2 = mgr.begin();
456
457        mgr.commit(tx1).unwrap();
458        // tx2 still active - started before tx1 committed
459
460        assert_eq!(mgr.active_count(), 1);
461
462        // GC should NOT remove tx1 because tx2 might need its write set for conflict detection
463        let cleaned = mgr.gc();
464        assert_eq!(cleaned, 0);
465
466        // Both transactions should remain
467        assert_eq!(mgr.state(tx1), Some(TxState::Committed));
468        assert_eq!(mgr.state(tx2), Some(TxState::Active));
469    }
470
471    #[test]
472    fn test_gc_removes_old_commits() {
473        let mgr = TransactionManager::new();
474
475        // tx1 commits at epoch 1
476        let tx1 = mgr.begin();
477        mgr.commit(tx1).unwrap();
478
479        // tx2 starts at epoch 1, commits at epoch 2
480        let tx2 = mgr.begin();
481        mgr.commit(tx2).unwrap();
482
483        // tx3 starts at epoch 2
484        let tx3 = mgr.begin();
485
486        // At this point:
487        // - tx1 committed at epoch 1, tx3 started at epoch 2 → tx1 commit < tx3 start → safe to GC
488        // - tx2 committed at epoch 2, tx3 started at epoch 2 → tx2 commit >= tx3 start → NOT safe
489        let cleaned = mgr.gc();
490        assert_eq!(cleaned, 1); // Only tx1 removed
491
492        assert_eq!(mgr.state(tx1), None);
493        assert_eq!(mgr.state(tx2), Some(TxState::Committed)); // Preserved for conflict detection
494        assert_eq!(mgr.state(tx3), Some(TxState::Active));
495
496        // After tx3 commits, tx2 can be GC'd
497        mgr.commit(tx3).unwrap();
498        let cleaned = mgr.gc();
499        assert_eq!(cleaned, 2); // tx2 and tx3 both cleaned (no active transactions)
500    }
501
502    #[test]
503    fn test_gc_removes_aborted() {
504        let mgr = TransactionManager::new();
505
506        let tx1 = mgr.begin();
507        let tx2 = mgr.begin();
508
509        mgr.abort(tx1).unwrap();
510        // tx2 still active
511
512        // Aborted transactions are always safe to remove
513        let cleaned = mgr.gc();
514        assert_eq!(cleaned, 1);
515
516        assert_eq!(mgr.state(tx1), None);
517        assert_eq!(mgr.state(tx2), Some(TxState::Active));
518    }
519
520    #[test]
521    fn test_write_tracking() {
522        let mgr = TransactionManager::new();
523
524        let tx = mgr.begin();
525
526        // Record writes
527        mgr.record_write(tx, NodeId::new(1)).unwrap();
528        mgr.record_write(tx, NodeId::new(2)).unwrap();
529        mgr.record_write(tx, EdgeId::new(100)).unwrap();
530
531        // Should commit successfully (no conflicts)
532        assert!(mgr.commit(tx).is_ok());
533    }
534
535    #[test]
536    fn test_min_active_epoch() {
537        let mgr = TransactionManager::new();
538
539        // No active transactions - should return current epoch
540        assert_eq!(mgr.min_active_epoch(), mgr.current_epoch());
541
542        // Start some transactions
543        let tx1 = mgr.begin();
544        let epoch1 = mgr.start_epoch(tx1).unwrap();
545
546        // Advance epoch
547        let tx2 = mgr.begin();
548        mgr.commit(tx2).unwrap();
549
550        let _tx3 = mgr.begin();
551
552        // min_active_epoch should be tx1's start epoch (earliest active)
553        assert_eq!(mgr.min_active_epoch(), epoch1);
554    }
555
556    #[test]
557    fn test_abort_all_active() {
558        let mgr = TransactionManager::new();
559
560        let tx1 = mgr.begin();
561        let tx2 = mgr.begin();
562        let tx3 = mgr.begin();
563
564        mgr.commit(tx1).unwrap();
565        // tx2 and tx3 still active
566
567        mgr.abort_all_active();
568
569        assert_eq!(mgr.state(tx1), Some(TxState::Committed)); // Already committed
570        assert_eq!(mgr.state(tx2), Some(TxState::Aborted));
571        assert_eq!(mgr.state(tx3), Some(TxState::Aborted));
572    }
573
574    #[test]
575    fn test_start_epoch_snapshot() {
576        let mgr = TransactionManager::new();
577
578        // Start epoch for tx1
579        let tx1 = mgr.begin();
580        let start1 = mgr.start_epoch(tx1).unwrap();
581
582        // Commit tx1, advancing epoch
583        mgr.commit(tx1).unwrap();
584
585        // Start tx2 after epoch advanced
586        let tx2 = mgr.begin();
587        let start2 = mgr.start_epoch(tx2).unwrap();
588
589        // tx2 should have a later start epoch
590        assert!(start2.as_u64() > start1.as_u64());
591    }
592
593    #[test]
594    fn test_write_write_conflict_detection() {
595        let mgr = TransactionManager::new();
596
597        // Both transactions start at the same epoch
598        let tx1 = mgr.begin();
599        let tx2 = mgr.begin();
600
601        // Both try to write to the same entity
602        let entity = NodeId::new(42);
603        mgr.record_write(tx1, entity).unwrap();
604        mgr.record_write(tx2, entity).unwrap();
605
606        // First commit succeeds
607        let result1 = mgr.commit(tx1);
608        assert!(result1.is_ok());
609
610        // Second commit should fail due to write-write conflict
611        let result2 = mgr.commit(tx2);
612        assert!(result2.is_err());
613        assert!(
614            result2
615                .unwrap_err()
616                .to_string()
617                .contains("Write-write conflict"),
618            "Expected write-write conflict error"
619        );
620    }
621
622    #[test]
623    fn test_commit_epoch_monotonicity() {
624        let mgr = TransactionManager::new();
625
626        let mut epochs = Vec::new();
627
628        // Commit multiple transactions and verify epochs are strictly increasing
629        for _ in 0..10 {
630            let tx = mgr.begin();
631            let epoch = mgr.commit(tx).unwrap();
632            epochs.push(epoch.as_u64());
633        }
634
635        // Verify strict monotonicity
636        for i in 1..epochs.len() {
637            assert!(
638                epochs[i] > epochs[i - 1],
639                "Epoch {} ({}) should be greater than epoch {} ({})",
640                i,
641                epochs[i],
642                i - 1,
643                epochs[i - 1]
644            );
645        }
646    }
647
648    #[test]
649    fn test_concurrent_commits_via_threads() {
650        use std::sync::Arc;
651        use std::thread;
652
653        let mgr = Arc::new(TransactionManager::new());
654        let num_threads = 10;
655        let commits_per_thread = 100;
656
657        let handles: Vec<_> = (0..num_threads)
658            .map(|_| {
659                let mgr = Arc::clone(&mgr);
660                thread::spawn(move || {
661                    let mut epochs = Vec::new();
662                    for _ in 0..commits_per_thread {
663                        let tx = mgr.begin();
664                        let epoch = mgr.commit(tx).unwrap();
665                        epochs.push(epoch.as_u64());
666                    }
667                    epochs
668                })
669            })
670            .collect();
671
672        let mut all_epochs: Vec<u64> = handles
673            .into_iter()
674            .flat_map(|h| h.join().unwrap())
675            .collect();
676
677        // All epochs should be unique (no duplicates)
678        all_epochs.sort();
679        let unique_count = all_epochs.len();
680        all_epochs.dedup();
681        assert_eq!(
682            all_epochs.len(),
683            unique_count,
684            "All commit epochs should be unique"
685        );
686
687        // Final epoch should equal number of commits
688        assert_eq!(
689            mgr.current_epoch().as_u64(),
690            (num_threads * commits_per_thread) as u64,
691            "Final epoch should equal total commits"
692        );
693    }
694}