1use crate::node::NodeId;
9use serde::{Deserialize, Serialize};
10use std::collections::{HashMap, HashSet};
11use std::sync::RwLock;
12
13#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
19pub struct TransactionId(String);
20
21impl TransactionId {
22 pub fn new(id: impl Into<String>) -> Self {
24 Self(id.into())
25 }
26
27 pub fn generate() -> Self {
29 let timestamp = std::time::SystemTime::now()
30 .duration_since(std::time::UNIX_EPOCH)
31 .unwrap_or_default()
32 .as_nanos();
33 Self(format!("txn_{:x}", timestamp))
34 }
35
36 pub fn as_str(&self) -> &str {
38 &self.0
39 }
40}
41
42impl std::fmt::Display for TransactionId {
43 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44 write!(f, "{}", self.0)
45 }
46}
47
48#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
54pub enum TransactionState {
55 Preparing,
57 Prepared,
59 Committing,
61 Committed,
63 Aborting,
65 Aborted,
67 Unknown,
69}
70
71impl TransactionState {
72 pub fn is_terminal(&self) -> bool {
74 matches!(self, Self::Committed | Self::Aborted)
75 }
76
77 pub fn can_commit(&self) -> bool {
79 matches!(self, Self::Prepared)
80 }
81
82 pub fn can_abort(&self) -> bool {
84 !matches!(self, Self::Committed)
85 }
86}
87
88#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
94pub enum ParticipantVote {
95 Commit,
97 Abort,
99 Pending,
101}
102
103#[derive(Debug, Clone, Serialize, Deserialize)]
109pub struct TransactionParticipant {
110 pub node_id: NodeId,
111 pub vote: ParticipantVote,
112 pub prepared: bool,
113 pub committed: bool,
114 pub last_contact: Option<u64>,
115}
116
117impl TransactionParticipant {
118 pub fn new(node_id: NodeId) -> Self {
120 Self {
121 node_id,
122 vote: ParticipantVote::Pending,
123 prepared: false,
124 committed: false,
125 last_contact: None,
126 }
127 }
128
129 pub fn record_prepare(&mut self, vote: ParticipantVote) {
131 self.vote = vote;
132 self.prepared = vote == ParticipantVote::Commit;
133 }
134
135 pub fn record_commit(&mut self) {
137 self.committed = true;
138 }
139}
140
141#[derive(Debug, Clone, Serialize, Deserialize)]
147pub struct DistributedTransaction {
148 pub id: TransactionId,
149 pub coordinator: NodeId,
150 pub participants: HashMap<NodeId, TransactionParticipant>,
151 pub state: TransactionState,
152 pub created_at: u64,
153 pub timeout_ms: u64,
154 pub operations: Vec<TransactionOperation>,
155}
156
157impl DistributedTransaction {
158 pub fn new(id: TransactionId, coordinator: NodeId, timeout_ms: u64) -> Self {
160 let created_at = std::time::SystemTime::now()
161 .duration_since(std::time::UNIX_EPOCH)
162 .unwrap_or_default()
163 .as_millis() as u64;
164
165 Self {
166 id,
167 coordinator,
168 participants: HashMap::new(),
169 state: TransactionState::Preparing,
170 created_at,
171 timeout_ms,
172 operations: Vec::new(),
173 }
174 }
175
176 pub fn add_participant(&mut self, node_id: NodeId) {
178 if !self.participants.contains_key(&node_id) {
179 self.participants
180 .insert(node_id.clone(), TransactionParticipant::new(node_id));
181 }
182 }
183
184 pub fn add_operation(&mut self, operation: TransactionOperation) {
186 self.operations.push(operation);
187 }
188
189 pub fn all_prepared(&self) -> bool {
191 self.participants.values().all(|p| p.prepared)
192 }
193
194 pub fn all_committed(&self) -> bool {
196 self.participants.values().all(|p| p.committed)
197 }
198
199 pub fn any_abort(&self) -> bool {
201 self.participants
202 .values()
203 .any(|p| p.vote == ParticipantVote::Abort)
204 }
205
206 pub fn is_timed_out(&self) -> bool {
208 let now = std::time::SystemTime::now()
209 .duration_since(std::time::UNIX_EPOCH)
210 .unwrap_or_default()
211 .as_millis() as u64;
212 now - self.created_at > self.timeout_ms
213 }
214
215 pub fn participant_count(&self) -> usize {
217 self.participants.len()
218 }
219
220 pub fn prepared_count(&self) -> usize {
222 self.participants.values().filter(|p| p.prepared).count()
223 }
224}
225
226#[derive(Debug, Clone, Serialize, Deserialize)]
232pub enum TransactionOperation {
233 Read {
235 key: String,
236 shard_id: String,
237 },
238 Write {
240 key: String,
241 value: Vec<u8>,
242 shard_id: String,
243 },
244 Delete {
246 key: String,
247 shard_id: String,
248 },
249 CompareAndSwap {
251 key: String,
252 expected: Option<Vec<u8>>,
253 new_value: Vec<u8>,
254 shard_id: String,
255 },
256}
257
258impl TransactionOperation {
259 pub fn shard_id(&self) -> &str {
261 match self {
262 Self::Read { shard_id, .. } => shard_id,
263 Self::Write { shard_id, .. } => shard_id,
264 Self::Delete { shard_id, .. } => shard_id,
265 Self::CompareAndSwap { shard_id, .. } => shard_id,
266 }
267 }
268
269 pub fn key(&self) -> &str {
271 match self {
272 Self::Read { key, .. } => key,
273 Self::Write { key, .. } => key,
274 Self::Delete { key, .. } => key,
275 Self::CompareAndSwap { key, .. } => key,
276 }
277 }
278
279 pub fn is_write(&self) -> bool {
281 !matches!(self, Self::Read { .. })
282 }
283}
284
285#[derive(Debug, Clone, Serialize, Deserialize)]
291pub enum TwoPhaseMessage {
292 PrepareRequest {
294 txn_id: TransactionId,
295 operations: Vec<TransactionOperation>,
296 },
297 PrepareResponse {
299 txn_id: TransactionId,
300 vote: ParticipantVote,
301 participant: NodeId,
302 },
303 CommitRequest {
305 txn_id: TransactionId,
306 },
307 CommitAck {
309 txn_id: TransactionId,
310 participant: NodeId,
311 },
312 AbortRequest {
314 txn_id: TransactionId,
315 },
316 AbortAck {
318 txn_id: TransactionId,
319 participant: NodeId,
320 },
321 StatusQuery {
323 txn_id: TransactionId,
324 },
325 StatusResponse {
327 txn_id: TransactionId,
328 state: TransactionState,
329 },
330}
331
332pub struct TransactionCoordinator {
338 node_id: NodeId,
339 transactions: RwLock<HashMap<TransactionId, DistributedTransaction>>,
340 default_timeout_ms: u64,
341 prepared_log: RwLock<HashSet<TransactionId>>,
342}
343
344impl TransactionCoordinator {
345 pub fn new(node_id: NodeId) -> Self {
347 Self {
348 node_id,
349 transactions: RwLock::new(HashMap::new()),
350 default_timeout_ms: 30000,
351 prepared_log: RwLock::new(HashSet::new()),
352 }
353 }
354
355 pub fn with_timeout(node_id: NodeId, timeout_ms: u64) -> Self {
357 Self {
358 node_id,
359 transactions: RwLock::new(HashMap::new()),
360 default_timeout_ms: timeout_ms,
361 prepared_log: RwLock::new(HashSet::new()),
362 }
363 }
364
365 pub fn begin_transaction(&self) -> TransactionId {
367 let txn_id = TransactionId::generate();
368 let txn = DistributedTransaction::new(
369 txn_id.clone(),
370 self.node_id.clone(),
371 self.default_timeout_ms,
372 );
373 self.transactions
374 .write()
375 .unwrap()
376 .insert(txn_id.clone(), txn);
377 txn_id
378 }
379
380 pub fn begin_transaction_with_id(&self, txn_id: TransactionId) {
382 let txn = DistributedTransaction::new(
383 txn_id.clone(),
384 self.node_id.clone(),
385 self.default_timeout_ms,
386 );
387 self.transactions.write().unwrap().insert(txn_id, txn);
388 }
389
390 pub fn add_participant(&self, txn_id: &TransactionId, node_id: NodeId) -> bool {
392 if let Some(txn) = self.transactions.write().unwrap().get_mut(txn_id) {
393 txn.add_participant(node_id);
394 true
395 } else {
396 false
397 }
398 }
399
400 pub fn add_operation(&self, txn_id: &TransactionId, operation: TransactionOperation) -> bool {
402 if let Some(txn) = self.transactions.write().unwrap().get_mut(txn_id) {
403 txn.add_operation(operation);
404 true
405 } else {
406 false
407 }
408 }
409
410 pub fn prepare(&self, txn_id: &TransactionId) -> Option<Vec<(NodeId, TwoPhaseMessage)>> {
412 let txns = self.transactions.read().unwrap();
413 let txn = txns.get(txn_id)?;
414
415 if txn.state != TransactionState::Preparing {
416 return None;
417 }
418
419 let messages: Vec<_> = txn
420 .participants
421 .keys()
422 .map(|node_id| {
423 (
424 node_id.clone(),
425 TwoPhaseMessage::PrepareRequest {
426 txn_id: txn_id.clone(),
427 operations: txn.operations.clone(),
428 },
429 )
430 })
431 .collect();
432
433 Some(messages)
434 }
435
436 pub fn handle_prepare_response(
438 &self,
439 txn_id: &TransactionId,
440 participant: &NodeId,
441 vote: ParticipantVote,
442 ) -> Option<TransactionState> {
443 let mut txns = self.transactions.write().unwrap();
444 let txn = txns.get_mut(txn_id)?;
445
446 if let Some(p) = txn.participants.get_mut(participant) {
447 p.record_prepare(vote);
448 }
449
450 if txn.any_abort() {
452 txn.state = TransactionState::Aborting;
453 Some(TransactionState::Aborting)
454 } else if txn.all_prepared() {
455 txn.state = TransactionState::Prepared;
456 self.prepared_log.write().unwrap().insert(txn_id.clone());
457 Some(TransactionState::Prepared)
458 } else {
459 None
460 }
461 }
462
463 pub fn commit(&self, txn_id: &TransactionId) -> Option<Vec<(NodeId, TwoPhaseMessage)>> {
465 let mut txns = self.transactions.write().unwrap();
466 let txn = txns.get_mut(txn_id)?;
467
468 if !txn.state.can_commit() {
469 return None;
470 }
471
472 txn.state = TransactionState::Committing;
473
474 let messages: Vec<_> = txn
475 .participants
476 .keys()
477 .map(|node_id| {
478 (
479 node_id.clone(),
480 TwoPhaseMessage::CommitRequest {
481 txn_id: txn_id.clone(),
482 },
483 )
484 })
485 .collect();
486
487 Some(messages)
488 }
489
490 pub fn handle_commit_ack(
492 &self,
493 txn_id: &TransactionId,
494 participant: &NodeId,
495 ) -> Option<TransactionState> {
496 let mut txns = self.transactions.write().unwrap();
497 let txn = txns.get_mut(txn_id)?;
498
499 if let Some(p) = txn.participants.get_mut(participant) {
500 p.record_commit();
501 }
502
503 if txn.all_committed() {
504 txn.state = TransactionState::Committed;
505 Some(TransactionState::Committed)
506 } else {
507 None
508 }
509 }
510
511 pub fn abort(&self, txn_id: &TransactionId) -> Option<Vec<(NodeId, TwoPhaseMessage)>> {
513 let mut txns = self.transactions.write().unwrap();
514 let txn = txns.get_mut(txn_id)?;
515
516 if !txn.state.can_abort() {
517 return None;
518 }
519
520 txn.state = TransactionState::Aborting;
521
522 let messages: Vec<_> = txn
523 .participants
524 .keys()
525 .map(|node_id| {
526 (
527 node_id.clone(),
528 TwoPhaseMessage::AbortRequest {
529 txn_id: txn_id.clone(),
530 },
531 )
532 })
533 .collect();
534
535 Some(messages)
536 }
537
538 pub fn handle_abort_ack(&self, txn_id: &TransactionId, _participant: &NodeId) -> bool {
540 let mut txns = self.transactions.write().unwrap();
541 if let Some(txn) = txns.get_mut(txn_id) {
542 txn.state = TransactionState::Aborted;
543 true
544 } else {
545 false
546 }
547 }
548
549 pub fn get_state(&self, txn_id: &TransactionId) -> Option<TransactionState> {
551 self.transactions
552 .read()
553 .unwrap()
554 .get(txn_id)
555 .map(|t| t.state)
556 }
557
558 pub fn get_transaction(&self, txn_id: &TransactionId) -> Option<DistributedTransaction> {
560 self.transactions.read().unwrap().get(txn_id).cloned()
561 }
562
563 pub fn check_timeouts(&self) -> Vec<TransactionId> {
565 self.transactions
566 .read()
567 .unwrap()
568 .iter()
569 .filter(|(_, txn)| txn.is_timed_out() && !txn.state.is_terminal())
570 .map(|(id, _)| id.clone())
571 .collect()
572 }
573
574 pub fn cleanup_completed(&self) -> usize {
576 let mut txns = self.transactions.write().unwrap();
577 let before = txns.len();
578 txns.retain(|_, txn| !txn.state.is_terminal());
579 before - txns.len()
580 }
581
582 pub fn active_count(&self) -> usize {
584 self.transactions
585 .read()
586 .unwrap()
587 .values()
588 .filter(|t| !t.state.is_terminal())
589 .count()
590 }
591
592 pub fn was_prepared(&self, txn_id: &TransactionId) -> bool {
594 self.prepared_log.read().unwrap().contains(txn_id)
595 }
596}
597
598pub type ValidationCallback = Box<dyn Fn(&TransactionId, &[TransactionOperation]) -> ValidationResult + Send + Sync>;
604
605pub type CommitCallback = Box<dyn Fn(&TransactionId, &[TransactionOperation]) -> Result<(), String> + Send + Sync>;
607
608pub type AbortCallback = Box<dyn Fn(&TransactionId) -> Result<(), String> + Send + Sync>;
610
611#[derive(Debug, Clone)]
613pub struct ValidationResult {
614 pub success: bool,
616 pub error: Option<String>,
618 pub locked_keys: Vec<String>,
620}
621
622impl ValidationResult {
623 pub fn success(locked_keys: Vec<String>) -> Self {
625 Self {
626 success: true,
627 error: None,
628 locked_keys,
629 }
630 }
631
632 pub fn failure(error: impl Into<String>) -> Self {
634 Self {
635 success: false,
636 error: Some(error.into()),
637 locked_keys: vec![],
638 }
639 }
640}
641
642pub struct ParticipantHandler {
644 node_id: NodeId,
645 pending_prepares: RwLock<HashMap<TransactionId, Vec<TransactionOperation>>>,
646 prepared: RwLock<HashSet<TransactionId>>,
647 committed: RwLock<HashSet<TransactionId>>,
648 locked_keys: RwLock<HashMap<TransactionId, Vec<String>>>,
649 validation_callback: RwLock<Option<ValidationCallback>>,
650 commit_callback: RwLock<Option<CommitCallback>>,
651 abort_callback: RwLock<Option<AbortCallback>>,
652}
653
654impl ParticipantHandler {
655 pub fn new(node_id: NodeId) -> Self {
657 Self {
658 node_id,
659 pending_prepares: RwLock::new(HashMap::new()),
660 prepared: RwLock::new(HashSet::new()),
661 committed: RwLock::new(HashSet::new()),
662 locked_keys: RwLock::new(HashMap::new()),
663 validation_callback: RwLock::new(None),
664 commit_callback: RwLock::new(None),
665 abort_callback: RwLock::new(None),
666 }
667 }
668
669 pub fn set_validation_callback(&self, callback: ValidationCallback) {
671 *self.validation_callback.write().unwrap() = Some(callback);
672 }
673
674 pub fn set_commit_callback(&self, callback: CommitCallback) {
676 *self.commit_callback.write().unwrap() = Some(callback);
677 }
678
679 pub fn set_abort_callback(&self, callback: AbortCallback) {
681 *self.abort_callback.write().unwrap() = Some(callback);
682 }
683
684 pub fn handle_prepare(
686 &self,
687 txn_id: &TransactionId,
688 operations: Vec<TransactionOperation>,
689 ) -> TwoPhaseMessage {
690 let validation_result = {
692 let callback_guard = self.validation_callback.read().unwrap();
693 if let Some(ref callback) = *callback_guard {
694 callback(txn_id, &operations)
696 } else {
697 self.basic_validation(txn_id, &operations)
699 }
700 };
701
702 if !validation_result.success {
703 return TwoPhaseMessage::PrepareResponse {
704 txn_id: txn_id.clone(),
705 vote: ParticipantVote::Abort,
706 participant: self.node_id.clone(),
707 };
708 }
709
710 self.locked_keys
712 .write()
713 .unwrap()
714 .insert(txn_id.clone(), validation_result.locked_keys);
715
716 self.pending_prepares
718 .write()
719 .unwrap()
720 .insert(txn_id.clone(), operations);
721 self.prepared.write().unwrap().insert(txn_id.clone());
722
723 TwoPhaseMessage::PrepareResponse {
724 txn_id: txn_id.clone(),
725 vote: ParticipantVote::Commit,
726 participant: self.node_id.clone(),
727 }
728 }
729
730 fn basic_validation(&self, txn_id: &TransactionId, operations: &[TransactionOperation]) -> ValidationResult {
732 let pending = self.pending_prepares.read().unwrap();
734 let locked = self.locked_keys.read().unwrap();
735
736 let mut keys_to_lock = Vec::new();
737 for op in operations {
738 let key = op.key().to_string();
739
740 for (other_txn_id, other_keys) in locked.iter() {
742 if other_txn_id != txn_id && other_keys.contains(&key) {
743 return ValidationResult::failure(format!(
744 "Key '{}' is locked by transaction {}",
745 key, other_txn_id
746 ));
747 }
748 }
749
750 for (other_txn_id, other_ops) in pending.iter() {
752 if other_txn_id != txn_id {
753 for other_op in other_ops {
754 if other_op.key() == key && other_op.is_write() && op.is_write() {
755 return ValidationResult::failure(format!(
756 "Write conflict on key '{}' with transaction {}",
757 key, other_txn_id
758 ));
759 }
760 }
761 }
762 }
763
764 if op.is_write() {
765 keys_to_lock.push(key);
766 }
767 }
768
769 ValidationResult::success(keys_to_lock)
770 }
771
772 pub fn handle_prepare_with_validation<F>(
774 &self,
775 txn_id: &TransactionId,
776 operations: Vec<TransactionOperation>,
777 validator: F,
778 ) -> TwoPhaseMessage
779 where
780 F: FnOnce(&[TransactionOperation]) -> bool,
781 {
782 let vote = if validator(&operations) {
783 self.pending_prepares
784 .write()
785 .unwrap()
786 .insert(txn_id.clone(), operations);
787 self.prepared.write().unwrap().insert(txn_id.clone());
788 ParticipantVote::Commit
789 } else {
790 ParticipantVote::Abort
791 };
792
793 TwoPhaseMessage::PrepareResponse {
794 txn_id: txn_id.clone(),
795 vote,
796 participant: self.node_id.clone(),
797 }
798 }
799
800 pub fn handle_commit(&self, txn_id: &TransactionId) -> TwoPhaseMessage {
802 let operations = self.pending_prepares.write().unwrap().remove(txn_id);
804
805 if let Some(ref callback) = *self.commit_callback.read().unwrap() {
807 if let Some(ops) = &operations {
808 if let Err(e) = callback(txn_id, ops) {
809 tracing::error!("Commit callback failed for {}: {}", txn_id, e);
811 }
812 }
813 }
814
815 self.prepared.write().unwrap().remove(txn_id);
817 self.locked_keys.write().unwrap().remove(txn_id);
818 self.committed.write().unwrap().insert(txn_id.clone());
819
820 TwoPhaseMessage::CommitAck {
821 txn_id: txn_id.clone(),
822 participant: self.node_id.clone(),
823 }
824 }
825
826 pub fn handle_abort(&self, txn_id: &TransactionId) -> TwoPhaseMessage {
828 if let Some(ref callback) = *self.abort_callback.read().unwrap() {
830 if let Err(e) = callback(txn_id) {
831 tracing::error!("Abort callback failed for {}: {}", txn_id, e);
832 }
833 }
834
835 self.pending_prepares.write().unwrap().remove(txn_id);
837 self.prepared.write().unwrap().remove(txn_id);
838 self.locked_keys.write().unwrap().remove(txn_id);
839
840 TwoPhaseMessage::AbortAck {
841 txn_id: txn_id.clone(),
842 participant: self.node_id.clone(),
843 }
844 }
845
846 pub fn get_locked_keys(&self, txn_id: &TransactionId) -> Vec<String> {
848 self.locked_keys
849 .read()
850 .unwrap()
851 .get(txn_id)
852 .cloned()
853 .unwrap_or_default()
854 }
855
856 pub fn is_key_locked(&self, key: &str) -> Option<TransactionId> {
858 let locked = self.locked_keys.read().unwrap();
859 for (txn_id, keys) in locked.iter() {
860 if keys.iter().any(|k| k == key) {
861 return Some(txn_id.clone());
862 }
863 }
864 None
865 }
866
867 pub fn is_prepared(&self, txn_id: &TransactionId) -> bool {
869 self.prepared.read().unwrap().contains(txn_id)
870 }
871
872 pub fn is_committed(&self, txn_id: &TransactionId) -> bool {
874 self.committed.read().unwrap().contains(txn_id)
875 }
876
877 pub fn pending_count(&self) -> usize {
879 self.pending_prepares.read().unwrap().len()
880 }
881}
882
883#[cfg(test)]
888mod tests {
889 use super::*;
890
891 #[test]
892 fn test_transaction_id() {
893 let id1 = TransactionId::new("txn_1");
894 let id2 = TransactionId::generate();
895
896 assert_eq!(id1.as_str(), "txn_1");
897 assert!(id2.as_str().starts_with("txn_"));
898 }
899
900 #[test]
901 fn test_transaction_state() {
902 assert!(!TransactionState::Preparing.is_terminal());
903 assert!(TransactionState::Committed.is_terminal());
904 assert!(TransactionState::Aborted.is_terminal());
905
906 assert!(TransactionState::Prepared.can_commit());
907 assert!(!TransactionState::Preparing.can_commit());
908
909 assert!(TransactionState::Preparing.can_abort());
910 assert!(!TransactionState::Committed.can_abort());
911 }
912
913 #[test]
914 fn test_distributed_transaction() {
915 let txn_id = TransactionId::new("txn_1");
916 let coordinator = NodeId::new("coord");
917 let mut txn = DistributedTransaction::new(txn_id, coordinator, 30000);
918
919 assert_eq!(txn.state, TransactionState::Preparing);
920 assert_eq!(txn.participant_count(), 0);
921
922 txn.add_participant(NodeId::new("node1"));
923 txn.add_participant(NodeId::new("node2"));
924
925 assert_eq!(txn.participant_count(), 2);
926 assert!(!txn.all_prepared());
927
928 txn.participants
929 .get_mut(&NodeId::new("node1"))
930 .unwrap()
931 .record_prepare(ParticipantVote::Commit);
932
933 assert!(!txn.all_prepared());
934 assert_eq!(txn.prepared_count(), 1);
935
936 txn.participants
937 .get_mut(&NodeId::new("node2"))
938 .unwrap()
939 .record_prepare(ParticipantVote::Commit);
940
941 assert!(txn.all_prepared());
942 assert!(!txn.any_abort());
943 }
944
945 #[test]
946 fn test_transaction_abort_vote() {
947 let txn_id = TransactionId::new("txn_1");
948 let mut txn = DistributedTransaction::new(txn_id, NodeId::new("coord"), 30000);
949
950 txn.add_participant(NodeId::new("node1"));
951 txn.add_participant(NodeId::new("node2"));
952
953 txn.participants
954 .get_mut(&NodeId::new("node1"))
955 .unwrap()
956 .record_prepare(ParticipantVote::Commit);
957 txn.participants
958 .get_mut(&NodeId::new("node2"))
959 .unwrap()
960 .record_prepare(ParticipantVote::Abort);
961
962 assert!(!txn.all_prepared());
963 assert!(txn.any_abort());
964 }
965
966 #[test]
967 fn test_transaction_operation() {
968 let write_op = TransactionOperation::Write {
969 key: "user:1".to_string(),
970 value: vec![1, 2, 3],
971 shard_id: "shard_1".to_string(),
972 };
973
974 assert_eq!(write_op.key(), "user:1");
975 assert_eq!(write_op.shard_id(), "shard_1");
976 assert!(write_op.is_write());
977
978 let read_op = TransactionOperation::Read {
979 key: "user:2".to_string(),
980 shard_id: "shard_2".to_string(),
981 };
982
983 assert!(!read_op.is_write());
984 }
985
986 #[test]
987 fn test_coordinator_begin_transaction() {
988 let coord = TransactionCoordinator::new(NodeId::new("coord"));
989 let txn_id = coord.begin_transaction();
990
991 assert!(coord.get_state(&txn_id).is_some());
992 assert_eq!(
993 coord.get_state(&txn_id).unwrap(),
994 TransactionState::Preparing
995 );
996 }
997
998 #[test]
999 fn test_coordinator_add_participant() {
1000 let coord = TransactionCoordinator::new(NodeId::new("coord"));
1001 let txn_id = coord.begin_transaction();
1002
1003 assert!(coord.add_participant(&txn_id, NodeId::new("node1")));
1004 assert!(coord.add_participant(&txn_id, NodeId::new("node2")));
1005
1006 let txn = coord.get_transaction(&txn_id).unwrap();
1007 assert_eq!(txn.participant_count(), 2);
1008 }
1009
1010 #[test]
1011 fn test_coordinator_prepare() {
1012 let coord = TransactionCoordinator::new(NodeId::new("coord"));
1013 let txn_id = coord.begin_transaction();
1014
1015 coord.add_participant(&txn_id, NodeId::new("node1"));
1016 coord.add_participant(&txn_id, NodeId::new("node2"));
1017
1018 let messages = coord.prepare(&txn_id).unwrap();
1019 assert_eq!(messages.len(), 2);
1020
1021 for (_, msg) in &messages {
1022 match msg {
1023 TwoPhaseMessage::PrepareRequest { txn_id: id, .. } => {
1024 assert_eq!(id, &txn_id);
1025 }
1026 _ => panic!("Expected PrepareRequest"),
1027 }
1028 }
1029 }
1030
1031 #[test]
1032 fn test_coordinator_full_commit() {
1033 let coord = TransactionCoordinator::new(NodeId::new("coord"));
1034 let txn_id = coord.begin_transaction();
1035
1036 let node1 = NodeId::new("node1");
1037 let node2 = NodeId::new("node2");
1038
1039 coord.add_participant(&txn_id, node1.clone());
1040 coord.add_participant(&txn_id, node2.clone());
1041
1042 coord.prepare(&txn_id);
1044
1045 coord.handle_prepare_response(&txn_id, &node1, ParticipantVote::Commit);
1047 let state = coord.handle_prepare_response(&txn_id, &node2, ParticipantVote::Commit);
1048
1049 assert_eq!(state, Some(TransactionState::Prepared));
1050
1051 let messages = coord.commit(&txn_id).unwrap();
1053 assert_eq!(messages.len(), 2);
1054
1055 coord.handle_commit_ack(&txn_id, &node1);
1057 let final_state = coord.handle_commit_ack(&txn_id, &node2);
1058
1059 assert_eq!(final_state, Some(TransactionState::Committed));
1060 }
1061
1062 #[test]
1063 fn test_coordinator_abort_on_vote() {
1064 let coord = TransactionCoordinator::new(NodeId::new("coord"));
1065 let txn_id = coord.begin_transaction();
1066
1067 let node1 = NodeId::new("node1");
1068 let node2 = NodeId::new("node2");
1069
1070 coord.add_participant(&txn_id, node1.clone());
1071 coord.add_participant(&txn_id, node2.clone());
1072
1073 coord.prepare(&txn_id);
1074
1075 coord.handle_prepare_response(&txn_id, &node1, ParticipantVote::Commit);
1076 let state = coord.handle_prepare_response(&txn_id, &node2, ParticipantVote::Abort);
1077
1078 assert_eq!(state, Some(TransactionState::Aborting));
1079 }
1080
1081 #[test]
1082 fn test_participant_handler() {
1083 let handler = ParticipantHandler::new(NodeId::new("node1"));
1084 let txn_id = TransactionId::new("txn_1");
1085
1086 let ops = vec![TransactionOperation::Write {
1087 key: "key1".to_string(),
1088 value: vec![1, 2, 3],
1089 shard_id: "shard_1".to_string(),
1090 }];
1091
1092 let response = handler.handle_prepare(&txn_id, ops);
1094 match response {
1095 TwoPhaseMessage::PrepareResponse { vote, .. } => {
1096 assert_eq!(vote, ParticipantVote::Commit);
1097 }
1098 _ => panic!("Expected PrepareResponse"),
1099 }
1100
1101 assert!(handler.is_prepared(&txn_id));
1102 assert!(!handler.is_committed(&txn_id));
1103
1104 let commit_response = handler.handle_commit(&txn_id);
1106 match commit_response {
1107 TwoPhaseMessage::CommitAck { .. } => {}
1108 _ => panic!("Expected CommitAck"),
1109 }
1110
1111 assert!(!handler.is_prepared(&txn_id));
1112 assert!(handler.is_committed(&txn_id));
1113 }
1114
1115 #[test]
1116 fn test_participant_abort() {
1117 let handler = ParticipantHandler::new(NodeId::new("node1"));
1118 let txn_id = TransactionId::new("txn_1");
1119
1120 let ops = vec![TransactionOperation::Write {
1121 key: "key1".to_string(),
1122 value: vec![1, 2, 3],
1123 shard_id: "shard_1".to_string(),
1124 }];
1125
1126 handler.handle_prepare(&txn_id, ops);
1127 assert!(handler.is_prepared(&txn_id));
1128
1129 let abort_response = handler.handle_abort(&txn_id);
1130 match abort_response {
1131 TwoPhaseMessage::AbortAck { .. } => {}
1132 _ => panic!("Expected AbortAck"),
1133 }
1134
1135 assert!(!handler.is_prepared(&txn_id));
1136 assert!(!handler.is_committed(&txn_id));
1137 }
1138
1139 #[test]
1140 fn test_coordinator_cleanup() {
1141 let coord = TransactionCoordinator::new(NodeId::new("coord"));
1142
1143 let txn_id = coord.begin_transaction();
1145 coord.add_participant(&txn_id, NodeId::new("node1"));
1146 coord.prepare(&txn_id);
1147 coord.handle_prepare_response(&txn_id, &NodeId::new("node1"), ParticipantVote::Commit);
1148 coord.commit(&txn_id);
1149 coord.handle_commit_ack(&txn_id, &NodeId::new("node1"));
1150
1151 assert_eq!(coord.get_state(&txn_id), Some(TransactionState::Committed));
1152
1153 let cleaned = coord.cleanup_completed();
1154 assert_eq!(cleaned, 1);
1155 assert!(coord.get_state(&txn_id).is_none());
1156 }
1157
1158 #[test]
1159 fn test_active_count() {
1160 let coord = TransactionCoordinator::new(NodeId::new("coord"));
1161
1162 let txn1 = coord.begin_transaction();
1163 let _txn2 = coord.begin_transaction();
1164
1165 assert_eq!(coord.active_count(), 2);
1166
1167 coord.add_participant(&txn1, NodeId::new("node1"));
1169 coord.prepare(&txn1);
1170 coord.handle_prepare_response(&txn1, &NodeId::new("node1"), ParticipantVote::Commit);
1171 coord.commit(&txn1);
1172 coord.handle_commit_ack(&txn1, &NodeId::new("node1"));
1173
1174 assert_eq!(coord.active_count(), 1);
1175 }
1176}