1use crate::node::NodeId;
9use serde::{Deserialize, Serialize};
10use std::collections::{HashMap, HashSet};
11use std::sync::RwLock;
12
13#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
19pub struct TransactionId(String);
20
21impl TransactionId {
22 pub fn new(id: impl Into<String>) -> Self {
24 Self(id.into())
25 }
26
27 pub fn generate() -> Self {
29 let timestamp = std::time::SystemTime::now()
30 .duration_since(std::time::UNIX_EPOCH)
31 .unwrap_or_default()
32 .as_nanos();
33 Self(format!("txn_{:x}", timestamp))
34 }
35
36 pub fn as_str(&self) -> &str {
38 &self.0
39 }
40}
41
42impl std::fmt::Display for TransactionId {
43 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44 write!(f, "{}", self.0)
45 }
46}
47
48#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
54pub enum TransactionState {
55 Preparing,
57 Prepared,
59 Committing,
61 Committed,
63 Aborting,
65 Aborted,
67 Unknown,
69}
70
71impl TransactionState {
72 pub fn is_terminal(&self) -> bool {
74 matches!(self, Self::Committed | Self::Aborted)
75 }
76
77 pub fn can_commit(&self) -> bool {
79 matches!(self, Self::Prepared)
80 }
81
82 pub fn can_abort(&self) -> bool {
84 !matches!(self, Self::Committed)
85 }
86}
87
88#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
94pub enum ParticipantVote {
95 Commit,
97 Abort,
99 Pending,
101}
102
103#[derive(Debug, Clone, Serialize, Deserialize)]
109pub struct TransactionParticipant {
110 pub node_id: NodeId,
111 pub vote: ParticipantVote,
112 pub prepared: bool,
113 pub committed: bool,
114 pub last_contact: Option<u64>,
115}
116
117impl TransactionParticipant {
118 pub fn new(node_id: NodeId) -> Self {
120 Self {
121 node_id,
122 vote: ParticipantVote::Pending,
123 prepared: false,
124 committed: false,
125 last_contact: None,
126 }
127 }
128
129 pub fn record_prepare(&mut self, vote: ParticipantVote) {
131 self.vote = vote;
132 self.prepared = vote == ParticipantVote::Commit;
133 }
134
135 pub fn record_commit(&mut self) {
137 self.committed = true;
138 }
139}
140
141#[derive(Debug, Clone, Serialize, Deserialize)]
147pub struct DistributedTransaction {
148 pub id: TransactionId,
149 pub coordinator: NodeId,
150 pub participants: HashMap<NodeId, TransactionParticipant>,
151 pub state: TransactionState,
152 pub created_at: u64,
153 pub timeout_ms: u64,
154 pub operations: Vec<TransactionOperation>,
155}
156
157impl DistributedTransaction {
158 pub fn new(id: TransactionId, coordinator: NodeId, timeout_ms: u64) -> Self {
160 let created_at = std::time::SystemTime::now()
161 .duration_since(std::time::UNIX_EPOCH)
162 .unwrap_or_default()
163 .as_millis() as u64;
164
165 Self {
166 id,
167 coordinator,
168 participants: HashMap::new(),
169 state: TransactionState::Preparing,
170 created_at,
171 timeout_ms,
172 operations: Vec::new(),
173 }
174 }
175
176 pub fn add_participant(&mut self, node_id: NodeId) {
178 if !self.participants.contains_key(&node_id) {
179 self.participants
180 .insert(node_id.clone(), TransactionParticipant::new(node_id));
181 }
182 }
183
184 pub fn add_operation(&mut self, operation: TransactionOperation) {
186 self.operations.push(operation);
187 }
188
189 pub fn all_prepared(&self) -> bool {
191 self.participants.values().all(|p| p.prepared)
192 }
193
194 pub fn all_committed(&self) -> bool {
196 self.participants.values().all(|p| p.committed)
197 }
198
199 pub fn any_abort(&self) -> bool {
201 self.participants
202 .values()
203 .any(|p| p.vote == ParticipantVote::Abort)
204 }
205
206 pub fn is_timed_out(&self) -> bool {
208 let now = std::time::SystemTime::now()
209 .duration_since(std::time::UNIX_EPOCH)
210 .unwrap_or_default()
211 .as_millis() as u64;
212 now - self.created_at > self.timeout_ms
213 }
214
215 pub fn participant_count(&self) -> usize {
217 self.participants.len()
218 }
219
220 pub fn prepared_count(&self) -> usize {
222 self.participants.values().filter(|p| p.prepared).count()
223 }
224}
225
226#[derive(Debug, Clone, Serialize, Deserialize)]
232pub enum TransactionOperation {
233 Read { key: String, shard_id: String },
235 Write {
237 key: String,
238 value: Vec<u8>,
239 shard_id: String,
240 },
241 Delete { key: String, shard_id: String },
243 CompareAndSwap {
245 key: String,
246 expected: Option<Vec<u8>>,
247 new_value: Vec<u8>,
248 shard_id: String,
249 },
250}
251
252impl TransactionOperation {
253 pub fn shard_id(&self) -> &str {
255 match self {
256 Self::Read { shard_id, .. } => shard_id,
257 Self::Write { shard_id, .. } => shard_id,
258 Self::Delete { shard_id, .. } => shard_id,
259 Self::CompareAndSwap { shard_id, .. } => shard_id,
260 }
261 }
262
263 pub fn key(&self) -> &str {
265 match self {
266 Self::Read { key, .. } => key,
267 Self::Write { key, .. } => key,
268 Self::Delete { key, .. } => key,
269 Self::CompareAndSwap { key, .. } => key,
270 }
271 }
272
273 pub fn is_write(&self) -> bool {
275 !matches!(self, Self::Read { .. })
276 }
277}
278
279#[derive(Debug, Clone, Serialize, Deserialize)]
285pub enum TwoPhaseMessage {
286 PrepareRequest {
288 txn_id: TransactionId,
289 operations: Vec<TransactionOperation>,
290 },
291 PrepareResponse {
293 txn_id: TransactionId,
294 vote: ParticipantVote,
295 participant: NodeId,
296 },
297 CommitRequest { txn_id: TransactionId },
299 CommitAck {
301 txn_id: TransactionId,
302 participant: NodeId,
303 },
304 AbortRequest { txn_id: TransactionId },
306 AbortAck {
308 txn_id: TransactionId,
309 participant: NodeId,
310 },
311 StatusQuery { txn_id: TransactionId },
313 StatusResponse {
315 txn_id: TransactionId,
316 state: TransactionState,
317 },
318}
319
320pub struct TransactionCoordinator {
326 node_id: NodeId,
327 transactions: RwLock<HashMap<TransactionId, DistributedTransaction>>,
328 default_timeout_ms: u64,
329 prepared_log: RwLock<HashSet<TransactionId>>,
330}
331
332impl TransactionCoordinator {
333 pub fn new(node_id: NodeId) -> Self {
335 Self {
336 node_id,
337 transactions: RwLock::new(HashMap::new()),
338 default_timeout_ms: 30000,
339 prepared_log: RwLock::new(HashSet::new()),
340 }
341 }
342
343 pub fn with_timeout(node_id: NodeId, timeout_ms: u64) -> Self {
345 Self {
346 node_id,
347 transactions: RwLock::new(HashMap::new()),
348 default_timeout_ms: timeout_ms,
349 prepared_log: RwLock::new(HashSet::new()),
350 }
351 }
352
353 pub fn begin_transaction(&self) -> TransactionId {
355 let txn_id = TransactionId::generate();
356 let txn = DistributedTransaction::new(
357 txn_id.clone(),
358 self.node_id.clone(),
359 self.default_timeout_ms,
360 );
361 self.transactions
362 .write()
363 .expect("transaction coordinator transactions lock poisoned")
364 .insert(txn_id.clone(), txn);
365 txn_id
366 }
367
368 pub fn begin_transaction_with_id(&self, txn_id: TransactionId) {
370 let txn = DistributedTransaction::new(
371 txn_id.clone(),
372 self.node_id.clone(),
373 self.default_timeout_ms,
374 );
375 self.transactions
376 .write()
377 .expect("transaction coordinator transactions lock poisoned")
378 .insert(txn_id, txn);
379 }
380
381 pub fn add_participant(&self, txn_id: &TransactionId, node_id: NodeId) -> bool {
383 if let Some(txn) = self
384 .transactions
385 .write()
386 .expect("transaction coordinator transactions lock poisoned")
387 .get_mut(txn_id)
388 {
389 txn.add_participant(node_id);
390 true
391 } else {
392 false
393 }
394 }
395
396 pub fn add_operation(&self, txn_id: &TransactionId, operation: TransactionOperation) -> bool {
398 if let Some(txn) = self
399 .transactions
400 .write()
401 .expect("transaction coordinator transactions lock poisoned")
402 .get_mut(txn_id)
403 {
404 txn.add_operation(operation);
405 true
406 } else {
407 false
408 }
409 }
410
411 pub fn prepare(&self, txn_id: &TransactionId) -> Option<Vec<(NodeId, TwoPhaseMessage)>> {
413 let txns = self
414 .transactions
415 .read()
416 .expect("transaction coordinator transactions lock poisoned");
417 let txn = txns.get(txn_id)?;
418
419 if txn.state != TransactionState::Preparing {
420 return None;
421 }
422
423 let messages: Vec<_> = txn
424 .participants
425 .keys()
426 .map(|node_id| {
427 (
428 node_id.clone(),
429 TwoPhaseMessage::PrepareRequest {
430 txn_id: txn_id.clone(),
431 operations: txn.operations.clone(),
432 },
433 )
434 })
435 .collect();
436
437 Some(messages)
438 }
439
440 pub fn handle_prepare_response(
442 &self,
443 txn_id: &TransactionId,
444 participant: &NodeId,
445 vote: ParticipantVote,
446 ) -> Option<TransactionState> {
447 let mut txns = self
448 .transactions
449 .write()
450 .expect("transaction coordinator transactions lock poisoned");
451 let txn = txns.get_mut(txn_id)?;
452
453 if let Some(p) = txn.participants.get_mut(participant) {
454 p.record_prepare(vote);
455 }
456
457 if txn.any_abort() {
459 txn.state = TransactionState::Aborting;
460 Some(TransactionState::Aborting)
461 } else if txn.all_prepared() {
462 txn.state = TransactionState::Prepared;
463 self.prepared_log
464 .write()
465 .expect("transaction coordinator prepared_log lock poisoned")
466 .insert(txn_id.clone());
467 Some(TransactionState::Prepared)
468 } else {
469 None
470 }
471 }
472
473 pub fn commit(&self, txn_id: &TransactionId) -> Option<Vec<(NodeId, TwoPhaseMessage)>> {
475 let mut txns = self
476 .transactions
477 .write()
478 .expect("transaction coordinator transactions lock poisoned");
479 let txn = txns.get_mut(txn_id)?;
480
481 if !txn.state.can_commit() {
482 return None;
483 }
484
485 txn.state = TransactionState::Committing;
486
487 let messages: Vec<_> = txn
488 .participants
489 .keys()
490 .map(|node_id| {
491 (
492 node_id.clone(),
493 TwoPhaseMessage::CommitRequest {
494 txn_id: txn_id.clone(),
495 },
496 )
497 })
498 .collect();
499
500 Some(messages)
501 }
502
503 pub fn handle_commit_ack(
505 &self,
506 txn_id: &TransactionId,
507 participant: &NodeId,
508 ) -> Option<TransactionState> {
509 let mut txns = self
510 .transactions
511 .write()
512 .expect("transaction coordinator transactions lock poisoned");
513 let txn = txns.get_mut(txn_id)?;
514
515 if let Some(p) = txn.participants.get_mut(participant) {
516 p.record_commit();
517 }
518
519 if txn.all_committed() {
520 txn.state = TransactionState::Committed;
521 Some(TransactionState::Committed)
522 } else {
523 None
524 }
525 }
526
527 pub fn abort(&self, txn_id: &TransactionId) -> Option<Vec<(NodeId, TwoPhaseMessage)>> {
529 let mut txns = self
530 .transactions
531 .write()
532 .expect("transaction coordinator transactions lock poisoned");
533 let txn = txns.get_mut(txn_id)?;
534
535 if !txn.state.can_abort() {
536 return None;
537 }
538
539 txn.state = TransactionState::Aborting;
540
541 let messages: Vec<_> = txn
542 .participants
543 .keys()
544 .map(|node_id| {
545 (
546 node_id.clone(),
547 TwoPhaseMessage::AbortRequest {
548 txn_id: txn_id.clone(),
549 },
550 )
551 })
552 .collect();
553
554 Some(messages)
555 }
556
557 pub fn handle_abort_ack(&self, txn_id: &TransactionId, _participant: &NodeId) -> bool {
559 let mut txns = self
560 .transactions
561 .write()
562 .expect("transaction coordinator transactions lock poisoned");
563 if let Some(txn) = txns.get_mut(txn_id) {
564 txn.state = TransactionState::Aborted;
565 true
566 } else {
567 false
568 }
569 }
570
571 pub fn get_state(&self, txn_id: &TransactionId) -> Option<TransactionState> {
573 self.transactions
574 .read()
575 .expect("transaction coordinator transactions lock poisoned")
576 .get(txn_id)
577 .map(|t| t.state)
578 }
579
580 pub fn get_transaction(&self, txn_id: &TransactionId) -> Option<DistributedTransaction> {
582 self.transactions
583 .read()
584 .expect("transaction coordinator transactions lock poisoned")
585 .get(txn_id)
586 .cloned()
587 }
588
589 pub fn check_timeouts(&self) -> Vec<TransactionId> {
591 self.transactions
592 .read()
593 .expect("transaction coordinator transactions lock poisoned")
594 .iter()
595 .filter(|(_, txn)| txn.is_timed_out() && !txn.state.is_terminal())
596 .map(|(id, _)| id.clone())
597 .collect()
598 }
599
600 pub fn cleanup_completed(&self) -> usize {
602 let mut txns = self
603 .transactions
604 .write()
605 .expect("transaction coordinator transactions lock poisoned");
606 let before = txns.len();
607 txns.retain(|_, txn| !txn.state.is_terminal());
608 before - txns.len()
609 }
610
611 pub fn active_count(&self) -> usize {
613 self.transactions
614 .read()
615 .expect("transaction coordinator transactions lock poisoned")
616 .values()
617 .filter(|t| !t.state.is_terminal())
618 .count()
619 }
620
621 pub fn was_prepared(&self, txn_id: &TransactionId) -> bool {
623 self.prepared_log
624 .read()
625 .expect("transaction coordinator prepared_log lock poisoned")
626 .contains(txn_id)
627 }
628}
629
630pub type ValidationCallback =
636 Box<dyn Fn(&TransactionId, &[TransactionOperation]) -> ValidationResult + Send + Sync>;
637
638pub type CommitCallback =
640 Box<dyn Fn(&TransactionId, &[TransactionOperation]) -> Result<(), String> + Send + Sync>;
641
642pub type AbortCallback = Box<dyn Fn(&TransactionId) -> Result<(), String> + Send + Sync>;
644
645#[derive(Debug, Clone)]
647pub struct ValidationResult {
648 pub success: bool,
650 pub error: Option<String>,
652 pub locked_keys: Vec<String>,
654}
655
656impl ValidationResult {
657 pub fn success(locked_keys: Vec<String>) -> Self {
659 Self {
660 success: true,
661 error: None,
662 locked_keys,
663 }
664 }
665
666 pub fn failure(error: impl Into<String>) -> Self {
668 Self {
669 success: false,
670 error: Some(error.into()),
671 locked_keys: vec![],
672 }
673 }
674}
675
676pub struct ParticipantHandler {
678 node_id: NodeId,
679 pending_prepares: RwLock<HashMap<TransactionId, Vec<TransactionOperation>>>,
680 prepared: RwLock<HashSet<TransactionId>>,
681 committed: RwLock<HashSet<TransactionId>>,
682 locked_keys: RwLock<HashMap<TransactionId, Vec<String>>>,
683 validation_callback: RwLock<Option<ValidationCallback>>,
684 commit_callback: RwLock<Option<CommitCallback>>,
685 abort_callback: RwLock<Option<AbortCallback>>,
686}
687
688impl ParticipantHandler {
689 pub fn new(node_id: NodeId) -> Self {
691 Self {
692 node_id,
693 pending_prepares: RwLock::new(HashMap::new()),
694 prepared: RwLock::new(HashSet::new()),
695 committed: RwLock::new(HashSet::new()),
696 locked_keys: RwLock::new(HashMap::new()),
697 validation_callback: RwLock::new(None),
698 commit_callback: RwLock::new(None),
699 abort_callback: RwLock::new(None),
700 }
701 }
702
703 pub fn set_validation_callback(&self, callback: ValidationCallback) {
705 *self
706 .validation_callback
707 .write()
708 .expect("participant handler validation_callback lock poisoned") = Some(callback);
709 }
710
711 pub fn set_commit_callback(&self, callback: CommitCallback) {
713 *self
714 .commit_callback
715 .write()
716 .expect("participant handler commit_callback lock poisoned") = Some(callback);
717 }
718
719 pub fn set_abort_callback(&self, callback: AbortCallback) {
721 *self
722 .abort_callback
723 .write()
724 .expect("participant handler abort_callback lock poisoned") = Some(callback);
725 }
726
727 pub fn handle_prepare(
729 &self,
730 txn_id: &TransactionId,
731 operations: Vec<TransactionOperation>,
732 ) -> TwoPhaseMessage {
733 let validation_result = {
735 let callback_guard = self
736 .validation_callback
737 .read()
738 .expect("participant handler validation_callback lock poisoned");
739 if let Some(ref callback) = *callback_guard {
740 callback(txn_id, &operations)
742 } else {
743 self.basic_validation(txn_id, &operations)
745 }
746 };
747
748 if !validation_result.success {
749 return TwoPhaseMessage::PrepareResponse {
750 txn_id: txn_id.clone(),
751 vote: ParticipantVote::Abort,
752 participant: self.node_id.clone(),
753 };
754 }
755
756 self.locked_keys
758 .write()
759 .expect("participant handler locked_keys lock poisoned")
760 .insert(txn_id.clone(), validation_result.locked_keys);
761
762 self.pending_prepares
764 .write()
765 .expect("participant handler pending_prepares lock poisoned")
766 .insert(txn_id.clone(), operations);
767 self.prepared
768 .write()
769 .expect("participant handler prepared lock poisoned")
770 .insert(txn_id.clone());
771
772 TwoPhaseMessage::PrepareResponse {
773 txn_id: txn_id.clone(),
774 vote: ParticipantVote::Commit,
775 participant: self.node_id.clone(),
776 }
777 }
778
779 fn basic_validation(
781 &self,
782 txn_id: &TransactionId,
783 operations: &[TransactionOperation],
784 ) -> ValidationResult {
785 let pending = self
787 .pending_prepares
788 .read()
789 .expect("participant handler pending_prepares lock poisoned");
790 let locked = self
791 .locked_keys
792 .read()
793 .expect("participant handler locked_keys lock poisoned");
794
795 let mut keys_to_lock = Vec::new();
796 for op in operations {
797 let key = op.key().to_string();
798
799 for (other_txn_id, other_keys) in locked.iter() {
801 if other_txn_id != txn_id && other_keys.contains(&key) {
802 return ValidationResult::failure(format!(
803 "Key '{}' is locked by transaction {}",
804 key, other_txn_id
805 ));
806 }
807 }
808
809 for (other_txn_id, other_ops) in pending.iter() {
811 if other_txn_id != txn_id {
812 for other_op in other_ops {
813 if other_op.key() == key && other_op.is_write() && op.is_write() {
814 return ValidationResult::failure(format!(
815 "Write conflict on key '{}' with transaction {}",
816 key, other_txn_id
817 ));
818 }
819 }
820 }
821 }
822
823 if op.is_write() {
824 keys_to_lock.push(key);
825 }
826 }
827
828 ValidationResult::success(keys_to_lock)
829 }
830
831 pub fn handle_prepare_with_validation<F>(
833 &self,
834 txn_id: &TransactionId,
835 operations: Vec<TransactionOperation>,
836 validator: F,
837 ) -> TwoPhaseMessage
838 where
839 F: FnOnce(&[TransactionOperation]) -> bool,
840 {
841 let vote = if validator(&operations) {
842 self.pending_prepares
843 .write()
844 .expect("participant handler pending_prepares lock poisoned")
845 .insert(txn_id.clone(), operations);
846 self.prepared
847 .write()
848 .expect("participant handler prepared lock poisoned")
849 .insert(txn_id.clone());
850 ParticipantVote::Commit
851 } else {
852 ParticipantVote::Abort
853 };
854
855 TwoPhaseMessage::PrepareResponse {
856 txn_id: txn_id.clone(),
857 vote,
858 participant: self.node_id.clone(),
859 }
860 }
861
862 pub fn handle_commit(&self, txn_id: &TransactionId) -> TwoPhaseMessage {
864 let operations = self
866 .pending_prepares
867 .write()
868 .expect("participant handler pending_prepares lock poisoned")
869 .remove(txn_id);
870
871 if let Some(ref callback) = *self
873 .commit_callback
874 .read()
875 .expect("participant handler commit_callback lock poisoned")
876 {
877 if let Some(ops) = &operations {
878 if let Err(e) = callback(txn_id, ops) {
879 tracing::error!("Commit callback failed for {}: {}", txn_id, e);
881 }
882 }
883 }
884
885 self.prepared
887 .write()
888 .expect("participant handler prepared lock poisoned")
889 .remove(txn_id);
890 self.locked_keys
891 .write()
892 .expect("participant handler locked_keys lock poisoned")
893 .remove(txn_id);
894 self.committed
895 .write()
896 .expect("participant handler committed lock poisoned")
897 .insert(txn_id.clone());
898
899 TwoPhaseMessage::CommitAck {
900 txn_id: txn_id.clone(),
901 participant: self.node_id.clone(),
902 }
903 }
904
905 pub fn handle_abort(&self, txn_id: &TransactionId) -> TwoPhaseMessage {
907 if let Some(ref callback) = *self
909 .abort_callback
910 .read()
911 .expect("participant handler abort_callback lock poisoned")
912 {
913 if let Err(e) = callback(txn_id) {
914 tracing::error!("Abort callback failed for {}: {}", txn_id, e);
915 }
916 }
917
918 self.pending_prepares
920 .write()
921 .expect("participant handler pending_prepares lock poisoned")
922 .remove(txn_id);
923 self.prepared
924 .write()
925 .expect("participant handler prepared lock poisoned")
926 .remove(txn_id);
927 self.locked_keys
928 .write()
929 .expect("participant handler locked_keys lock poisoned")
930 .remove(txn_id);
931
932 TwoPhaseMessage::AbortAck {
933 txn_id: txn_id.clone(),
934 participant: self.node_id.clone(),
935 }
936 }
937
938 pub fn get_locked_keys(&self, txn_id: &TransactionId) -> Vec<String> {
940 self.locked_keys
941 .read()
942 .expect("participant handler locked_keys lock poisoned")
943 .get(txn_id)
944 .cloned()
945 .unwrap_or_default()
946 }
947
948 pub fn is_key_locked(&self, key: &str) -> Option<TransactionId> {
950 let locked = self
951 .locked_keys
952 .read()
953 .expect("participant handler locked_keys lock poisoned");
954 for (txn_id, keys) in locked.iter() {
955 if keys.iter().any(|k| k == key) {
956 return Some(txn_id.clone());
957 }
958 }
959 None
960 }
961
962 pub fn is_prepared(&self, txn_id: &TransactionId) -> bool {
964 self.prepared
965 .read()
966 .expect("participant handler prepared lock poisoned")
967 .contains(txn_id)
968 }
969
970 pub fn is_committed(&self, txn_id: &TransactionId) -> bool {
972 self.committed
973 .read()
974 .expect("participant handler committed lock poisoned")
975 .contains(txn_id)
976 }
977
978 pub fn pending_count(&self) -> usize {
980 self.pending_prepares
981 .read()
982 .expect("participant handler pending_prepares lock poisoned")
983 .len()
984 }
985}
986
987#[cfg(test)]
992mod tests {
993 use super::*;
994
995 #[test]
996 fn test_transaction_id() {
997 let id1 = TransactionId::new("txn_1");
998 let id2 = TransactionId::generate();
999
1000 assert_eq!(id1.as_str(), "txn_1");
1001 assert!(id2.as_str().starts_with("txn_"));
1002 }
1003
1004 #[test]
1005 fn test_transaction_state() {
1006 assert!(!TransactionState::Preparing.is_terminal());
1007 assert!(TransactionState::Committed.is_terminal());
1008 assert!(TransactionState::Aborted.is_terminal());
1009
1010 assert!(TransactionState::Prepared.can_commit());
1011 assert!(!TransactionState::Preparing.can_commit());
1012
1013 assert!(TransactionState::Preparing.can_abort());
1014 assert!(!TransactionState::Committed.can_abort());
1015 }
1016
1017 #[test]
1018 fn test_distributed_transaction() {
1019 let txn_id = TransactionId::new("txn_1");
1020 let coordinator = NodeId::new("coord");
1021 let mut txn = DistributedTransaction::new(txn_id, coordinator, 30000);
1022
1023 assert_eq!(txn.state, TransactionState::Preparing);
1024 assert_eq!(txn.participant_count(), 0);
1025
1026 txn.add_participant(NodeId::new("node1"));
1027 txn.add_participant(NodeId::new("node2"));
1028
1029 assert_eq!(txn.participant_count(), 2);
1030 assert!(!txn.all_prepared());
1031
1032 txn.participants
1033 .get_mut(&NodeId::new("node1"))
1034 .unwrap()
1035 .record_prepare(ParticipantVote::Commit);
1036
1037 assert!(!txn.all_prepared());
1038 assert_eq!(txn.prepared_count(), 1);
1039
1040 txn.participants
1041 .get_mut(&NodeId::new("node2"))
1042 .unwrap()
1043 .record_prepare(ParticipantVote::Commit);
1044
1045 assert!(txn.all_prepared());
1046 assert!(!txn.any_abort());
1047 }
1048
1049 #[test]
1050 fn test_transaction_abort_vote() {
1051 let txn_id = TransactionId::new("txn_1");
1052 let mut txn = DistributedTransaction::new(txn_id, NodeId::new("coord"), 30000);
1053
1054 txn.add_participant(NodeId::new("node1"));
1055 txn.add_participant(NodeId::new("node2"));
1056
1057 txn.participants
1058 .get_mut(&NodeId::new("node1"))
1059 .unwrap()
1060 .record_prepare(ParticipantVote::Commit);
1061 txn.participants
1062 .get_mut(&NodeId::new("node2"))
1063 .unwrap()
1064 .record_prepare(ParticipantVote::Abort);
1065
1066 assert!(!txn.all_prepared());
1067 assert!(txn.any_abort());
1068 }
1069
1070 #[test]
1071 fn test_transaction_operation() {
1072 let write_op = TransactionOperation::Write {
1073 key: "user:1".to_string(),
1074 value: vec![1, 2, 3],
1075 shard_id: "shard_1".to_string(),
1076 };
1077
1078 assert_eq!(write_op.key(), "user:1");
1079 assert_eq!(write_op.shard_id(), "shard_1");
1080 assert!(write_op.is_write());
1081
1082 let read_op = TransactionOperation::Read {
1083 key: "user:2".to_string(),
1084 shard_id: "shard_2".to_string(),
1085 };
1086
1087 assert!(!read_op.is_write());
1088 }
1089
1090 #[test]
1091 fn test_coordinator_begin_transaction() {
1092 let coord = TransactionCoordinator::new(NodeId::new("coord"));
1093 let txn_id = coord.begin_transaction();
1094
1095 assert!(coord.get_state(&txn_id).is_some());
1096 assert_eq!(
1097 coord.get_state(&txn_id).unwrap(),
1098 TransactionState::Preparing
1099 );
1100 }
1101
1102 #[test]
1103 fn test_coordinator_add_participant() {
1104 let coord = TransactionCoordinator::new(NodeId::new("coord"));
1105 let txn_id = coord.begin_transaction();
1106
1107 assert!(coord.add_participant(&txn_id, NodeId::new("node1")));
1108 assert!(coord.add_participant(&txn_id, NodeId::new("node2")));
1109
1110 let txn = coord.get_transaction(&txn_id).unwrap();
1111 assert_eq!(txn.participant_count(), 2);
1112 }
1113
1114 #[test]
1115 fn test_coordinator_prepare() {
1116 let coord = TransactionCoordinator::new(NodeId::new("coord"));
1117 let txn_id = coord.begin_transaction();
1118
1119 coord.add_participant(&txn_id, NodeId::new("node1"));
1120 coord.add_participant(&txn_id, NodeId::new("node2"));
1121
1122 let messages = coord.prepare(&txn_id).unwrap();
1123 assert_eq!(messages.len(), 2);
1124
1125 for (_, msg) in &messages {
1126 match msg {
1127 TwoPhaseMessage::PrepareRequest { txn_id: id, .. } => {
1128 assert_eq!(id, &txn_id);
1129 }
1130 _ => panic!("Expected PrepareRequest"),
1131 }
1132 }
1133 }
1134
1135 #[test]
1136 fn test_coordinator_full_commit() {
1137 let coord = TransactionCoordinator::new(NodeId::new("coord"));
1138 let txn_id = coord.begin_transaction();
1139
1140 let node1 = NodeId::new("node1");
1141 let node2 = NodeId::new("node2");
1142
1143 coord.add_participant(&txn_id, node1.clone());
1144 coord.add_participant(&txn_id, node2.clone());
1145
1146 coord.prepare(&txn_id);
1148
1149 coord.handle_prepare_response(&txn_id, &node1, ParticipantVote::Commit);
1151 let state = coord.handle_prepare_response(&txn_id, &node2, ParticipantVote::Commit);
1152
1153 assert_eq!(state, Some(TransactionState::Prepared));
1154
1155 let messages = coord.commit(&txn_id).unwrap();
1157 assert_eq!(messages.len(), 2);
1158
1159 coord.handle_commit_ack(&txn_id, &node1);
1161 let final_state = coord.handle_commit_ack(&txn_id, &node2);
1162
1163 assert_eq!(final_state, Some(TransactionState::Committed));
1164 }
1165
1166 #[test]
1167 fn test_coordinator_abort_on_vote() {
1168 let coord = TransactionCoordinator::new(NodeId::new("coord"));
1169 let txn_id = coord.begin_transaction();
1170
1171 let node1 = NodeId::new("node1");
1172 let node2 = NodeId::new("node2");
1173
1174 coord.add_participant(&txn_id, node1.clone());
1175 coord.add_participant(&txn_id, node2.clone());
1176
1177 coord.prepare(&txn_id);
1178
1179 coord.handle_prepare_response(&txn_id, &node1, ParticipantVote::Commit);
1180 let state = coord.handle_prepare_response(&txn_id, &node2, ParticipantVote::Abort);
1181
1182 assert_eq!(state, Some(TransactionState::Aborting));
1183 }
1184
1185 #[test]
1186 fn test_participant_handler() {
1187 let handler = ParticipantHandler::new(NodeId::new("node1"));
1188 let txn_id = TransactionId::new("txn_1");
1189
1190 let ops = vec![TransactionOperation::Write {
1191 key: "key1".to_string(),
1192 value: vec![1, 2, 3],
1193 shard_id: "shard_1".to_string(),
1194 }];
1195
1196 let response = handler.handle_prepare(&txn_id, ops);
1198 match response {
1199 TwoPhaseMessage::PrepareResponse { vote, .. } => {
1200 assert_eq!(vote, ParticipantVote::Commit);
1201 }
1202 _ => panic!("Expected PrepareResponse"),
1203 }
1204
1205 assert!(handler.is_prepared(&txn_id));
1206 assert!(!handler.is_committed(&txn_id));
1207
1208 let commit_response = handler.handle_commit(&txn_id);
1210 match commit_response {
1211 TwoPhaseMessage::CommitAck { .. } => {}
1212 _ => panic!("Expected CommitAck"),
1213 }
1214
1215 assert!(!handler.is_prepared(&txn_id));
1216 assert!(handler.is_committed(&txn_id));
1217 }
1218
1219 #[test]
1220 fn test_participant_abort() {
1221 let handler = ParticipantHandler::new(NodeId::new("node1"));
1222 let txn_id = TransactionId::new("txn_1");
1223
1224 let ops = vec![TransactionOperation::Write {
1225 key: "key1".to_string(),
1226 value: vec![1, 2, 3],
1227 shard_id: "shard_1".to_string(),
1228 }];
1229
1230 handler.handle_prepare(&txn_id, ops);
1231 assert!(handler.is_prepared(&txn_id));
1232
1233 let abort_response = handler.handle_abort(&txn_id);
1234 match abort_response {
1235 TwoPhaseMessage::AbortAck { .. } => {}
1236 _ => panic!("Expected AbortAck"),
1237 }
1238
1239 assert!(!handler.is_prepared(&txn_id));
1240 assert!(!handler.is_committed(&txn_id));
1241 }
1242
1243 #[test]
1244 fn test_coordinator_cleanup() {
1245 let coord = TransactionCoordinator::new(NodeId::new("coord"));
1246
1247 let txn_id = coord.begin_transaction();
1249 coord.add_participant(&txn_id, NodeId::new("node1"));
1250 coord.prepare(&txn_id);
1251 coord.handle_prepare_response(&txn_id, &NodeId::new("node1"), ParticipantVote::Commit);
1252 coord.commit(&txn_id);
1253 coord.handle_commit_ack(&txn_id, &NodeId::new("node1"));
1254
1255 assert_eq!(coord.get_state(&txn_id), Some(TransactionState::Committed));
1256
1257 let cleaned = coord.cleanup_completed();
1258 assert_eq!(cleaned, 1);
1259 assert!(coord.get_state(&txn_id).is_none());
1260 }
1261
1262 #[test]
1263 fn test_active_count() {
1264 let coord = TransactionCoordinator::new(NodeId::new("coord"));
1265
1266 let txn1 = coord.begin_transaction();
1267 let _txn2 = coord.begin_transaction();
1268
1269 assert_eq!(coord.active_count(), 2);
1270
1271 coord.add_participant(&txn1, NodeId::new("node1"));
1273 coord.prepare(&txn1);
1274 coord.handle_prepare_response(&txn1, &NodeId::new("node1"), ParticipantVote::Commit);
1275 coord.commit(&txn1);
1276 coord.handle_commit_ack(&txn1, &NodeId::new("node1"));
1277
1278 assert_eq!(coord.active_count(), 1);
1279 }
1280}