1use crate::node::NodeId;
9use serde::{Deserialize, Serialize};
10use std::collections::{HashMap, HashSet};
11use std::sync::RwLock;
12
13#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
19pub struct TransactionId(String);
20
21impl TransactionId {
22 pub fn new(id: impl Into<String>) -> Self {
24 Self(id.into())
25 }
26
27 pub fn generate() -> Self {
29 let timestamp = std::time::SystemTime::now()
30 .duration_since(std::time::UNIX_EPOCH)
31 .unwrap_or_default()
32 .as_nanos();
33 Self(format!("txn_{:x}", timestamp))
34 }
35
36 pub fn as_str(&self) -> &str {
38 &self.0
39 }
40}
41
42impl std::fmt::Display for TransactionId {
43 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44 write!(f, "{}", self.0)
45 }
46}
47
48#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
54pub enum TransactionState {
55 Preparing,
57 Prepared,
59 Committing,
61 Committed,
63 Aborting,
65 Aborted,
67 Unknown,
69}
70
71impl TransactionState {
72 pub fn is_terminal(&self) -> bool {
74 matches!(self, Self::Committed | Self::Aborted)
75 }
76
77 pub fn can_commit(&self) -> bool {
79 matches!(self, Self::Prepared)
80 }
81
82 pub fn can_abort(&self) -> bool {
84 !matches!(self, Self::Committed)
85 }
86}
87
88#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
94pub enum ParticipantVote {
95 Commit,
97 Abort,
99 Pending,
101}
102
103#[derive(Debug, Clone, Serialize, Deserialize)]
109pub struct TransactionParticipant {
110 pub node_id: NodeId,
111 pub vote: ParticipantVote,
112 pub prepared: bool,
113 pub committed: bool,
114 pub last_contact: Option<u64>,
115}
116
117impl TransactionParticipant {
118 pub fn new(node_id: NodeId) -> Self {
120 Self {
121 node_id,
122 vote: ParticipantVote::Pending,
123 prepared: false,
124 committed: false,
125 last_contact: None,
126 }
127 }
128
129 pub fn record_prepare(&mut self, vote: ParticipantVote) {
131 self.vote = vote;
132 self.prepared = vote == ParticipantVote::Commit;
133 }
134
135 pub fn record_commit(&mut self) {
137 self.committed = true;
138 }
139}
140
141#[derive(Debug, Clone, Serialize, Deserialize)]
147pub struct DistributedTransaction {
148 pub id: TransactionId,
149 pub coordinator: NodeId,
150 pub participants: HashMap<NodeId, TransactionParticipant>,
151 pub state: TransactionState,
152 pub created_at: u64,
153 pub timeout_ms: u64,
154 pub operations: Vec<TransactionOperation>,
155}
156
157impl DistributedTransaction {
158 pub fn new(id: TransactionId, coordinator: NodeId, timeout_ms: u64) -> Self {
160 let created_at = std::time::SystemTime::now()
161 .duration_since(std::time::UNIX_EPOCH)
162 .unwrap_or_default()
163 .as_millis() as u64;
164
165 Self {
166 id,
167 coordinator,
168 participants: HashMap::new(),
169 state: TransactionState::Preparing,
170 created_at,
171 timeout_ms,
172 operations: Vec::new(),
173 }
174 }
175
176 pub fn add_participant(&mut self, node_id: NodeId) {
178 if !self.participants.contains_key(&node_id) {
179 self.participants
180 .insert(node_id.clone(), TransactionParticipant::new(node_id));
181 }
182 }
183
184 pub fn add_operation(&mut self, operation: TransactionOperation) {
186 self.operations.push(operation);
187 }
188
189 pub fn all_prepared(&self) -> bool {
191 self.participants.values().all(|p| p.prepared)
192 }
193
194 pub fn all_committed(&self) -> bool {
196 self.participants.values().all(|p| p.committed)
197 }
198
199 pub fn any_abort(&self) -> bool {
201 self.participants
202 .values()
203 .any(|p| p.vote == ParticipantVote::Abort)
204 }
205
206 pub fn is_timed_out(&self) -> bool {
208 let now = std::time::SystemTime::now()
209 .duration_since(std::time::UNIX_EPOCH)
210 .unwrap_or_default()
211 .as_millis() as u64;
212 now - self.created_at > self.timeout_ms
213 }
214
215 pub fn participant_count(&self) -> usize {
217 self.participants.len()
218 }
219
220 pub fn prepared_count(&self) -> usize {
222 self.participants.values().filter(|p| p.prepared).count()
223 }
224}
225
226#[derive(Debug, Clone, Serialize, Deserialize)]
232pub enum TransactionOperation {
233 Read {
235 key: String,
236 shard_id: String,
237 },
238 Write {
240 key: String,
241 value: Vec<u8>,
242 shard_id: String,
243 },
244 Delete {
246 key: String,
247 shard_id: String,
248 },
249 CompareAndSwap {
251 key: String,
252 expected: Option<Vec<u8>>,
253 new_value: Vec<u8>,
254 shard_id: String,
255 },
256}
257
258impl TransactionOperation {
259 pub fn shard_id(&self) -> &str {
261 match self {
262 Self::Read { shard_id, .. } => shard_id,
263 Self::Write { shard_id, .. } => shard_id,
264 Self::Delete { shard_id, .. } => shard_id,
265 Self::CompareAndSwap { shard_id, .. } => shard_id,
266 }
267 }
268
269 pub fn key(&self) -> &str {
271 match self {
272 Self::Read { key, .. } => key,
273 Self::Write { key, .. } => key,
274 Self::Delete { key, .. } => key,
275 Self::CompareAndSwap { key, .. } => key,
276 }
277 }
278
279 pub fn is_write(&self) -> bool {
281 !matches!(self, Self::Read { .. })
282 }
283}
284
285#[derive(Debug, Clone, Serialize, Deserialize)]
291pub enum TwoPhaseMessage {
292 PrepareRequest {
294 txn_id: TransactionId,
295 operations: Vec<TransactionOperation>,
296 },
297 PrepareResponse {
299 txn_id: TransactionId,
300 vote: ParticipantVote,
301 participant: NodeId,
302 },
303 CommitRequest {
305 txn_id: TransactionId,
306 },
307 CommitAck {
309 txn_id: TransactionId,
310 participant: NodeId,
311 },
312 AbortRequest {
314 txn_id: TransactionId,
315 },
316 AbortAck {
318 txn_id: TransactionId,
319 participant: NodeId,
320 },
321 StatusQuery {
323 txn_id: TransactionId,
324 },
325 StatusResponse {
327 txn_id: TransactionId,
328 state: TransactionState,
329 },
330}
331
332pub struct TransactionCoordinator {
338 node_id: NodeId,
339 transactions: RwLock<HashMap<TransactionId, DistributedTransaction>>,
340 default_timeout_ms: u64,
341 prepared_log: RwLock<HashSet<TransactionId>>,
342}
343
344impl TransactionCoordinator {
345 pub fn new(node_id: NodeId) -> Self {
347 Self {
348 node_id,
349 transactions: RwLock::new(HashMap::new()),
350 default_timeout_ms: 30000,
351 prepared_log: RwLock::new(HashSet::new()),
352 }
353 }
354
355 pub fn with_timeout(node_id: NodeId, timeout_ms: u64) -> Self {
357 Self {
358 node_id,
359 transactions: RwLock::new(HashMap::new()),
360 default_timeout_ms: timeout_ms,
361 prepared_log: RwLock::new(HashSet::new()),
362 }
363 }
364
365 pub fn begin_transaction(&self) -> TransactionId {
367 let txn_id = TransactionId::generate();
368 let txn = DistributedTransaction::new(
369 txn_id.clone(),
370 self.node_id.clone(),
371 self.default_timeout_ms,
372 );
373 self.transactions
374 .write()
375 .unwrap()
376 .insert(txn_id.clone(), txn);
377 txn_id
378 }
379
380 pub fn begin_transaction_with_id(&self, txn_id: TransactionId) {
382 let txn = DistributedTransaction::new(
383 txn_id.clone(),
384 self.node_id.clone(),
385 self.default_timeout_ms,
386 );
387 self.transactions.write().unwrap().insert(txn_id, txn);
388 }
389
390 pub fn add_participant(&self, txn_id: &TransactionId, node_id: NodeId) -> bool {
392 if let Some(txn) = self.transactions.write().unwrap().get_mut(txn_id) {
393 txn.add_participant(node_id);
394 true
395 } else {
396 false
397 }
398 }
399
400 pub fn add_operation(&self, txn_id: &TransactionId, operation: TransactionOperation) -> bool {
402 if let Some(txn) = self.transactions.write().unwrap().get_mut(txn_id) {
403 txn.add_operation(operation);
404 true
405 } else {
406 false
407 }
408 }
409
410 pub fn prepare(&self, txn_id: &TransactionId) -> Option<Vec<(NodeId, TwoPhaseMessage)>> {
412 let txns = self.transactions.read().unwrap();
413 let txn = txns.get(txn_id)?;
414
415 if txn.state != TransactionState::Preparing {
416 return None;
417 }
418
419 let messages: Vec<_> = txn
420 .participants
421 .keys()
422 .map(|node_id| {
423 (
424 node_id.clone(),
425 TwoPhaseMessage::PrepareRequest {
426 txn_id: txn_id.clone(),
427 operations: txn.operations.clone(),
428 },
429 )
430 })
431 .collect();
432
433 Some(messages)
434 }
435
436 pub fn handle_prepare_response(
438 &self,
439 txn_id: &TransactionId,
440 participant: &NodeId,
441 vote: ParticipantVote,
442 ) -> Option<TransactionState> {
443 let mut txns = self.transactions.write().unwrap();
444 let txn = txns.get_mut(txn_id)?;
445
446 if let Some(p) = txn.participants.get_mut(participant) {
447 p.record_prepare(vote);
448 }
449
450 if txn.any_abort() {
452 txn.state = TransactionState::Aborting;
453 Some(TransactionState::Aborting)
454 } else if txn.all_prepared() {
455 txn.state = TransactionState::Prepared;
456 self.prepared_log.write().unwrap().insert(txn_id.clone());
457 Some(TransactionState::Prepared)
458 } else {
459 None
460 }
461 }
462
463 pub fn commit(&self, txn_id: &TransactionId) -> Option<Vec<(NodeId, TwoPhaseMessage)>> {
465 let mut txns = self.transactions.write().unwrap();
466 let txn = txns.get_mut(txn_id)?;
467
468 if !txn.state.can_commit() {
469 return None;
470 }
471
472 txn.state = TransactionState::Committing;
473
474 let messages: Vec<_> = txn
475 .participants
476 .keys()
477 .map(|node_id| {
478 (
479 node_id.clone(),
480 TwoPhaseMessage::CommitRequest {
481 txn_id: txn_id.clone(),
482 },
483 )
484 })
485 .collect();
486
487 Some(messages)
488 }
489
490 pub fn handle_commit_ack(
492 &self,
493 txn_id: &TransactionId,
494 participant: &NodeId,
495 ) -> Option<TransactionState> {
496 let mut txns = self.transactions.write().unwrap();
497 let txn = txns.get_mut(txn_id)?;
498
499 if let Some(p) = txn.participants.get_mut(participant) {
500 p.record_commit();
501 }
502
503 if txn.all_committed() {
504 txn.state = TransactionState::Committed;
505 Some(TransactionState::Committed)
506 } else {
507 None
508 }
509 }
510
511 pub fn abort(&self, txn_id: &TransactionId) -> Option<Vec<(NodeId, TwoPhaseMessage)>> {
513 let mut txns = self.transactions.write().unwrap();
514 let txn = txns.get_mut(txn_id)?;
515
516 if !txn.state.can_abort() {
517 return None;
518 }
519
520 txn.state = TransactionState::Aborting;
521
522 let messages: Vec<_> = txn
523 .participants
524 .keys()
525 .map(|node_id| {
526 (
527 node_id.clone(),
528 TwoPhaseMessage::AbortRequest {
529 txn_id: txn_id.clone(),
530 },
531 )
532 })
533 .collect();
534
535 Some(messages)
536 }
537
538 pub fn handle_abort_ack(&self, txn_id: &TransactionId, _participant: &NodeId) -> bool {
540 let mut txns = self.transactions.write().unwrap();
541 if let Some(txn) = txns.get_mut(txn_id) {
542 txn.state = TransactionState::Aborted;
543 true
544 } else {
545 false
546 }
547 }
548
549 pub fn get_state(&self, txn_id: &TransactionId) -> Option<TransactionState> {
551 self.transactions
552 .read()
553 .unwrap()
554 .get(txn_id)
555 .map(|t| t.state)
556 }
557
558 pub fn get_transaction(&self, txn_id: &TransactionId) -> Option<DistributedTransaction> {
560 self.transactions.read().unwrap().get(txn_id).cloned()
561 }
562
563 pub fn check_timeouts(&self) -> Vec<TransactionId> {
565 self.transactions
566 .read()
567 .unwrap()
568 .iter()
569 .filter(|(_, txn)| txn.is_timed_out() && !txn.state.is_terminal())
570 .map(|(id, _)| id.clone())
571 .collect()
572 }
573
574 pub fn cleanup_completed(&self) -> usize {
576 let mut txns = self.transactions.write().unwrap();
577 let before = txns.len();
578 txns.retain(|_, txn| !txn.state.is_terminal());
579 before - txns.len()
580 }
581
582 pub fn active_count(&self) -> usize {
584 self.transactions
585 .read()
586 .unwrap()
587 .values()
588 .filter(|t| !t.state.is_terminal())
589 .count()
590 }
591
592 pub fn was_prepared(&self, txn_id: &TransactionId) -> bool {
594 self.prepared_log.read().unwrap().contains(txn_id)
595 }
596}
597
598pub struct ParticipantHandler {
604 node_id: NodeId,
605 pending_prepares: RwLock<HashMap<TransactionId, Vec<TransactionOperation>>>,
606 prepared: RwLock<HashSet<TransactionId>>,
607 committed: RwLock<HashSet<TransactionId>>,
608}
609
610impl ParticipantHandler {
611 pub fn new(node_id: NodeId) -> Self {
613 Self {
614 node_id,
615 pending_prepares: RwLock::new(HashMap::new()),
616 prepared: RwLock::new(HashSet::new()),
617 committed: RwLock::new(HashSet::new()),
618 }
619 }
620
621 pub fn handle_prepare(
623 &self,
624 txn_id: &TransactionId,
625 operations: Vec<TransactionOperation>,
626 ) -> TwoPhaseMessage {
627 self.pending_prepares
634 .write()
635 .unwrap()
636 .insert(txn_id.clone(), operations);
637 self.prepared.write().unwrap().insert(txn_id.clone());
638
639 TwoPhaseMessage::PrepareResponse {
640 txn_id: txn_id.clone(),
641 vote: ParticipantVote::Commit,
642 participant: self.node_id.clone(),
643 }
644 }
645
646 pub fn handle_prepare_with_validation<F>(
648 &self,
649 txn_id: &TransactionId,
650 operations: Vec<TransactionOperation>,
651 validator: F,
652 ) -> TwoPhaseMessage
653 where
654 F: FnOnce(&[TransactionOperation]) -> bool,
655 {
656 let vote = if validator(&operations) {
657 self.pending_prepares
658 .write()
659 .unwrap()
660 .insert(txn_id.clone(), operations);
661 self.prepared.write().unwrap().insert(txn_id.clone());
662 ParticipantVote::Commit
663 } else {
664 ParticipantVote::Abort
665 };
666
667 TwoPhaseMessage::PrepareResponse {
668 txn_id: txn_id.clone(),
669 vote,
670 participant: self.node_id.clone(),
671 }
672 }
673
674 pub fn handle_commit(&self, txn_id: &TransactionId) -> TwoPhaseMessage {
676 self.pending_prepares.write().unwrap().remove(txn_id);
678 self.prepared.write().unwrap().remove(txn_id);
679 self.committed.write().unwrap().insert(txn_id.clone());
680
681 TwoPhaseMessage::CommitAck {
682 txn_id: txn_id.clone(),
683 participant: self.node_id.clone(),
684 }
685 }
686
687 pub fn handle_abort(&self, txn_id: &TransactionId) -> TwoPhaseMessage {
689 self.pending_prepares.write().unwrap().remove(txn_id);
691 self.prepared.write().unwrap().remove(txn_id);
692
693 TwoPhaseMessage::AbortAck {
694 txn_id: txn_id.clone(),
695 participant: self.node_id.clone(),
696 }
697 }
698
699 pub fn is_prepared(&self, txn_id: &TransactionId) -> bool {
701 self.prepared.read().unwrap().contains(txn_id)
702 }
703
704 pub fn is_committed(&self, txn_id: &TransactionId) -> bool {
706 self.committed.read().unwrap().contains(txn_id)
707 }
708
709 pub fn pending_count(&self) -> usize {
711 self.pending_prepares.read().unwrap().len()
712 }
713}
714
715#[cfg(test)]
720mod tests {
721 use super::*;
722
723 #[test]
724 fn test_transaction_id() {
725 let id1 = TransactionId::new("txn_1");
726 let id2 = TransactionId::generate();
727
728 assert_eq!(id1.as_str(), "txn_1");
729 assert!(id2.as_str().starts_with("txn_"));
730 }
731
732 #[test]
733 fn test_transaction_state() {
734 assert!(!TransactionState::Preparing.is_terminal());
735 assert!(TransactionState::Committed.is_terminal());
736 assert!(TransactionState::Aborted.is_terminal());
737
738 assert!(TransactionState::Prepared.can_commit());
739 assert!(!TransactionState::Preparing.can_commit());
740
741 assert!(TransactionState::Preparing.can_abort());
742 assert!(!TransactionState::Committed.can_abort());
743 }
744
745 #[test]
746 fn test_distributed_transaction() {
747 let txn_id = TransactionId::new("txn_1");
748 let coordinator = NodeId::new("coord");
749 let mut txn = DistributedTransaction::new(txn_id, coordinator, 30000);
750
751 assert_eq!(txn.state, TransactionState::Preparing);
752 assert_eq!(txn.participant_count(), 0);
753
754 txn.add_participant(NodeId::new("node1"));
755 txn.add_participant(NodeId::new("node2"));
756
757 assert_eq!(txn.participant_count(), 2);
758 assert!(!txn.all_prepared());
759
760 txn.participants
761 .get_mut(&NodeId::new("node1"))
762 .unwrap()
763 .record_prepare(ParticipantVote::Commit);
764
765 assert!(!txn.all_prepared());
766 assert_eq!(txn.prepared_count(), 1);
767
768 txn.participants
769 .get_mut(&NodeId::new("node2"))
770 .unwrap()
771 .record_prepare(ParticipantVote::Commit);
772
773 assert!(txn.all_prepared());
774 assert!(!txn.any_abort());
775 }
776
777 #[test]
778 fn test_transaction_abort_vote() {
779 let txn_id = TransactionId::new("txn_1");
780 let mut txn = DistributedTransaction::new(txn_id, NodeId::new("coord"), 30000);
781
782 txn.add_participant(NodeId::new("node1"));
783 txn.add_participant(NodeId::new("node2"));
784
785 txn.participants
786 .get_mut(&NodeId::new("node1"))
787 .unwrap()
788 .record_prepare(ParticipantVote::Commit);
789 txn.participants
790 .get_mut(&NodeId::new("node2"))
791 .unwrap()
792 .record_prepare(ParticipantVote::Abort);
793
794 assert!(!txn.all_prepared());
795 assert!(txn.any_abort());
796 }
797
798 #[test]
799 fn test_transaction_operation() {
800 let write_op = TransactionOperation::Write {
801 key: "user:1".to_string(),
802 value: vec![1, 2, 3],
803 shard_id: "shard_1".to_string(),
804 };
805
806 assert_eq!(write_op.key(), "user:1");
807 assert_eq!(write_op.shard_id(), "shard_1");
808 assert!(write_op.is_write());
809
810 let read_op = TransactionOperation::Read {
811 key: "user:2".to_string(),
812 shard_id: "shard_2".to_string(),
813 };
814
815 assert!(!read_op.is_write());
816 }
817
818 #[test]
819 fn test_coordinator_begin_transaction() {
820 let coord = TransactionCoordinator::new(NodeId::new("coord"));
821 let txn_id = coord.begin_transaction();
822
823 assert!(coord.get_state(&txn_id).is_some());
824 assert_eq!(
825 coord.get_state(&txn_id).unwrap(),
826 TransactionState::Preparing
827 );
828 }
829
830 #[test]
831 fn test_coordinator_add_participant() {
832 let coord = TransactionCoordinator::new(NodeId::new("coord"));
833 let txn_id = coord.begin_transaction();
834
835 assert!(coord.add_participant(&txn_id, NodeId::new("node1")));
836 assert!(coord.add_participant(&txn_id, NodeId::new("node2")));
837
838 let txn = coord.get_transaction(&txn_id).unwrap();
839 assert_eq!(txn.participant_count(), 2);
840 }
841
842 #[test]
843 fn test_coordinator_prepare() {
844 let coord = TransactionCoordinator::new(NodeId::new("coord"));
845 let txn_id = coord.begin_transaction();
846
847 coord.add_participant(&txn_id, NodeId::new("node1"));
848 coord.add_participant(&txn_id, NodeId::new("node2"));
849
850 let messages = coord.prepare(&txn_id).unwrap();
851 assert_eq!(messages.len(), 2);
852
853 for (_, msg) in &messages {
854 match msg {
855 TwoPhaseMessage::PrepareRequest { txn_id: id, .. } => {
856 assert_eq!(id, &txn_id);
857 }
858 _ => panic!("Expected PrepareRequest"),
859 }
860 }
861 }
862
863 #[test]
864 fn test_coordinator_full_commit() {
865 let coord = TransactionCoordinator::new(NodeId::new("coord"));
866 let txn_id = coord.begin_transaction();
867
868 let node1 = NodeId::new("node1");
869 let node2 = NodeId::new("node2");
870
871 coord.add_participant(&txn_id, node1.clone());
872 coord.add_participant(&txn_id, node2.clone());
873
874 coord.prepare(&txn_id);
876
877 coord.handle_prepare_response(&txn_id, &node1, ParticipantVote::Commit);
879 let state = coord.handle_prepare_response(&txn_id, &node2, ParticipantVote::Commit);
880
881 assert_eq!(state, Some(TransactionState::Prepared));
882
883 let messages = coord.commit(&txn_id).unwrap();
885 assert_eq!(messages.len(), 2);
886
887 coord.handle_commit_ack(&txn_id, &node1);
889 let final_state = coord.handle_commit_ack(&txn_id, &node2);
890
891 assert_eq!(final_state, Some(TransactionState::Committed));
892 }
893
894 #[test]
895 fn test_coordinator_abort_on_vote() {
896 let coord = TransactionCoordinator::new(NodeId::new("coord"));
897 let txn_id = coord.begin_transaction();
898
899 let node1 = NodeId::new("node1");
900 let node2 = NodeId::new("node2");
901
902 coord.add_participant(&txn_id, node1.clone());
903 coord.add_participant(&txn_id, node2.clone());
904
905 coord.prepare(&txn_id);
906
907 coord.handle_prepare_response(&txn_id, &node1, ParticipantVote::Commit);
908 let state = coord.handle_prepare_response(&txn_id, &node2, ParticipantVote::Abort);
909
910 assert_eq!(state, Some(TransactionState::Aborting));
911 }
912
913 #[test]
914 fn test_participant_handler() {
915 let handler = ParticipantHandler::new(NodeId::new("node1"));
916 let txn_id = TransactionId::new("txn_1");
917
918 let ops = vec![TransactionOperation::Write {
919 key: "key1".to_string(),
920 value: vec![1, 2, 3],
921 shard_id: "shard_1".to_string(),
922 }];
923
924 let response = handler.handle_prepare(&txn_id, ops);
926 match response {
927 TwoPhaseMessage::PrepareResponse { vote, .. } => {
928 assert_eq!(vote, ParticipantVote::Commit);
929 }
930 _ => panic!("Expected PrepareResponse"),
931 }
932
933 assert!(handler.is_prepared(&txn_id));
934 assert!(!handler.is_committed(&txn_id));
935
936 let commit_response = handler.handle_commit(&txn_id);
938 match commit_response {
939 TwoPhaseMessage::CommitAck { .. } => {}
940 _ => panic!("Expected CommitAck"),
941 }
942
943 assert!(!handler.is_prepared(&txn_id));
944 assert!(handler.is_committed(&txn_id));
945 }
946
947 #[test]
948 fn test_participant_abort() {
949 let handler = ParticipantHandler::new(NodeId::new("node1"));
950 let txn_id = TransactionId::new("txn_1");
951
952 let ops = vec![TransactionOperation::Write {
953 key: "key1".to_string(),
954 value: vec![1, 2, 3],
955 shard_id: "shard_1".to_string(),
956 }];
957
958 handler.handle_prepare(&txn_id, ops);
959 assert!(handler.is_prepared(&txn_id));
960
961 let abort_response = handler.handle_abort(&txn_id);
962 match abort_response {
963 TwoPhaseMessage::AbortAck { .. } => {}
964 _ => panic!("Expected AbortAck"),
965 }
966
967 assert!(!handler.is_prepared(&txn_id));
968 assert!(!handler.is_committed(&txn_id));
969 }
970
971 #[test]
972 fn test_coordinator_cleanup() {
973 let coord = TransactionCoordinator::new(NodeId::new("coord"));
974
975 let txn_id = coord.begin_transaction();
977 coord.add_participant(&txn_id, NodeId::new("node1"));
978 coord.prepare(&txn_id);
979 coord.handle_prepare_response(&txn_id, &NodeId::new("node1"), ParticipantVote::Commit);
980 coord.commit(&txn_id);
981 coord.handle_commit_ack(&txn_id, &NodeId::new("node1"));
982
983 assert_eq!(coord.get_state(&txn_id), Some(TransactionState::Committed));
984
985 let cleaned = coord.cleanup_completed();
986 assert_eq!(cleaned, 1);
987 assert!(coord.get_state(&txn_id).is_none());
988 }
989
990 #[test]
991 fn test_active_count() {
992 let coord = TransactionCoordinator::new(NodeId::new("coord"));
993
994 let txn1 = coord.begin_transaction();
995 let _txn2 = coord.begin_transaction();
996
997 assert_eq!(coord.active_count(), 2);
998
999 coord.add_participant(&txn1, NodeId::new("node1"));
1001 coord.prepare(&txn1);
1002 coord.handle_prepare_response(&txn1, &NodeId::new("node1"), ParticipantVote::Commit);
1003 coord.commit(&txn1);
1004 coord.handle_commit_ack(&txn1, &NodeId::new("node1"));
1005
1006 assert_eq!(coord.active_count(), 1);
1007 }
1008}