1use std::collections::HashSet;
4use std::sync::atomic::{AtomicU64, Ordering};
5
6use grafeo_common::types::{EdgeId, EpochId, NodeId, TransactionId};
7use grafeo_common::utils::error::{Error, Result, TransactionError};
8use grafeo_common::utils::hash::FxHashMap;
9use parking_lot::RwLock;
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13#[non_exhaustive]
14pub enum TransactionState {
15 Active,
17 Committed,
19 Aborted,
21}
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
42#[non_exhaustive]
43pub enum IsolationLevel {
44 ReadCommitted,
49
50 #[default]
56 SnapshotIsolation,
57
58 Serializable,
64}
65
66#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
68#[non_exhaustive]
69pub enum EntityId {
70 Node(NodeId),
72 Edge(EdgeId),
74}
75
76impl From<NodeId> for EntityId {
77 fn from(id: NodeId) -> Self {
78 Self::Node(id)
79 }
80}
81
82impl From<EdgeId> for EntityId {
83 fn from(id: EdgeId) -> Self {
84 Self::Edge(id)
85 }
86}
87
88pub struct TransactionInfo {
90 pub state: TransactionState,
92 pub isolation_level: IsolationLevel,
94 pub start_epoch: EpochId,
96 pub write_set: HashSet<EntityId>,
98 pub read_set: HashSet<EntityId>,
100}
101
102impl TransactionInfo {
103 fn new(start_epoch: EpochId, isolation_level: IsolationLevel) -> Self {
105 Self {
106 state: TransactionState::Active,
107 isolation_level,
108 start_epoch,
109 write_set: HashSet::new(),
110 read_set: HashSet::new(),
111 }
112 }
113}
114
115pub struct TransactionManager {
117 next_transaction_id: AtomicU64,
119 current_epoch: AtomicU64,
121 active_count: AtomicU64,
123 transactions: RwLock<FxHashMap<TransactionId, TransactionInfo>>,
125 committed_epochs: RwLock<FxHashMap<TransactionId, EpochId>>,
128}
129
130impl TransactionManager {
131 #[must_use]
133 pub fn new() -> Self {
134 Self {
135 next_transaction_id: AtomicU64::new(2),
138 current_epoch: AtomicU64::new(0),
139 active_count: AtomicU64::new(0),
140 transactions: RwLock::new(FxHashMap::default()),
141 committed_epochs: RwLock::new(FxHashMap::default()),
142 }
143 }
144
145 pub fn begin(&self) -> TransactionId {
147 self.begin_with_isolation(IsolationLevel::default())
148 }
149
150 pub fn begin_with_isolation(&self, isolation_level: IsolationLevel) -> TransactionId {
152 let transaction_id =
153 TransactionId::new(self.next_transaction_id.fetch_add(1, Ordering::Relaxed));
154 let epoch = EpochId::new(self.current_epoch.load(Ordering::Acquire));
155
156 let info = TransactionInfo::new(epoch, isolation_level);
157 self.transactions.write().insert(transaction_id, info);
158 self.active_count.fetch_add(1, Ordering::Relaxed);
159 transaction_id
160 }
161
162 pub fn isolation_level(&self, transaction_id: TransactionId) -> Option<IsolationLevel> {
164 self.transactions
165 .read()
166 .get(&transaction_id)
167 .map(|info| info.isolation_level)
168 }
169
170 pub fn record_write(
181 &self,
182 transaction_id: TransactionId,
183 entity: impl Into<EntityId>,
184 ) -> Result<()> {
185 let entity = entity.into();
186 let mut txns = self.transactions.write();
187
188 if self.active_count.load(Ordering::Relaxed) > 1 {
191 for (other_tx, other_info) in txns.iter() {
192 if *other_tx != transaction_id
193 && other_info.state == TransactionState::Active
194 && other_info.write_set.contains(&entity)
195 {
196 return Err(Error::Transaction(TransactionError::WriteConflict(
197 format!("Write-write conflict on entity {entity:?}"),
198 )));
199 }
200 }
201 }
202
203 let info = txns.get_mut(&transaction_id).ok_or_else(|| {
205 Error::Transaction(TransactionError::InvalidState(
206 "Transaction not found".to_string(),
207 ))
208 })?;
209
210 if info.state != TransactionState::Active {
211 return Err(Error::Transaction(TransactionError::InvalidState(
212 "Transaction is not active".to_string(),
213 )));
214 }
215
216 info.write_set.insert(entity);
217 Ok(())
218 }
219
220 pub fn record_read(
226 &self,
227 transaction_id: TransactionId,
228 entity: impl Into<EntityId>,
229 ) -> Result<()> {
230 let mut txns = self.transactions.write();
231 let info = txns.get_mut(&transaction_id).ok_or_else(|| {
232 Error::Transaction(TransactionError::InvalidState(
233 "Transaction not found".to_string(),
234 ))
235 })?;
236
237 if info.state != TransactionState::Active {
238 return Err(Error::Transaction(TransactionError::InvalidState(
239 "Transaction is not active".to_string(),
240 )));
241 }
242
243 info.read_set.insert(entity.into());
244 Ok(())
245 }
246
247 pub fn commit(&self, transaction_id: TransactionId) -> Result<EpochId> {
265 let mut txns = self.transactions.write();
270 let mut committed = self.committed_epochs.write();
271
272 let (our_isolation, our_start_epoch, our_write_set, our_read_set) = {
274 let info = txns.get(&transaction_id).ok_or_else(|| {
275 Error::Transaction(TransactionError::InvalidState(
276 "Transaction not found".to_string(),
277 ))
278 })?;
279
280 if info.state != TransactionState::Active {
281 return Err(Error::Transaction(TransactionError::InvalidState(
282 "Transaction is not active".to_string(),
283 )));
284 }
285
286 (
287 info.isolation_level,
288 info.start_epoch,
289 info.write_set.clone(),
290 info.read_set.clone(),
291 )
292 };
293
294 for (other_tx, commit_epoch) in committed.iter() {
299 if *other_tx != transaction_id && commit_epoch.as_u64() > our_start_epoch.as_u64() {
300 if let Some(other_info) = txns.get(other_tx) {
302 for entity in &our_write_set {
303 if other_info.write_set.contains(entity) {
304 return Err(Error::Transaction(TransactionError::WriteConflict(
305 format!("Write-write conflict on entity {:?}", entity),
306 )));
307 }
308 }
309 }
310 }
311 }
312
313 if our_isolation == IsolationLevel::Serializable && !our_read_set.is_empty() {
323 for (other_tx, commit_epoch) in committed.iter() {
324 if *other_tx != transaction_id && commit_epoch.as_u64() > our_start_epoch.as_u64() {
325 if let Some(other_info) = txns.get(other_tx) {
327 for entity in &our_read_set {
328 if other_info.write_set.contains(entity) {
329 return Err(Error::Transaction(
330 TransactionError::SerializationFailure(format!(
331 "Read-write conflict on entity {:?}: \
332 another transaction modified data we read",
333 entity
334 )),
335 ));
336 }
337 }
338 }
339 }
340 }
341 }
342
343 let commit_epoch = EpochId::new(self.current_epoch.fetch_add(1, Ordering::SeqCst) + 1);
346
347 if let Some(info) = txns.get_mut(&transaction_id) {
349 info.state = TransactionState::Committed;
350 }
351 self.active_count.fetch_sub(1, Ordering::Relaxed);
352 committed.insert(transaction_id, commit_epoch);
353
354 Ok(commit_epoch)
355 }
356
357 pub fn abort(&self, transaction_id: TransactionId) -> Result<()> {
363 let mut txns = self.transactions.write();
364
365 let info = txns.get_mut(&transaction_id).ok_or_else(|| {
366 Error::Transaction(TransactionError::InvalidState(
367 "Transaction not found".to_string(),
368 ))
369 })?;
370
371 if info.state != TransactionState::Active {
372 return Err(Error::Transaction(TransactionError::InvalidState(
373 "Transaction is not active".to_string(),
374 )));
375 }
376
377 info.state = TransactionState::Aborted;
378 self.active_count.fetch_sub(1, Ordering::Relaxed);
379 Ok(())
380 }
381
382 pub fn get_write_set(&self, transaction_id: TransactionId) -> Result<HashSet<EntityId>> {
391 let txns = self.transactions.read();
392 let info = txns.get(&transaction_id).ok_or_else(|| {
393 Error::Transaction(TransactionError::InvalidState(
394 "Transaction not found".to_string(),
395 ))
396 })?;
397 Ok(info.write_set.clone())
398 }
399
400 pub fn reset_write_set(
406 &self,
407 transaction_id: TransactionId,
408 write_set: HashSet<EntityId>,
409 ) -> Result<()> {
410 let mut txns = self.transactions.write();
411 let info = txns.get_mut(&transaction_id).ok_or_else(|| {
412 Error::Transaction(TransactionError::InvalidState(
413 "Transaction not found".to_string(),
414 ))
415 })?;
416 info.write_set = write_set;
417 Ok(())
418 }
419
420 pub fn abort_all_active(&self) {
424 let mut txns = self.transactions.write();
425 for info in txns.values_mut() {
426 if info.state == TransactionState::Active {
427 info.state = TransactionState::Aborted;
428 self.active_count.fetch_sub(1, Ordering::Relaxed);
429 }
430 }
431 }
432
433 pub fn state(&self, transaction_id: TransactionId) -> Option<TransactionState> {
435 self.transactions
436 .read()
437 .get(&transaction_id)
438 .map(|info| info.state)
439 }
440
441 pub fn start_epoch(&self, transaction_id: TransactionId) -> Option<EpochId> {
443 self.transactions
444 .read()
445 .get(&transaction_id)
446 .map(|info| info.start_epoch)
447 }
448
449 #[must_use]
451 pub fn current_epoch(&self) -> EpochId {
452 EpochId::new(self.current_epoch.load(Ordering::Acquire))
453 }
454
455 pub fn sync_epoch(&self, epoch: EpochId) {
460 self.current_epoch
461 .fetch_max(epoch.as_u64(), Ordering::SeqCst);
462 }
463
464 #[must_use]
469 pub fn min_active_epoch(&self) -> EpochId {
470 let txns = self.transactions.read();
471 txns.values()
472 .filter(|info| info.state == TransactionState::Active)
473 .map(|info| info.start_epoch)
474 .min()
475 .unwrap_or_else(|| self.current_epoch())
476 }
477
478 #[must_use]
480 pub fn active_count(&self) -> usize {
481 self.transactions
482 .read()
483 .values()
484 .filter(|info| info.state == TransactionState::Active)
485 .count()
486 }
487
488 pub fn gc(&self) -> usize {
496 let mut txns = self.transactions.write();
497 let mut committed = self.committed_epochs.write();
498
499 let min_active_start = txns
501 .values()
502 .filter(|info| info.state == TransactionState::Active)
503 .map(|info| info.start_epoch)
504 .min();
505
506 let initial_count = txns.len();
507
508 let to_remove: Vec<TransactionId> = txns
510 .iter()
511 .filter(|(transaction_id, info)| {
512 match info.state {
513 TransactionState::Active => false, TransactionState::Aborted => true, TransactionState::Committed => {
516 if let Some(min_start) = min_active_start {
519 if let Some(commit_epoch) = committed.get(*transaction_id) {
520 commit_epoch.as_u64() < min_start.as_u64()
522 } else {
523 false
525 }
526 } else {
527 true
529 }
530 }
531 }
532 })
533 .map(|(id, _)| *id)
534 .collect();
535
536 for id in &to_remove {
537 txns.remove(id);
538 committed.remove(id);
539 }
540
541 initial_count - txns.len()
542 }
543
544 pub fn mark_committed(&self, transaction_id: TransactionId, epoch: EpochId) {
548 self.committed_epochs.write().insert(transaction_id, epoch);
549 }
550
551 #[must_use]
555 pub fn last_assigned_transaction_id(&self) -> Option<TransactionId> {
556 let next = self.next_transaction_id.load(Ordering::Relaxed);
557 if next > 1 {
558 Some(TransactionId::new(next - 1))
559 } else {
560 None
561 }
562 }
563
564 #[cfg(test)]
566 pub fn committed_epoch(&self, transaction_id: TransactionId) -> Option<EpochId> {
567 self.committed_epochs.read().get(&transaction_id).copied()
568 }
569}
570
571impl Default for TransactionManager {
572 fn default() -> Self {
573 Self::new()
574 }
575}
576
577#[cfg(test)]
578mod tests {
579 use super::*;
580
581 #[test]
582 fn test_begin_commit() {
583 let mgr = TransactionManager::new();
584
585 let tx = mgr.begin();
586 assert_eq!(mgr.state(tx), Some(TransactionState::Active));
587
588 let commit_epoch = mgr.commit(tx).unwrap();
589 assert_eq!(mgr.state(tx), Some(TransactionState::Committed));
590 assert!(commit_epoch.as_u64() > 0);
591 }
592
593 #[test]
594 fn test_begin_abort() {
595 let mgr = TransactionManager::new();
596
597 let tx = mgr.begin();
598 mgr.abort(tx).unwrap();
599 assert_eq!(mgr.state(tx), Some(TransactionState::Aborted));
600 }
601
602 #[test]
603 fn test_epoch_advancement() {
604 let mgr = TransactionManager::new();
605
606 let initial_epoch = mgr.current_epoch();
607
608 let tx = mgr.begin();
609 let commit_epoch = mgr.commit(tx).unwrap();
610
611 assert!(mgr.current_epoch().as_u64() > initial_epoch.as_u64());
612 assert!(commit_epoch.as_u64() > initial_epoch.as_u64());
613 }
614
615 #[test]
616 fn test_gc_preserves_needed_write_sets() {
617 let mgr = TransactionManager::new();
618
619 let tx1 = mgr.begin();
620 let tx2 = mgr.begin();
621
622 mgr.commit(tx1).unwrap();
623 assert_eq!(mgr.active_count(), 1);
626
627 let cleaned = mgr.gc();
629 assert_eq!(cleaned, 0);
630
631 assert_eq!(mgr.state(tx1), Some(TransactionState::Committed));
633 assert_eq!(mgr.state(tx2), Some(TransactionState::Active));
634 }
635
636 #[test]
637 fn test_gc_removes_old_commits() {
638 let mgr = TransactionManager::new();
639
640 let tx1 = mgr.begin();
642 mgr.commit(tx1).unwrap();
643
644 let tx2 = mgr.begin();
646 mgr.commit(tx2).unwrap();
647
648 let tx3 = mgr.begin();
650
651 let cleaned = mgr.gc();
655 assert_eq!(cleaned, 1); assert_eq!(mgr.state(tx1), None);
658 assert_eq!(mgr.state(tx2), Some(TransactionState::Committed)); assert_eq!(mgr.state(tx3), Some(TransactionState::Active));
660
661 mgr.commit(tx3).unwrap();
663 let cleaned = mgr.gc();
664 assert_eq!(cleaned, 2); }
666
667 #[test]
668 fn test_gc_removes_aborted() {
669 let mgr = TransactionManager::new();
670
671 let tx1 = mgr.begin();
672 let tx2 = mgr.begin();
673
674 mgr.abort(tx1).unwrap();
675 let cleaned = mgr.gc();
679 assert_eq!(cleaned, 1);
680
681 assert_eq!(mgr.state(tx1), None);
682 assert_eq!(mgr.state(tx2), Some(TransactionState::Active));
683 }
684
685 #[test]
686 fn test_write_tracking() {
687 let mgr = TransactionManager::new();
688
689 let tx = mgr.begin();
690
691 mgr.record_write(tx, NodeId::new(1)).unwrap();
693 mgr.record_write(tx, NodeId::new(2)).unwrap();
694 mgr.record_write(tx, EdgeId::new(100)).unwrap();
695
696 assert!(mgr.commit(tx).is_ok());
698 }
699
700 #[test]
701 fn test_min_active_epoch() {
702 let mgr = TransactionManager::new();
703
704 assert_eq!(mgr.min_active_epoch(), mgr.current_epoch());
706
707 let tx1 = mgr.begin();
709 let epoch1 = mgr.start_epoch(tx1).unwrap();
710
711 let tx2 = mgr.begin();
713 mgr.commit(tx2).unwrap();
714
715 let _tx3 = mgr.begin();
716
717 assert_eq!(mgr.min_active_epoch(), epoch1);
719 }
720
721 #[test]
722 fn test_abort_all_active() {
723 let mgr = TransactionManager::new();
724
725 let tx1 = mgr.begin();
726 let tx2 = mgr.begin();
727 let tx3 = mgr.begin();
728
729 mgr.commit(tx1).unwrap();
730 mgr.abort_all_active();
733
734 assert_eq!(mgr.state(tx1), Some(TransactionState::Committed)); assert_eq!(mgr.state(tx2), Some(TransactionState::Aborted));
736 assert_eq!(mgr.state(tx3), Some(TransactionState::Aborted));
737 }
738
739 #[test]
740 fn test_start_epoch_snapshot() {
741 let mgr = TransactionManager::new();
742
743 let tx1 = mgr.begin();
745 let start1 = mgr.start_epoch(tx1).unwrap();
746
747 mgr.commit(tx1).unwrap();
749
750 let tx2 = mgr.begin();
752 let start2 = mgr.start_epoch(tx2).unwrap();
753
754 assert!(start2.as_u64() > start1.as_u64());
756 }
757
758 #[test]
759 fn test_write_write_conflict_detection() {
760 let mgr = TransactionManager::new();
761
762 let tx1 = mgr.begin();
764 let tx2 = mgr.begin();
765
766 let entity = NodeId::new(42);
768 mgr.record_write(tx1, entity).unwrap();
769
770 let result = mgr.record_write(tx2, entity);
772 assert!(result.is_err());
773 assert!(
774 result
775 .unwrap_err()
776 .to_string()
777 .contains("Write-write conflict"),
778 "Expected write-write conflict error"
779 );
780
781 let result1 = mgr.commit(tx1);
783 assert!(result1.is_ok());
784 }
785
786 #[test]
787 fn test_commit_epoch_monotonicity() {
788 let mgr = TransactionManager::new();
789
790 let mut epochs = Vec::new();
791
792 for _ in 0..10 {
794 let tx = mgr.begin();
795 let epoch = mgr.commit(tx).unwrap();
796 epochs.push(epoch.as_u64());
797 }
798
799 for i in 1..epochs.len() {
801 assert!(
802 epochs[i] > epochs[i - 1],
803 "Epoch {} ({}) should be greater than epoch {} ({})",
804 i,
805 epochs[i],
806 i - 1,
807 epochs[i - 1]
808 );
809 }
810 }
811
812 #[test]
813 fn test_concurrent_commits_via_threads() {
814 use std::sync::Arc;
815 use std::thread;
816
817 let mgr = Arc::new(TransactionManager::new());
818 let num_threads = 10;
819 let commits_per_thread = 100;
820
821 let handles: Vec<_> = (0..num_threads)
822 .map(|_| {
823 let mgr = Arc::clone(&mgr);
824 thread::spawn(move || {
825 let mut epochs = Vec::new();
826 for _ in 0..commits_per_thread {
827 let tx = mgr.begin();
828 let epoch = mgr.commit(tx).unwrap();
829 epochs.push(epoch.as_u64());
830 }
831 epochs
832 })
833 })
834 .collect();
835
836 let mut all_epochs: Vec<u64> = handles
837 .into_iter()
838 .flat_map(|h| h.join().unwrap())
839 .collect();
840
841 all_epochs.sort_unstable();
843 let unique_count = all_epochs.len();
844 all_epochs.dedup();
845 assert_eq!(
846 all_epochs.len(),
847 unique_count,
848 "All commit epochs should be unique"
849 );
850
851 #[allow(clippy::cast_sign_loss)]
855 let expected_epoch = (num_threads * commits_per_thread) as u64;
856 assert_eq!(
857 mgr.current_epoch().as_u64(),
858 expected_epoch,
859 "Final epoch should equal total commits"
860 );
861 }
862
863 #[test]
864 fn test_isolation_level_default() {
865 let mgr = TransactionManager::new();
866
867 let tx = mgr.begin();
868 assert_eq!(
869 mgr.isolation_level(tx),
870 Some(IsolationLevel::SnapshotIsolation)
871 );
872 }
873
874 #[test]
875 fn test_isolation_level_explicit() {
876 let mgr = TransactionManager::new();
877
878 let transaction_rc = mgr.begin_with_isolation(IsolationLevel::ReadCommitted);
879 let transaction_si = mgr.begin_with_isolation(IsolationLevel::SnapshotIsolation);
880 let transaction_ser = mgr.begin_with_isolation(IsolationLevel::Serializable);
881
882 assert_eq!(
883 mgr.isolation_level(transaction_rc),
884 Some(IsolationLevel::ReadCommitted)
885 );
886 assert_eq!(
887 mgr.isolation_level(transaction_si),
888 Some(IsolationLevel::SnapshotIsolation)
889 );
890 assert_eq!(
891 mgr.isolation_level(transaction_ser),
892 Some(IsolationLevel::Serializable)
893 );
894 }
895
896 #[test]
897 fn test_ssi_read_write_conflict_detected() {
898 let mgr = TransactionManager::new();
899
900 let tx1 = mgr.begin_with_isolation(IsolationLevel::Serializable);
902
903 let tx2 = mgr.begin();
905
906 let entity = NodeId::new(42);
908 mgr.record_read(tx1, entity).unwrap();
909
910 mgr.record_write(tx2, entity).unwrap();
912 mgr.commit(tx2).unwrap();
913
914 let result = mgr.commit(tx1);
916 assert!(result.is_err());
917 assert!(
918 result
919 .unwrap_err()
920 .to_string()
921 .contains("Serialization failure"),
922 "Expected serialization failure error"
923 );
924 }
925
926 #[test]
927 fn test_ssi_no_conflict_when_not_serializable() {
928 let mgr = TransactionManager::new();
929
930 let tx1 = mgr.begin();
932
933 let tx2 = mgr.begin();
935
936 let entity = NodeId::new(42);
938 mgr.record_read(tx1, entity).unwrap();
939
940 mgr.record_write(tx2, entity).unwrap();
942 mgr.commit(tx2).unwrap();
943
944 let result = mgr.commit(tx1);
946 assert!(
947 result.is_ok(),
948 "Snapshot Isolation should not detect read-write conflicts"
949 );
950 }
951
952 #[test]
953 fn test_ssi_no_conflict_when_write_before_read() {
954 let mgr = TransactionManager::new();
955
956 let tx1 = mgr.begin();
958 let entity = NodeId::new(42);
959 mgr.record_write(tx1, entity).unwrap();
960 mgr.commit(tx1).unwrap();
961
962 let tx2 = mgr.begin_with_isolation(IsolationLevel::Serializable);
964 mgr.record_read(tx2, entity).unwrap();
965
966 let result = mgr.commit(tx2);
968 assert!(
969 result.is_ok(),
970 "Should not conflict when writer committed before reader started"
971 );
972 }
973
974 #[test]
975 fn test_write_skew_prevented_by_ssi() {
976 let mgr = TransactionManager::new();
983
984 let account_a = NodeId::new(1);
985 let account_b = NodeId::new(2);
986
987 let tx1 = mgr.begin_with_isolation(IsolationLevel::Serializable);
989 let tx2 = mgr.begin_with_isolation(IsolationLevel::Serializable);
990
991 mgr.record_read(tx1, account_a).unwrap();
993 mgr.record_read(tx1, account_b).unwrap();
994 mgr.record_read(tx2, account_a).unwrap();
995 mgr.record_read(tx2, account_b).unwrap();
996
997 mgr.record_write(tx1, account_a).unwrap();
999 mgr.record_write(tx2, account_b).unwrap();
1000
1001 let result1 = mgr.commit(tx1);
1003 assert!(result1.is_ok(), "First commit should succeed");
1004
1005 let result2 = mgr.commit(tx2);
1007 assert!(result2.is_err(), "Second commit should fail due to SSI");
1008 assert!(
1009 result2
1010 .unwrap_err()
1011 .to_string()
1012 .contains("Serialization failure"),
1013 "Expected serialization failure error for write skew prevention"
1014 );
1015 }
1016
1017 #[test]
1018 fn test_read_committed_allows_non_repeatable_reads() {
1019 let mgr = TransactionManager::new();
1020
1021 let tx1 = mgr.begin_with_isolation(IsolationLevel::ReadCommitted);
1023 let entity = NodeId::new(42);
1024
1025 mgr.record_read(tx1, entity).unwrap();
1027
1028 let tx2 = mgr.begin();
1030 mgr.record_write(tx2, entity).unwrap();
1031 mgr.commit(tx2).unwrap();
1032
1033 let result = mgr.commit(tx1);
1035 assert!(
1036 result.is_ok(),
1037 "ReadCommitted should allow non-repeatable reads"
1038 );
1039 }
1040
1041 #[test]
1042 fn test_isolation_level_debug() {
1043 assert_eq!(
1044 format!("{:?}", IsolationLevel::ReadCommitted),
1045 "ReadCommitted"
1046 );
1047 assert_eq!(
1048 format!("{:?}", IsolationLevel::SnapshotIsolation),
1049 "SnapshotIsolation"
1050 );
1051 assert_eq!(
1052 format!("{:?}", IsolationLevel::Serializable),
1053 "Serializable"
1054 );
1055 }
1056
1057 #[test]
1058 fn test_isolation_level_default_trait() {
1059 let default: IsolationLevel = Default::default();
1060 assert_eq!(default, IsolationLevel::SnapshotIsolation);
1061 }
1062
1063 #[test]
1064 fn test_ssi_concurrent_reads_no_conflict() {
1065 let mgr = TransactionManager::new();
1066
1067 let entity = NodeId::new(42);
1068
1069 let tx1 = mgr.begin_with_isolation(IsolationLevel::Serializable);
1071 let tx2 = mgr.begin_with_isolation(IsolationLevel::Serializable);
1072
1073 mgr.record_read(tx1, entity).unwrap();
1074 mgr.record_read(tx2, entity).unwrap();
1075
1076 assert!(mgr.commit(tx1).is_ok());
1078 assert!(mgr.commit(tx2).is_ok());
1079 }
1080
1081 #[test]
1082 fn test_ssi_write_write_conflict() {
1083 let mgr = TransactionManager::new();
1084
1085 let entity = NodeId::new(42);
1086
1087 let tx1 = mgr.begin_with_isolation(IsolationLevel::Serializable);
1089 let tx2 = mgr.begin_with_isolation(IsolationLevel::Serializable);
1090
1091 mgr.record_write(tx1, entity).unwrap();
1093
1094 let result = mgr.record_write(tx2, entity);
1096 assert!(
1097 result.is_err(),
1098 "Second record_write should fail with write-write conflict"
1099 );
1100
1101 assert!(mgr.commit(tx1).is_ok());
1103 }
1104
1105 #[test]
1106 fn test_ssi_concurrent_commit_race() {
1107 use std::sync::Arc;
1111
1112 let mgr = Arc::new(TransactionManager::new());
1113
1114 for _ in 0..100 {
1116 let entity_a = NodeId::new(1);
1117 let entity_b = NodeId::new(2);
1118
1119 let tx1 = mgr.begin_with_isolation(IsolationLevel::Serializable);
1122 let tx2 = mgr.begin_with_isolation(IsolationLevel::Serializable);
1123
1124 mgr.record_read(tx1, entity_a).unwrap();
1125 mgr.record_read(tx1, entity_b).unwrap();
1126 mgr.record_read(tx2, entity_a).unwrap();
1127 mgr.record_read(tx2, entity_b).unwrap();
1128
1129 mgr.record_write(tx1, entity_a).unwrap();
1130 mgr.record_write(tx2, entity_b).unwrap();
1131
1132 mgr.commit(tx1).unwrap();
1134
1135 let result = mgr.commit(tx2);
1137 assert!(
1138 result.is_err(),
1139 "SSI should detect read-write conflict on entity_a"
1140 );
1141
1142 let _ = mgr.abort(tx2);
1145 mgr.gc();
1146 }
1147 }
1148
1149 #[test]
1150 fn test_ssi_concurrent_commit_barrier() {
1151 use std::sync::{Arc, Barrier};
1154 use std::thread;
1155
1156 let mgr = Arc::new(TransactionManager::new());
1157 let mut both_ok_count = 0;
1158
1159 for _ in 0..50 {
1160 let entity_a = NodeId::new(1);
1161 let entity_b = NodeId::new(2);
1162
1163 let tx1 = mgr.begin_with_isolation(IsolationLevel::Serializable);
1164 let tx2 = mgr.begin_with_isolation(IsolationLevel::Serializable);
1165
1166 mgr.record_read(tx1, entity_a).unwrap();
1167 mgr.record_read(tx1, entity_b).unwrap();
1168 mgr.record_read(tx2, entity_a).unwrap();
1169 mgr.record_read(tx2, entity_b).unwrap();
1170
1171 mgr.record_write(tx1, entity_a).unwrap();
1172 mgr.record_write(tx2, entity_b).unwrap();
1173
1174 let mgr1 = Arc::clone(&mgr);
1175 let mgr2 = Arc::clone(&mgr);
1176 let barrier = Arc::new(Barrier::new(2));
1177 let b1 = Arc::clone(&barrier);
1178 let b2 = Arc::clone(&barrier);
1179
1180 let h1 = thread::spawn(move || {
1181 b1.wait();
1182 mgr1.commit(tx1)
1183 });
1184 let h2 = thread::spawn(move || {
1185 b2.wait();
1186 mgr2.commit(tx2)
1187 });
1188
1189 let r1 = h1.join().unwrap();
1190 let r2 = h2.join().unwrap();
1191
1192 if r1.is_ok() && r2.is_ok() {
1193 both_ok_count += 1;
1194 }
1195
1196 if r1.is_err() {
1198 let _ = mgr.abort(tx1);
1199 }
1200 if r2.is_err() {
1201 let _ = mgr.abort(tx2);
1202 }
1203 mgr.gc();
1204 }
1205
1206 assert_eq!(
1209 both_ok_count, 0,
1210 "SSI must prevent both concurrent write-skew commits from succeeding"
1211 );
1212 }
1213
1214 #[test]
1215 fn test_committed_epoch_present_after_commit() {
1216 let mgr = TransactionManager::new();
1219
1220 let tx = mgr.begin();
1221 mgr.record_write(tx, NodeId::new(1)).unwrap();
1222 let epoch = mgr.commit(tx).unwrap();
1223
1224 assert_eq!(
1226 mgr.committed_epoch(tx),
1227 Some(epoch),
1228 "committed_epochs must contain tx immediately after commit()"
1229 );
1230 }
1231}