grafeo_engine/transaction/
manager.rs1use std::collections::HashSet;
4use std::sync::atomic::{AtomicU64, Ordering};
5
6use grafeo_common::types::{EdgeId, EpochId, NodeId, TxId};
7use grafeo_common::utils::error::{Error, Result, TransactionError};
8use grafeo_common::utils::hash::FxHashMap;
9use parking_lot::RwLock;
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum TxState {
14 Active,
16 Committed,
18 Aborted,
20}
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
24pub enum EntityId {
25 Node(NodeId),
27 Edge(EdgeId),
29}
30
31impl From<NodeId> for EntityId {
32 fn from(id: NodeId) -> Self {
33 Self::Node(id)
34 }
35}
36
37impl From<EdgeId> for EntityId {
38 fn from(id: EdgeId) -> Self {
39 Self::Edge(id)
40 }
41}
42
43pub struct TxInfo {
45 pub state: TxState,
47 pub start_epoch: EpochId,
49 pub write_set: HashSet<EntityId>,
51 pub read_set: HashSet<EntityId>,
53}
54
55impl TxInfo {
56 fn new(start_epoch: EpochId) -> Self {
58 Self {
59 state: TxState::Active,
60 start_epoch,
61 write_set: HashSet::new(),
62 read_set: HashSet::new(),
63 }
64 }
65}
66
67pub struct TransactionManager {
69 next_tx_id: AtomicU64,
71 current_epoch: AtomicU64,
73 transactions: RwLock<FxHashMap<TxId, TxInfo>>,
75 committed_epochs: RwLock<FxHashMap<TxId, EpochId>>,
78}
79
80impl TransactionManager {
81 #[must_use]
83 pub fn new() -> Self {
84 Self {
85 next_tx_id: AtomicU64::new(2),
88 current_epoch: AtomicU64::new(0),
89 transactions: RwLock::new(FxHashMap::default()),
90 committed_epochs: RwLock::new(FxHashMap::default()),
91 }
92 }
93
94 pub fn begin(&self) -> TxId {
96 let tx_id = TxId::new(self.next_tx_id.fetch_add(1, Ordering::Relaxed));
97 let epoch = EpochId::new(self.current_epoch.load(Ordering::Acquire));
98
99 let info = TxInfo::new(epoch);
100 self.transactions.write().insert(tx_id, info);
101 tx_id
102 }
103
104 pub fn record_write(&self, tx_id: TxId, entity: impl Into<EntityId>) -> Result<()> {
110 let mut txns = self.transactions.write();
111 let info = txns.get_mut(&tx_id).ok_or_else(|| {
112 Error::Transaction(TransactionError::InvalidState(
113 "Transaction not found".to_string(),
114 ))
115 })?;
116
117 if info.state != TxState::Active {
118 return Err(Error::Transaction(TransactionError::InvalidState(
119 "Transaction is not active".to_string(),
120 )));
121 }
122
123 info.write_set.insert(entity.into());
124 Ok(())
125 }
126
127 pub fn record_read(&self, tx_id: TxId, entity: impl Into<EntityId>) -> Result<()> {
133 let mut txns = self.transactions.write();
134 let info = txns.get_mut(&tx_id).ok_or_else(|| {
135 Error::Transaction(TransactionError::InvalidState(
136 "Transaction not found".to_string(),
137 ))
138 })?;
139
140 if info.state != TxState::Active {
141 return Err(Error::Transaction(TransactionError::InvalidState(
142 "Transaction is not active".to_string(),
143 )));
144 }
145
146 info.read_set.insert(entity.into());
147 Ok(())
148 }
149
150 pub fn commit(&self, tx_id: TxId) -> Result<EpochId> {
158 let mut txns = self.transactions.write();
159 let committed = self.committed_epochs.read();
160
161 {
163 let info = txns.get(&tx_id).ok_or_else(|| {
164 Error::Transaction(TransactionError::InvalidState(
165 "Transaction not found".to_string(),
166 ))
167 })?;
168
169 if info.state != TxState::Active {
170 return Err(Error::Transaction(TransactionError::InvalidState(
171 "Transaction is not active".to_string(),
172 )));
173 }
174 }
175
176 let our_write_set: HashSet<EntityId> = txns
178 .get(&tx_id)
179 .map(|info| info.write_set.clone())
180 .unwrap_or_default();
181
182 let our_start_epoch = txns
183 .get(&tx_id)
184 .map(|info| info.start_epoch)
185 .unwrap_or(EpochId::new(0));
186
187 for (other_tx, other_info) in txns.iter() {
189 if *other_tx == tx_id {
190 continue;
191 }
192 if other_info.state == TxState::Committed {
193 for entity in &our_write_set {
195 if other_info.write_set.contains(entity) {
196 return Err(Error::Transaction(TransactionError::WriteConflict(
197 format!("Write-write conflict on entity {:?}", entity),
198 )));
199 }
200 }
201 }
202 }
203
204 for (other_tx, commit_epoch) in committed.iter() {
206 if *other_tx != tx_id && commit_epoch.as_u64() > our_start_epoch.as_u64() {
207 if let Some(other_info) = txns.get(other_tx) {
209 for entity in &our_write_set {
210 if other_info.write_set.contains(entity) {
211 return Err(Error::Transaction(TransactionError::WriteConflict(
212 format!("Write-write conflict on entity {:?}", entity),
213 )));
214 }
215 }
216 }
217 }
218 }
219
220 let commit_epoch = EpochId::new(self.current_epoch.fetch_add(1, Ordering::SeqCst) + 1);
223
224 if let Some(info) = txns.get_mut(&tx_id) {
226 info.state = TxState::Committed;
227 }
228
229 drop(committed);
231 self.committed_epochs.write().insert(tx_id, commit_epoch);
232
233 Ok(commit_epoch)
234 }
235
236 pub fn abort(&self, tx_id: TxId) -> Result<()> {
242 let mut txns = self.transactions.write();
243
244 let info = txns.get_mut(&tx_id).ok_or_else(|| {
245 Error::Transaction(TransactionError::InvalidState(
246 "Transaction not found".to_string(),
247 ))
248 })?;
249
250 if info.state != TxState::Active {
251 return Err(Error::Transaction(TransactionError::InvalidState(
252 "Transaction is not active".to_string(),
253 )));
254 }
255
256 info.state = TxState::Aborted;
257 Ok(())
258 }
259
260 pub fn get_write_set(&self, tx_id: TxId) -> Result<HashSet<EntityId>> {
265 let txns = self.transactions.read();
266 let info = txns.get(&tx_id).ok_or_else(|| {
267 Error::Transaction(TransactionError::InvalidState(
268 "Transaction not found".to_string(),
269 ))
270 })?;
271 Ok(info.write_set.clone())
272 }
273
274 pub fn abort_all_active(&self) {
278 let mut txns = self.transactions.write();
279 for info in txns.values_mut() {
280 if info.state == TxState::Active {
281 info.state = TxState::Aborted;
282 }
283 }
284 }
285
286 pub fn state(&self, tx_id: TxId) -> Option<TxState> {
288 self.transactions.read().get(&tx_id).map(|info| info.state)
289 }
290
291 pub fn start_epoch(&self, tx_id: TxId) -> Option<EpochId> {
293 self.transactions
294 .read()
295 .get(&tx_id)
296 .map(|info| info.start_epoch)
297 }
298
299 #[must_use]
301 pub fn current_epoch(&self) -> EpochId {
302 EpochId::new(self.current_epoch.load(Ordering::Acquire))
303 }
304
305 #[must_use]
310 pub fn min_active_epoch(&self) -> EpochId {
311 let txns = self.transactions.read();
312 txns.values()
313 .filter(|info| info.state == TxState::Active)
314 .map(|info| info.start_epoch)
315 .min()
316 .unwrap_or_else(|| self.current_epoch())
317 }
318
319 #[must_use]
321 pub fn active_count(&self) -> usize {
322 self.transactions
323 .read()
324 .values()
325 .filter(|info| info.state == TxState::Active)
326 .count()
327 }
328
329 pub fn gc(&self) -> usize {
337 let mut txns = self.transactions.write();
338 let mut committed = self.committed_epochs.write();
339
340 let min_active_start = txns
342 .values()
343 .filter(|info| info.state == TxState::Active)
344 .map(|info| info.start_epoch)
345 .min();
346
347 let initial_count = txns.len();
348
349 let to_remove: Vec<TxId> = txns
351 .iter()
352 .filter(|(tx_id, info)| {
353 match info.state {
354 TxState::Active => false, TxState::Aborted => true, TxState::Committed => {
357 if let Some(min_start) = min_active_start {
360 if let Some(commit_epoch) = committed.get(*tx_id) {
361 commit_epoch.as_u64() < min_start.as_u64()
363 } else {
364 false
366 }
367 } else {
368 true
370 }
371 }
372 }
373 })
374 .map(|(id, _)| *id)
375 .collect();
376
377 for id in &to_remove {
378 txns.remove(id);
379 committed.remove(id);
380 }
381
382 initial_count - txns.len()
383 }
384
385 pub fn mark_committed(&self, tx_id: TxId, epoch: EpochId) {
389 self.committed_epochs.write().insert(tx_id, epoch);
390 }
391
392 #[must_use]
396 pub fn last_assigned_tx_id(&self) -> Option<TxId> {
397 let next = self.next_tx_id.load(Ordering::Relaxed);
398 if next > 1 {
399 Some(TxId::new(next - 1))
400 } else {
401 None
402 }
403 }
404}
405
406impl Default for TransactionManager {
407 fn default() -> Self {
408 Self::new()
409 }
410}
411
412#[cfg(test)]
413mod tests {
414 use super::*;
415
416 #[test]
417 fn test_begin_commit() {
418 let mgr = TransactionManager::new();
419
420 let tx = mgr.begin();
421 assert_eq!(mgr.state(tx), Some(TxState::Active));
422
423 let commit_epoch = mgr.commit(tx).unwrap();
424 assert_eq!(mgr.state(tx), Some(TxState::Committed));
425 assert!(commit_epoch.as_u64() > 0);
426 }
427
428 #[test]
429 fn test_begin_abort() {
430 let mgr = TransactionManager::new();
431
432 let tx = mgr.begin();
433 mgr.abort(tx).unwrap();
434 assert_eq!(mgr.state(tx), Some(TxState::Aborted));
435 }
436
437 #[test]
438 fn test_epoch_advancement() {
439 let mgr = TransactionManager::new();
440
441 let initial_epoch = mgr.current_epoch();
442
443 let tx = mgr.begin();
444 let commit_epoch = mgr.commit(tx).unwrap();
445
446 assert!(mgr.current_epoch().as_u64() > initial_epoch.as_u64());
447 assert!(commit_epoch.as_u64() > initial_epoch.as_u64());
448 }
449
450 #[test]
451 fn test_gc_preserves_needed_write_sets() {
452 let mgr = TransactionManager::new();
453
454 let tx1 = mgr.begin();
455 let tx2 = mgr.begin();
456
457 mgr.commit(tx1).unwrap();
458 assert_eq!(mgr.active_count(), 1);
461
462 let cleaned = mgr.gc();
464 assert_eq!(cleaned, 0);
465
466 assert_eq!(mgr.state(tx1), Some(TxState::Committed));
468 assert_eq!(mgr.state(tx2), Some(TxState::Active));
469 }
470
471 #[test]
472 fn test_gc_removes_old_commits() {
473 let mgr = TransactionManager::new();
474
475 let tx1 = mgr.begin();
477 mgr.commit(tx1).unwrap();
478
479 let tx2 = mgr.begin();
481 mgr.commit(tx2).unwrap();
482
483 let tx3 = mgr.begin();
485
486 let cleaned = mgr.gc();
490 assert_eq!(cleaned, 1); assert_eq!(mgr.state(tx1), None);
493 assert_eq!(mgr.state(tx2), Some(TxState::Committed)); assert_eq!(mgr.state(tx3), Some(TxState::Active));
495
496 mgr.commit(tx3).unwrap();
498 let cleaned = mgr.gc();
499 assert_eq!(cleaned, 2); }
501
502 #[test]
503 fn test_gc_removes_aborted() {
504 let mgr = TransactionManager::new();
505
506 let tx1 = mgr.begin();
507 let tx2 = mgr.begin();
508
509 mgr.abort(tx1).unwrap();
510 let cleaned = mgr.gc();
514 assert_eq!(cleaned, 1);
515
516 assert_eq!(mgr.state(tx1), None);
517 assert_eq!(mgr.state(tx2), Some(TxState::Active));
518 }
519
520 #[test]
521 fn test_write_tracking() {
522 let mgr = TransactionManager::new();
523
524 let tx = mgr.begin();
525
526 mgr.record_write(tx, NodeId::new(1)).unwrap();
528 mgr.record_write(tx, NodeId::new(2)).unwrap();
529 mgr.record_write(tx, EdgeId::new(100)).unwrap();
530
531 assert!(mgr.commit(tx).is_ok());
533 }
534
535 #[test]
536 fn test_min_active_epoch() {
537 let mgr = TransactionManager::new();
538
539 assert_eq!(mgr.min_active_epoch(), mgr.current_epoch());
541
542 let tx1 = mgr.begin();
544 let epoch1 = mgr.start_epoch(tx1).unwrap();
545
546 let tx2 = mgr.begin();
548 mgr.commit(tx2).unwrap();
549
550 let _tx3 = mgr.begin();
551
552 assert_eq!(mgr.min_active_epoch(), epoch1);
554 }
555
556 #[test]
557 fn test_abort_all_active() {
558 let mgr = TransactionManager::new();
559
560 let tx1 = mgr.begin();
561 let tx2 = mgr.begin();
562 let tx3 = mgr.begin();
563
564 mgr.commit(tx1).unwrap();
565 mgr.abort_all_active();
568
569 assert_eq!(mgr.state(tx1), Some(TxState::Committed)); assert_eq!(mgr.state(tx2), Some(TxState::Aborted));
571 assert_eq!(mgr.state(tx3), Some(TxState::Aborted));
572 }
573
574 #[test]
575 fn test_start_epoch_snapshot() {
576 let mgr = TransactionManager::new();
577
578 let tx1 = mgr.begin();
580 let start1 = mgr.start_epoch(tx1).unwrap();
581
582 mgr.commit(tx1).unwrap();
584
585 let tx2 = mgr.begin();
587 let start2 = mgr.start_epoch(tx2).unwrap();
588
589 assert!(start2.as_u64() > start1.as_u64());
591 }
592
593 #[test]
594 fn test_write_write_conflict_detection() {
595 let mgr = TransactionManager::new();
596
597 let tx1 = mgr.begin();
599 let tx2 = mgr.begin();
600
601 let entity = NodeId::new(42);
603 mgr.record_write(tx1, entity).unwrap();
604 mgr.record_write(tx2, entity).unwrap();
605
606 let result1 = mgr.commit(tx1);
608 assert!(result1.is_ok());
609
610 let result2 = mgr.commit(tx2);
612 assert!(result2.is_err());
613 assert!(
614 result2
615 .unwrap_err()
616 .to_string()
617 .contains("Write-write conflict"),
618 "Expected write-write conflict error"
619 );
620 }
621
622 #[test]
623 fn test_commit_epoch_monotonicity() {
624 let mgr = TransactionManager::new();
625
626 let mut epochs = Vec::new();
627
628 for _ in 0..10 {
630 let tx = mgr.begin();
631 let epoch = mgr.commit(tx).unwrap();
632 epochs.push(epoch.as_u64());
633 }
634
635 for i in 1..epochs.len() {
637 assert!(
638 epochs[i] > epochs[i - 1],
639 "Epoch {} ({}) should be greater than epoch {} ({})",
640 i,
641 epochs[i],
642 i - 1,
643 epochs[i - 1]
644 );
645 }
646 }
647
648 #[test]
649 fn test_concurrent_commits_via_threads() {
650 use std::sync::Arc;
651 use std::thread;
652
653 let mgr = Arc::new(TransactionManager::new());
654 let num_threads = 10;
655 let commits_per_thread = 100;
656
657 let handles: Vec<_> = (0..num_threads)
658 .map(|_| {
659 let mgr = Arc::clone(&mgr);
660 thread::spawn(move || {
661 let mut epochs = Vec::new();
662 for _ in 0..commits_per_thread {
663 let tx = mgr.begin();
664 let epoch = mgr.commit(tx).unwrap();
665 epochs.push(epoch.as_u64());
666 }
667 epochs
668 })
669 })
670 .collect();
671
672 let mut all_epochs: Vec<u64> = handles
673 .into_iter()
674 .flat_map(|h| h.join().unwrap())
675 .collect();
676
677 all_epochs.sort();
679 let unique_count = all_epochs.len();
680 all_epochs.dedup();
681 assert_eq!(
682 all_epochs.len(),
683 unique_count,
684 "All commit epochs should be unique"
685 );
686
687 assert_eq!(
689 mgr.current_epoch().as_u64(),
690 (num_threads * commits_per_thread) as u64,
691 "Final epoch should equal total commits"
692 );
693 }
694}