aegis_replication/
transaction.rs

1//! Aegis Distributed Transactions
2//!
3//! Two-Phase Commit (2PC) protocol for distributed transaction coordination.
4//!
5//! @version 0.1.0
6//! @author AutomataNexus Development Team
7
8use crate::node::NodeId;
9use serde::{Deserialize, Serialize};
10use std::collections::{HashMap, HashSet};
11use std::sync::RwLock;
12
13// =============================================================================
14// Transaction ID
15// =============================================================================
16
17/// Unique identifier for distributed transactions.
18#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
19pub struct TransactionId(String);
20
21impl TransactionId {
22    /// Create a new transaction ID.
23    pub fn new(id: impl Into<String>) -> Self {
24        Self(id.into())
25    }
26
27    /// Generate a unique transaction ID.
28    pub fn generate() -> Self {
29        let timestamp = std::time::SystemTime::now()
30            .duration_since(std::time::UNIX_EPOCH)
31            .unwrap_or_default()
32            .as_nanos();
33        Self(format!("txn_{:x}", timestamp))
34    }
35
36    /// Get the ID as a string.
37    pub fn as_str(&self) -> &str {
38        &self.0
39    }
40}
41
42impl std::fmt::Display for TransactionId {
43    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44        write!(f, "{}", self.0)
45    }
46}
47
48// =============================================================================
49// Transaction State
50// =============================================================================
51
52/// State of a distributed transaction.
53#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
54pub enum TransactionState {
55    /// Transaction is being prepared.
56    Preparing,
57    /// Transaction is prepared and ready to commit.
58    Prepared,
59    /// Transaction is committing.
60    Committing,
61    /// Transaction has been committed.
62    Committed,
63    /// Transaction is aborting.
64    Aborting,
65    /// Transaction has been aborted.
66    Aborted,
67    /// Transaction state is unknown (recovery needed).
68    Unknown,
69}
70
71impl TransactionState {
72    /// Check if the transaction is in a terminal state.
73    pub fn is_terminal(&self) -> bool {
74        matches!(self, Self::Committed | Self::Aborted)
75    }
76
77    /// Check if the transaction can be committed.
78    pub fn can_commit(&self) -> bool {
79        matches!(self, Self::Prepared)
80    }
81
82    /// Check if the transaction can be aborted.
83    pub fn can_abort(&self) -> bool {
84        !matches!(self, Self::Committed)
85    }
86}
87
88// =============================================================================
89// Participant Vote
90// =============================================================================
91
92/// Vote from a participant in 2PC.
93#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
94pub enum ParticipantVote {
95    /// Participant votes to commit.
96    Commit,
97    /// Participant votes to abort.
98    Abort,
99    /// Participant hasn't voted yet.
100    Pending,
101}
102
103// =============================================================================
104// Transaction Participant
105// =============================================================================
106
107/// A participant in a distributed transaction.
108#[derive(Debug, Clone, Serialize, Deserialize)]
109pub struct TransactionParticipant {
110    pub node_id: NodeId,
111    pub vote: ParticipantVote,
112    pub prepared: bool,
113    pub committed: bool,
114    pub last_contact: Option<u64>,
115}
116
117impl TransactionParticipant {
118    /// Create a new participant.
119    pub fn new(node_id: NodeId) -> Self {
120        Self {
121            node_id,
122            vote: ParticipantVote::Pending,
123            prepared: false,
124            committed: false,
125            last_contact: None,
126        }
127    }
128
129    /// Record a prepare vote.
130    pub fn record_prepare(&mut self, vote: ParticipantVote) {
131        self.vote = vote;
132        self.prepared = vote == ParticipantVote::Commit;
133    }
134
135    /// Record commit acknowledgment.
136    pub fn record_commit(&mut self) {
137        self.committed = true;
138    }
139}
140
141// =============================================================================
142// Distributed Transaction
143// =============================================================================
144
145/// A distributed transaction managed by 2PC.
146#[derive(Debug, Clone, Serialize, Deserialize)]
147pub struct DistributedTransaction {
148    pub id: TransactionId,
149    pub coordinator: NodeId,
150    pub participants: HashMap<NodeId, TransactionParticipant>,
151    pub state: TransactionState,
152    pub created_at: u64,
153    pub timeout_ms: u64,
154    pub operations: Vec<TransactionOperation>,
155}
156
157impl DistributedTransaction {
158    /// Create a new distributed transaction.
159    pub fn new(id: TransactionId, coordinator: NodeId, timeout_ms: u64) -> Self {
160        let created_at = std::time::SystemTime::now()
161            .duration_since(std::time::UNIX_EPOCH)
162            .unwrap_or_default()
163            .as_millis() as u64;
164
165        Self {
166            id,
167            coordinator,
168            participants: HashMap::new(),
169            state: TransactionState::Preparing,
170            created_at,
171            timeout_ms,
172            operations: Vec::new(),
173        }
174    }
175
176    /// Add a participant to the transaction.
177    pub fn add_participant(&mut self, node_id: NodeId) {
178        if !self.participants.contains_key(&node_id) {
179            self.participants
180                .insert(node_id.clone(), TransactionParticipant::new(node_id));
181        }
182    }
183
184    /// Add an operation to the transaction.
185    pub fn add_operation(&mut self, operation: TransactionOperation) {
186        self.operations.push(operation);
187    }
188
189    /// Check if all participants are prepared.
190    pub fn all_prepared(&self) -> bool {
191        self.participants.values().all(|p| p.prepared)
192    }
193
194    /// Check if all participants have committed.
195    pub fn all_committed(&self) -> bool {
196        self.participants.values().all(|p| p.committed)
197    }
198
199    /// Check if any participant voted to abort.
200    pub fn any_abort(&self) -> bool {
201        self.participants
202            .values()
203            .any(|p| p.vote == ParticipantVote::Abort)
204    }
205
206    /// Check if the transaction has timed out.
207    pub fn is_timed_out(&self) -> bool {
208        let now = std::time::SystemTime::now()
209            .duration_since(std::time::UNIX_EPOCH)
210            .unwrap_or_default()
211            .as_millis() as u64;
212        now - self.created_at > self.timeout_ms
213    }
214
215    /// Get participant count.
216    pub fn participant_count(&self) -> usize {
217        self.participants.len()
218    }
219
220    /// Get prepared count.
221    pub fn prepared_count(&self) -> usize {
222        self.participants.values().filter(|p| p.prepared).count()
223    }
224}
225
226// =============================================================================
227// Transaction Operation
228// =============================================================================
229
230/// An operation within a distributed transaction.
231#[derive(Debug, Clone, Serialize, Deserialize)]
232pub enum TransactionOperation {
233    /// Read operation.
234    Read {
235        key: String,
236        shard_id: String,
237    },
238    /// Write operation.
239    Write {
240        key: String,
241        value: Vec<u8>,
242        shard_id: String,
243    },
244    /// Delete operation.
245    Delete {
246        key: String,
247        shard_id: String,
248    },
249    /// Compare and swap operation.
250    CompareAndSwap {
251        key: String,
252        expected: Option<Vec<u8>>,
253        new_value: Vec<u8>,
254        shard_id: String,
255    },
256}
257
258impl TransactionOperation {
259    /// Get the shard ID for this operation.
260    pub fn shard_id(&self) -> &str {
261        match self {
262            Self::Read { shard_id, .. } => shard_id,
263            Self::Write { shard_id, .. } => shard_id,
264            Self::Delete { shard_id, .. } => shard_id,
265            Self::CompareAndSwap { shard_id, .. } => shard_id,
266        }
267    }
268
269    /// Get the key for this operation.
270    pub fn key(&self) -> &str {
271        match self {
272            Self::Read { key, .. } => key,
273            Self::Write { key, .. } => key,
274            Self::Delete { key, .. } => key,
275            Self::CompareAndSwap { key, .. } => key,
276        }
277    }
278
279    /// Check if this is a write operation.
280    pub fn is_write(&self) -> bool {
281        !matches!(self, Self::Read { .. })
282    }
283}
284
285// =============================================================================
286// 2PC Messages
287// =============================================================================
288
289/// Messages for 2PC protocol.
290#[derive(Debug, Clone, Serialize, Deserialize)]
291pub enum TwoPhaseMessage {
292    /// Prepare request from coordinator.
293    PrepareRequest {
294        txn_id: TransactionId,
295        operations: Vec<TransactionOperation>,
296    },
297    /// Prepare response from participant.
298    PrepareResponse {
299        txn_id: TransactionId,
300        vote: ParticipantVote,
301        participant: NodeId,
302    },
303    /// Commit request from coordinator.
304    CommitRequest {
305        txn_id: TransactionId,
306    },
307    /// Commit acknowledgment from participant.
308    CommitAck {
309        txn_id: TransactionId,
310        participant: NodeId,
311    },
312    /// Abort request from coordinator.
313    AbortRequest {
314        txn_id: TransactionId,
315    },
316    /// Abort acknowledgment from participant.
317    AbortAck {
318        txn_id: TransactionId,
319        participant: NodeId,
320    },
321    /// Query transaction status (for recovery).
322    StatusQuery {
323        txn_id: TransactionId,
324    },
325    /// Status response.
326    StatusResponse {
327        txn_id: TransactionId,
328        state: TransactionState,
329    },
330}
331
332// =============================================================================
333// Transaction Coordinator
334// =============================================================================
335
336/// Coordinator for distributed transactions using 2PC.
337pub struct TransactionCoordinator {
338    node_id: NodeId,
339    transactions: RwLock<HashMap<TransactionId, DistributedTransaction>>,
340    default_timeout_ms: u64,
341    prepared_log: RwLock<HashSet<TransactionId>>,
342}
343
344impl TransactionCoordinator {
345    /// Create a new transaction coordinator.
346    pub fn new(node_id: NodeId) -> Self {
347        Self {
348            node_id,
349            transactions: RwLock::new(HashMap::new()),
350            default_timeout_ms: 30000,
351            prepared_log: RwLock::new(HashSet::new()),
352        }
353    }
354
355    /// Create with custom timeout.
356    pub fn with_timeout(node_id: NodeId, timeout_ms: u64) -> Self {
357        Self {
358            node_id,
359            transactions: RwLock::new(HashMap::new()),
360            default_timeout_ms: timeout_ms,
361            prepared_log: RwLock::new(HashSet::new()),
362        }
363    }
364
365    /// Begin a new distributed transaction.
366    pub fn begin_transaction(&self) -> TransactionId {
367        let txn_id = TransactionId::generate();
368        let txn = DistributedTransaction::new(
369            txn_id.clone(),
370            self.node_id.clone(),
371            self.default_timeout_ms,
372        );
373        self.transactions
374            .write()
375            .unwrap()
376            .insert(txn_id.clone(), txn);
377        txn_id
378    }
379
380    /// Begin a transaction with a specific ID.
381    pub fn begin_transaction_with_id(&self, txn_id: TransactionId) {
382        let txn = DistributedTransaction::new(
383            txn_id.clone(),
384            self.node_id.clone(),
385            self.default_timeout_ms,
386        );
387        self.transactions.write().unwrap().insert(txn_id, txn);
388    }
389
390    /// Add a participant to a transaction.
391    pub fn add_participant(&self, txn_id: &TransactionId, node_id: NodeId) -> bool {
392        if let Some(txn) = self.transactions.write().unwrap().get_mut(txn_id) {
393            txn.add_participant(node_id);
394            true
395        } else {
396            false
397        }
398    }
399
400    /// Add an operation to a transaction.
401    pub fn add_operation(&self, txn_id: &TransactionId, operation: TransactionOperation) -> bool {
402        if let Some(txn) = self.transactions.write().unwrap().get_mut(txn_id) {
403            txn.add_operation(operation);
404            true
405        } else {
406            false
407        }
408    }
409
410    /// Phase 1: Prepare - Generate prepare requests for all participants.
411    pub fn prepare(&self, txn_id: &TransactionId) -> Option<Vec<(NodeId, TwoPhaseMessage)>> {
412        let txns = self.transactions.read().unwrap();
413        let txn = txns.get(txn_id)?;
414
415        if txn.state != TransactionState::Preparing {
416            return None;
417        }
418
419        let messages: Vec<_> = txn
420            .participants
421            .keys()
422            .map(|node_id| {
423                (
424                    node_id.clone(),
425                    TwoPhaseMessage::PrepareRequest {
426                        txn_id: txn_id.clone(),
427                        operations: txn.operations.clone(),
428                    },
429                )
430            })
431            .collect();
432
433        Some(messages)
434    }
435
436    /// Handle prepare response from a participant.
437    pub fn handle_prepare_response(
438        &self,
439        txn_id: &TransactionId,
440        participant: &NodeId,
441        vote: ParticipantVote,
442    ) -> Option<TransactionState> {
443        let mut txns = self.transactions.write().unwrap();
444        let txn = txns.get_mut(txn_id)?;
445
446        if let Some(p) = txn.participants.get_mut(participant) {
447            p.record_prepare(vote);
448        }
449
450        // Check if we can make a decision
451        if txn.any_abort() {
452            txn.state = TransactionState::Aborting;
453            Some(TransactionState::Aborting)
454        } else if txn.all_prepared() {
455            txn.state = TransactionState::Prepared;
456            self.prepared_log.write().unwrap().insert(txn_id.clone());
457            Some(TransactionState::Prepared)
458        } else {
459            None
460        }
461    }
462
463    /// Phase 2: Commit - Generate commit requests for all participants.
464    pub fn commit(&self, txn_id: &TransactionId) -> Option<Vec<(NodeId, TwoPhaseMessage)>> {
465        let mut txns = self.transactions.write().unwrap();
466        let txn = txns.get_mut(txn_id)?;
467
468        if !txn.state.can_commit() {
469            return None;
470        }
471
472        txn.state = TransactionState::Committing;
473
474        let messages: Vec<_> = txn
475            .participants
476            .keys()
477            .map(|node_id| {
478                (
479                    node_id.clone(),
480                    TwoPhaseMessage::CommitRequest {
481                        txn_id: txn_id.clone(),
482                    },
483                )
484            })
485            .collect();
486
487        Some(messages)
488    }
489
490    /// Handle commit acknowledgment from a participant.
491    pub fn handle_commit_ack(
492        &self,
493        txn_id: &TransactionId,
494        participant: &NodeId,
495    ) -> Option<TransactionState> {
496        let mut txns = self.transactions.write().unwrap();
497        let txn = txns.get_mut(txn_id)?;
498
499        if let Some(p) = txn.participants.get_mut(participant) {
500            p.record_commit();
501        }
502
503        if txn.all_committed() {
504            txn.state = TransactionState::Committed;
505            Some(TransactionState::Committed)
506        } else {
507            None
508        }
509    }
510
511    /// Abort a transaction - Generate abort requests.
512    pub fn abort(&self, txn_id: &TransactionId) -> Option<Vec<(NodeId, TwoPhaseMessage)>> {
513        let mut txns = self.transactions.write().unwrap();
514        let txn = txns.get_mut(txn_id)?;
515
516        if !txn.state.can_abort() {
517            return None;
518        }
519
520        txn.state = TransactionState::Aborting;
521
522        let messages: Vec<_> = txn
523            .participants
524            .keys()
525            .map(|node_id| {
526                (
527                    node_id.clone(),
528                    TwoPhaseMessage::AbortRequest {
529                        txn_id: txn_id.clone(),
530                    },
531                )
532            })
533            .collect();
534
535        Some(messages)
536    }
537
538    /// Handle abort acknowledgment.
539    pub fn handle_abort_ack(&self, txn_id: &TransactionId, _participant: &NodeId) -> bool {
540        let mut txns = self.transactions.write().unwrap();
541        if let Some(txn) = txns.get_mut(txn_id) {
542            txn.state = TransactionState::Aborted;
543            true
544        } else {
545            false
546        }
547    }
548
549    /// Get transaction state.
550    pub fn get_state(&self, txn_id: &TransactionId) -> Option<TransactionState> {
551        self.transactions
552            .read()
553            .unwrap()
554            .get(txn_id)
555            .map(|t| t.state)
556    }
557
558    /// Get transaction details.
559    pub fn get_transaction(&self, txn_id: &TransactionId) -> Option<DistributedTransaction> {
560        self.transactions.read().unwrap().get(txn_id).cloned()
561    }
562
563    /// Check for timed out transactions.
564    pub fn check_timeouts(&self) -> Vec<TransactionId> {
565        self.transactions
566            .read()
567            .unwrap()
568            .iter()
569            .filter(|(_, txn)| txn.is_timed_out() && !txn.state.is_terminal())
570            .map(|(id, _)| id.clone())
571            .collect()
572    }
573
574    /// Clean up completed transactions.
575    pub fn cleanup_completed(&self) -> usize {
576        let mut txns = self.transactions.write().unwrap();
577        let before = txns.len();
578        txns.retain(|_, txn| !txn.state.is_terminal());
579        before - txns.len()
580    }
581
582    /// Get active transaction count.
583    pub fn active_count(&self) -> usize {
584        self.transactions
585            .read()
586            .unwrap()
587            .values()
588            .filter(|t| !t.state.is_terminal())
589            .count()
590    }
591
592    /// Check if a transaction was prepared (for recovery).
593    pub fn was_prepared(&self, txn_id: &TransactionId) -> bool {
594        self.prepared_log.read().unwrap().contains(txn_id)
595    }
596}
597
598// =============================================================================
599// Transaction Participant Handler
600// =============================================================================
601
602/// Handler for transaction participants.
603pub struct ParticipantHandler {
604    node_id: NodeId,
605    pending_prepares: RwLock<HashMap<TransactionId, Vec<TransactionOperation>>>,
606    prepared: RwLock<HashSet<TransactionId>>,
607    committed: RwLock<HashSet<TransactionId>>,
608}
609
610impl ParticipantHandler {
611    /// Create a new participant handler.
612    pub fn new(node_id: NodeId) -> Self {
613        Self {
614            node_id,
615            pending_prepares: RwLock::new(HashMap::new()),
616            prepared: RwLock::new(HashSet::new()),
617            committed: RwLock::new(HashSet::new()),
618        }
619    }
620
621    /// Handle prepare request.
622    pub fn handle_prepare(
623        &self,
624        txn_id: &TransactionId,
625        operations: Vec<TransactionOperation>,
626    ) -> TwoPhaseMessage {
627        // In a real implementation, this would:
628        // 1. Acquire locks on all keys
629        // 2. Validate all operations can succeed
630        // 3. Write to WAL
631        // For now, we simulate success
632
633        self.pending_prepares
634            .write()
635            .unwrap()
636            .insert(txn_id.clone(), operations);
637        self.prepared.write().unwrap().insert(txn_id.clone());
638
639        TwoPhaseMessage::PrepareResponse {
640            txn_id: txn_id.clone(),
641            vote: ParticipantVote::Commit,
642            participant: self.node_id.clone(),
643        }
644    }
645
646    /// Handle prepare with validation function.
647    pub fn handle_prepare_with_validation<F>(
648        &self,
649        txn_id: &TransactionId,
650        operations: Vec<TransactionOperation>,
651        validator: F,
652    ) -> TwoPhaseMessage
653    where
654        F: FnOnce(&[TransactionOperation]) -> bool,
655    {
656        let vote = if validator(&operations) {
657            self.pending_prepares
658                .write()
659                .unwrap()
660                .insert(txn_id.clone(), operations);
661            self.prepared.write().unwrap().insert(txn_id.clone());
662            ParticipantVote::Commit
663        } else {
664            ParticipantVote::Abort
665        };
666
667        TwoPhaseMessage::PrepareResponse {
668            txn_id: txn_id.clone(),
669            vote,
670            participant: self.node_id.clone(),
671        }
672    }
673
674    /// Handle commit request.
675    pub fn handle_commit(&self, txn_id: &TransactionId) -> TwoPhaseMessage {
676        // Apply the prepared operations
677        self.pending_prepares.write().unwrap().remove(txn_id);
678        self.prepared.write().unwrap().remove(txn_id);
679        self.committed.write().unwrap().insert(txn_id.clone());
680
681        TwoPhaseMessage::CommitAck {
682            txn_id: txn_id.clone(),
683            participant: self.node_id.clone(),
684        }
685    }
686
687    /// Handle abort request.
688    pub fn handle_abort(&self, txn_id: &TransactionId) -> TwoPhaseMessage {
689        // Rollback any prepared state
690        self.pending_prepares.write().unwrap().remove(txn_id);
691        self.prepared.write().unwrap().remove(txn_id);
692
693        TwoPhaseMessage::AbortAck {
694            txn_id: txn_id.clone(),
695            participant: self.node_id.clone(),
696        }
697    }
698
699    /// Check if a transaction is prepared.
700    pub fn is_prepared(&self, txn_id: &TransactionId) -> bool {
701        self.prepared.read().unwrap().contains(txn_id)
702    }
703
704    /// Check if a transaction is committed.
705    pub fn is_committed(&self, txn_id: &TransactionId) -> bool {
706        self.committed.read().unwrap().contains(txn_id)
707    }
708
709    /// Get pending prepare count.
710    pub fn pending_count(&self) -> usize {
711        self.pending_prepares.read().unwrap().len()
712    }
713}
714
715// =============================================================================
716// Tests
717// =============================================================================
718
719#[cfg(test)]
720mod tests {
721    use super::*;
722
723    #[test]
724    fn test_transaction_id() {
725        let id1 = TransactionId::new("txn_1");
726        let id2 = TransactionId::generate();
727
728        assert_eq!(id1.as_str(), "txn_1");
729        assert!(id2.as_str().starts_with("txn_"));
730    }
731
732    #[test]
733    fn test_transaction_state() {
734        assert!(!TransactionState::Preparing.is_terminal());
735        assert!(TransactionState::Committed.is_terminal());
736        assert!(TransactionState::Aborted.is_terminal());
737
738        assert!(TransactionState::Prepared.can_commit());
739        assert!(!TransactionState::Preparing.can_commit());
740
741        assert!(TransactionState::Preparing.can_abort());
742        assert!(!TransactionState::Committed.can_abort());
743    }
744
745    #[test]
746    fn test_distributed_transaction() {
747        let txn_id = TransactionId::new("txn_1");
748        let coordinator = NodeId::new("coord");
749        let mut txn = DistributedTransaction::new(txn_id, coordinator, 30000);
750
751        assert_eq!(txn.state, TransactionState::Preparing);
752        assert_eq!(txn.participant_count(), 0);
753
754        txn.add_participant(NodeId::new("node1"));
755        txn.add_participant(NodeId::new("node2"));
756
757        assert_eq!(txn.participant_count(), 2);
758        assert!(!txn.all_prepared());
759
760        txn.participants
761            .get_mut(&NodeId::new("node1"))
762            .unwrap()
763            .record_prepare(ParticipantVote::Commit);
764
765        assert!(!txn.all_prepared());
766        assert_eq!(txn.prepared_count(), 1);
767
768        txn.participants
769            .get_mut(&NodeId::new("node2"))
770            .unwrap()
771            .record_prepare(ParticipantVote::Commit);
772
773        assert!(txn.all_prepared());
774        assert!(!txn.any_abort());
775    }
776
777    #[test]
778    fn test_transaction_abort_vote() {
779        let txn_id = TransactionId::new("txn_1");
780        let mut txn = DistributedTransaction::new(txn_id, NodeId::new("coord"), 30000);
781
782        txn.add_participant(NodeId::new("node1"));
783        txn.add_participant(NodeId::new("node2"));
784
785        txn.participants
786            .get_mut(&NodeId::new("node1"))
787            .unwrap()
788            .record_prepare(ParticipantVote::Commit);
789        txn.participants
790            .get_mut(&NodeId::new("node2"))
791            .unwrap()
792            .record_prepare(ParticipantVote::Abort);
793
794        assert!(!txn.all_prepared());
795        assert!(txn.any_abort());
796    }
797
798    #[test]
799    fn test_transaction_operation() {
800        let write_op = TransactionOperation::Write {
801            key: "user:1".to_string(),
802            value: vec![1, 2, 3],
803            shard_id: "shard_1".to_string(),
804        };
805
806        assert_eq!(write_op.key(), "user:1");
807        assert_eq!(write_op.shard_id(), "shard_1");
808        assert!(write_op.is_write());
809
810        let read_op = TransactionOperation::Read {
811            key: "user:2".to_string(),
812            shard_id: "shard_2".to_string(),
813        };
814
815        assert!(!read_op.is_write());
816    }
817
818    #[test]
819    fn test_coordinator_begin_transaction() {
820        let coord = TransactionCoordinator::new(NodeId::new("coord"));
821        let txn_id = coord.begin_transaction();
822
823        assert!(coord.get_state(&txn_id).is_some());
824        assert_eq!(
825            coord.get_state(&txn_id).unwrap(),
826            TransactionState::Preparing
827        );
828    }
829
830    #[test]
831    fn test_coordinator_add_participant() {
832        let coord = TransactionCoordinator::new(NodeId::new("coord"));
833        let txn_id = coord.begin_transaction();
834
835        assert!(coord.add_participant(&txn_id, NodeId::new("node1")));
836        assert!(coord.add_participant(&txn_id, NodeId::new("node2")));
837
838        let txn = coord.get_transaction(&txn_id).unwrap();
839        assert_eq!(txn.participant_count(), 2);
840    }
841
842    #[test]
843    fn test_coordinator_prepare() {
844        let coord = TransactionCoordinator::new(NodeId::new("coord"));
845        let txn_id = coord.begin_transaction();
846
847        coord.add_participant(&txn_id, NodeId::new("node1"));
848        coord.add_participant(&txn_id, NodeId::new("node2"));
849
850        let messages = coord.prepare(&txn_id).unwrap();
851        assert_eq!(messages.len(), 2);
852
853        for (_, msg) in &messages {
854            match msg {
855                TwoPhaseMessage::PrepareRequest { txn_id: id, .. } => {
856                    assert_eq!(id, &txn_id);
857                }
858                _ => panic!("Expected PrepareRequest"),
859            }
860        }
861    }
862
863    #[test]
864    fn test_coordinator_full_commit() {
865        let coord = TransactionCoordinator::new(NodeId::new("coord"));
866        let txn_id = coord.begin_transaction();
867
868        let node1 = NodeId::new("node1");
869        let node2 = NodeId::new("node2");
870
871        coord.add_participant(&txn_id, node1.clone());
872        coord.add_participant(&txn_id, node2.clone());
873
874        // Phase 1: Prepare
875        coord.prepare(&txn_id);
876
877        // Both participants vote commit
878        coord.handle_prepare_response(&txn_id, &node1, ParticipantVote::Commit);
879        let state = coord.handle_prepare_response(&txn_id, &node2, ParticipantVote::Commit);
880
881        assert_eq!(state, Some(TransactionState::Prepared));
882
883        // Phase 2: Commit
884        let messages = coord.commit(&txn_id).unwrap();
885        assert_eq!(messages.len(), 2);
886
887        // Both participants acknowledge
888        coord.handle_commit_ack(&txn_id, &node1);
889        let final_state = coord.handle_commit_ack(&txn_id, &node2);
890
891        assert_eq!(final_state, Some(TransactionState::Committed));
892    }
893
894    #[test]
895    fn test_coordinator_abort_on_vote() {
896        let coord = TransactionCoordinator::new(NodeId::new("coord"));
897        let txn_id = coord.begin_transaction();
898
899        let node1 = NodeId::new("node1");
900        let node2 = NodeId::new("node2");
901
902        coord.add_participant(&txn_id, node1.clone());
903        coord.add_participant(&txn_id, node2.clone());
904
905        coord.prepare(&txn_id);
906
907        coord.handle_prepare_response(&txn_id, &node1, ParticipantVote::Commit);
908        let state = coord.handle_prepare_response(&txn_id, &node2, ParticipantVote::Abort);
909
910        assert_eq!(state, Some(TransactionState::Aborting));
911    }
912
913    #[test]
914    fn test_participant_handler() {
915        let handler = ParticipantHandler::new(NodeId::new("node1"));
916        let txn_id = TransactionId::new("txn_1");
917
918        let ops = vec![TransactionOperation::Write {
919            key: "key1".to_string(),
920            value: vec![1, 2, 3],
921            shard_id: "shard_1".to_string(),
922        }];
923
924        // Handle prepare
925        let response = handler.handle_prepare(&txn_id, ops);
926        match response {
927            TwoPhaseMessage::PrepareResponse { vote, .. } => {
928                assert_eq!(vote, ParticipantVote::Commit);
929            }
930            _ => panic!("Expected PrepareResponse"),
931        }
932
933        assert!(handler.is_prepared(&txn_id));
934        assert!(!handler.is_committed(&txn_id));
935
936        // Handle commit
937        let commit_response = handler.handle_commit(&txn_id);
938        match commit_response {
939            TwoPhaseMessage::CommitAck { .. } => {}
940            _ => panic!("Expected CommitAck"),
941        }
942
943        assert!(!handler.is_prepared(&txn_id));
944        assert!(handler.is_committed(&txn_id));
945    }
946
947    #[test]
948    fn test_participant_abort() {
949        let handler = ParticipantHandler::new(NodeId::new("node1"));
950        let txn_id = TransactionId::new("txn_1");
951
952        let ops = vec![TransactionOperation::Write {
953            key: "key1".to_string(),
954            value: vec![1, 2, 3],
955            shard_id: "shard_1".to_string(),
956        }];
957
958        handler.handle_prepare(&txn_id, ops);
959        assert!(handler.is_prepared(&txn_id));
960
961        let abort_response = handler.handle_abort(&txn_id);
962        match abort_response {
963            TwoPhaseMessage::AbortAck { .. } => {}
964            _ => panic!("Expected AbortAck"),
965        }
966
967        assert!(!handler.is_prepared(&txn_id));
968        assert!(!handler.is_committed(&txn_id));
969    }
970
971    #[test]
972    fn test_coordinator_cleanup() {
973        let coord = TransactionCoordinator::new(NodeId::new("coord"));
974
975        // Create and complete a transaction
976        let txn_id = coord.begin_transaction();
977        coord.add_participant(&txn_id, NodeId::new("node1"));
978        coord.prepare(&txn_id);
979        coord.handle_prepare_response(&txn_id, &NodeId::new("node1"), ParticipantVote::Commit);
980        coord.commit(&txn_id);
981        coord.handle_commit_ack(&txn_id, &NodeId::new("node1"));
982
983        assert_eq!(coord.get_state(&txn_id), Some(TransactionState::Committed));
984
985        let cleaned = coord.cleanup_completed();
986        assert_eq!(cleaned, 1);
987        assert!(coord.get_state(&txn_id).is_none());
988    }
989
990    #[test]
991    fn test_active_count() {
992        let coord = TransactionCoordinator::new(NodeId::new("coord"));
993
994        let txn1 = coord.begin_transaction();
995        let _txn2 = coord.begin_transaction();
996
997        assert_eq!(coord.active_count(), 2);
998
999        // Complete txn1
1000        coord.add_participant(&txn1, NodeId::new("node1"));
1001        coord.prepare(&txn1);
1002        coord.handle_prepare_response(&txn1, &NodeId::new("node1"), ParticipantVote::Commit);
1003        coord.commit(&txn1);
1004        coord.handle_commit_ack(&txn1, &NodeId::new("node1"));
1005
1006        assert_eq!(coord.active_count(), 1);
1007    }
1008}