aegis_replication/
transaction.rs

1//! Aegis Distributed Transactions
2//!
3//! Two-Phase Commit (2PC) protocol for distributed transaction coordination.
4//!
5//! @version 0.1.0
6//! @author AutomataNexus Development Team
7
8use crate::node::NodeId;
9use serde::{Deserialize, Serialize};
10use std::collections::{HashMap, HashSet};
11use std::sync::RwLock;
12
13// =============================================================================
14// Transaction ID
15// =============================================================================
16
17/// Unique identifier for distributed transactions.
18#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
19pub struct TransactionId(String);
20
21impl TransactionId {
22    /// Create a new transaction ID.
23    pub fn new(id: impl Into<String>) -> Self {
24        Self(id.into())
25    }
26
27    /// Generate a unique transaction ID.
28    pub fn generate() -> Self {
29        let timestamp = std::time::SystemTime::now()
30            .duration_since(std::time::UNIX_EPOCH)
31            .unwrap_or_default()
32            .as_nanos();
33        Self(format!("txn_{:x}", timestamp))
34    }
35
36    /// Get the ID as a string.
37    pub fn as_str(&self) -> &str {
38        &self.0
39    }
40}
41
42impl std::fmt::Display for TransactionId {
43    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44        write!(f, "{}", self.0)
45    }
46}
47
48// =============================================================================
49// Transaction State
50// =============================================================================
51
52/// State of a distributed transaction.
53#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
54pub enum TransactionState {
55    /// Transaction is being prepared.
56    Preparing,
57    /// Transaction is prepared and ready to commit.
58    Prepared,
59    /// Transaction is committing.
60    Committing,
61    /// Transaction has been committed.
62    Committed,
63    /// Transaction is aborting.
64    Aborting,
65    /// Transaction has been aborted.
66    Aborted,
67    /// Transaction state is unknown (recovery needed).
68    Unknown,
69}
70
71impl TransactionState {
72    /// Check if the transaction is in a terminal state.
73    pub fn is_terminal(&self) -> bool {
74        matches!(self, Self::Committed | Self::Aborted)
75    }
76
77    /// Check if the transaction can be committed.
78    pub fn can_commit(&self) -> bool {
79        matches!(self, Self::Prepared)
80    }
81
82    /// Check if the transaction can be aborted.
83    pub fn can_abort(&self) -> bool {
84        !matches!(self, Self::Committed)
85    }
86}
87
88// =============================================================================
89// Participant Vote
90// =============================================================================
91
92/// Vote from a participant in 2PC.
93#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
94pub enum ParticipantVote {
95    /// Participant votes to commit.
96    Commit,
97    /// Participant votes to abort.
98    Abort,
99    /// Participant hasn't voted yet.
100    Pending,
101}
102
103// =============================================================================
104// Transaction Participant
105// =============================================================================
106
107/// A participant in a distributed transaction.
108#[derive(Debug, Clone, Serialize, Deserialize)]
109pub struct TransactionParticipant {
110    pub node_id: NodeId,
111    pub vote: ParticipantVote,
112    pub prepared: bool,
113    pub committed: bool,
114    pub last_contact: Option<u64>,
115}
116
117impl TransactionParticipant {
118    /// Create a new participant.
119    pub fn new(node_id: NodeId) -> Self {
120        Self {
121            node_id,
122            vote: ParticipantVote::Pending,
123            prepared: false,
124            committed: false,
125            last_contact: None,
126        }
127    }
128
129    /// Record a prepare vote.
130    pub fn record_prepare(&mut self, vote: ParticipantVote) {
131        self.vote = vote;
132        self.prepared = vote == ParticipantVote::Commit;
133    }
134
135    /// Record commit acknowledgment.
136    pub fn record_commit(&mut self) {
137        self.committed = true;
138    }
139}
140
141// =============================================================================
142// Distributed Transaction
143// =============================================================================
144
145/// A distributed transaction managed by 2PC.
146#[derive(Debug, Clone, Serialize, Deserialize)]
147pub struct DistributedTransaction {
148    pub id: TransactionId,
149    pub coordinator: NodeId,
150    pub participants: HashMap<NodeId, TransactionParticipant>,
151    pub state: TransactionState,
152    pub created_at: u64,
153    pub timeout_ms: u64,
154    pub operations: Vec<TransactionOperation>,
155}
156
157impl DistributedTransaction {
158    /// Create a new distributed transaction.
159    pub fn new(id: TransactionId, coordinator: NodeId, timeout_ms: u64) -> Self {
160        let created_at = std::time::SystemTime::now()
161            .duration_since(std::time::UNIX_EPOCH)
162            .unwrap_or_default()
163            .as_millis() as u64;
164
165        Self {
166            id,
167            coordinator,
168            participants: HashMap::new(),
169            state: TransactionState::Preparing,
170            created_at,
171            timeout_ms,
172            operations: Vec::new(),
173        }
174    }
175
176    /// Add a participant to the transaction.
177    pub fn add_participant(&mut self, node_id: NodeId) {
178        if !self.participants.contains_key(&node_id) {
179            self.participants
180                .insert(node_id.clone(), TransactionParticipant::new(node_id));
181        }
182    }
183
184    /// Add an operation to the transaction.
185    pub fn add_operation(&mut self, operation: TransactionOperation) {
186        self.operations.push(operation);
187    }
188
189    /// Check if all participants are prepared.
190    pub fn all_prepared(&self) -> bool {
191        self.participants.values().all(|p| p.prepared)
192    }
193
194    /// Check if all participants have committed.
195    pub fn all_committed(&self) -> bool {
196        self.participants.values().all(|p| p.committed)
197    }
198
199    /// Check if any participant voted to abort.
200    pub fn any_abort(&self) -> bool {
201        self.participants
202            .values()
203            .any(|p| p.vote == ParticipantVote::Abort)
204    }
205
206    /// Check if the transaction has timed out.
207    pub fn is_timed_out(&self) -> bool {
208        let now = std::time::SystemTime::now()
209            .duration_since(std::time::UNIX_EPOCH)
210            .unwrap_or_default()
211            .as_millis() as u64;
212        now - self.created_at > self.timeout_ms
213    }
214
215    /// Get participant count.
216    pub fn participant_count(&self) -> usize {
217        self.participants.len()
218    }
219
220    /// Get prepared count.
221    pub fn prepared_count(&self) -> usize {
222        self.participants.values().filter(|p| p.prepared).count()
223    }
224}
225
226// =============================================================================
227// Transaction Operation
228// =============================================================================
229
230/// An operation within a distributed transaction.
231#[derive(Debug, Clone, Serialize, Deserialize)]
232pub enum TransactionOperation {
233    /// Read operation.
234    Read {
235        key: String,
236        shard_id: String,
237    },
238    /// Write operation.
239    Write {
240        key: String,
241        value: Vec<u8>,
242        shard_id: String,
243    },
244    /// Delete operation.
245    Delete {
246        key: String,
247        shard_id: String,
248    },
249    /// Compare and swap operation.
250    CompareAndSwap {
251        key: String,
252        expected: Option<Vec<u8>>,
253        new_value: Vec<u8>,
254        shard_id: String,
255    },
256}
257
258impl TransactionOperation {
259    /// Get the shard ID for this operation.
260    pub fn shard_id(&self) -> &str {
261        match self {
262            Self::Read { shard_id, .. } => shard_id,
263            Self::Write { shard_id, .. } => shard_id,
264            Self::Delete { shard_id, .. } => shard_id,
265            Self::CompareAndSwap { shard_id, .. } => shard_id,
266        }
267    }
268
269    /// Get the key for this operation.
270    pub fn key(&self) -> &str {
271        match self {
272            Self::Read { key, .. } => key,
273            Self::Write { key, .. } => key,
274            Self::Delete { key, .. } => key,
275            Self::CompareAndSwap { key, .. } => key,
276        }
277    }
278
279    /// Check if this is a write operation.
280    pub fn is_write(&self) -> bool {
281        !matches!(self, Self::Read { .. })
282    }
283}
284
285// =============================================================================
286// 2PC Messages
287// =============================================================================
288
289/// Messages for 2PC protocol.
290#[derive(Debug, Clone, Serialize, Deserialize)]
291pub enum TwoPhaseMessage {
292    /// Prepare request from coordinator.
293    PrepareRequest {
294        txn_id: TransactionId,
295        operations: Vec<TransactionOperation>,
296    },
297    /// Prepare response from participant.
298    PrepareResponse {
299        txn_id: TransactionId,
300        vote: ParticipantVote,
301        participant: NodeId,
302    },
303    /// Commit request from coordinator.
304    CommitRequest {
305        txn_id: TransactionId,
306    },
307    /// Commit acknowledgment from participant.
308    CommitAck {
309        txn_id: TransactionId,
310        participant: NodeId,
311    },
312    /// Abort request from coordinator.
313    AbortRequest {
314        txn_id: TransactionId,
315    },
316    /// Abort acknowledgment from participant.
317    AbortAck {
318        txn_id: TransactionId,
319        participant: NodeId,
320    },
321    /// Query transaction status (for recovery).
322    StatusQuery {
323        txn_id: TransactionId,
324    },
325    /// Status response.
326    StatusResponse {
327        txn_id: TransactionId,
328        state: TransactionState,
329    },
330}
331
332// =============================================================================
333// Transaction Coordinator
334// =============================================================================
335
336/// Coordinator for distributed transactions using 2PC.
337pub struct TransactionCoordinator {
338    node_id: NodeId,
339    transactions: RwLock<HashMap<TransactionId, DistributedTransaction>>,
340    default_timeout_ms: u64,
341    prepared_log: RwLock<HashSet<TransactionId>>,
342}
343
344impl TransactionCoordinator {
345    /// Create a new transaction coordinator.
346    pub fn new(node_id: NodeId) -> Self {
347        Self {
348            node_id,
349            transactions: RwLock::new(HashMap::new()),
350            default_timeout_ms: 30000,
351            prepared_log: RwLock::new(HashSet::new()),
352        }
353    }
354
355    /// Create with custom timeout.
356    pub fn with_timeout(node_id: NodeId, timeout_ms: u64) -> Self {
357        Self {
358            node_id,
359            transactions: RwLock::new(HashMap::new()),
360            default_timeout_ms: timeout_ms,
361            prepared_log: RwLock::new(HashSet::new()),
362        }
363    }
364
365    /// Begin a new distributed transaction.
366    pub fn begin_transaction(&self) -> TransactionId {
367        let txn_id = TransactionId::generate();
368        let txn = DistributedTransaction::new(
369            txn_id.clone(),
370            self.node_id.clone(),
371            self.default_timeout_ms,
372        );
373        self.transactions
374            .write()
375            .unwrap()
376            .insert(txn_id.clone(), txn);
377        txn_id
378    }
379
380    /// Begin a transaction with a specific ID.
381    pub fn begin_transaction_with_id(&self, txn_id: TransactionId) {
382        let txn = DistributedTransaction::new(
383            txn_id.clone(),
384            self.node_id.clone(),
385            self.default_timeout_ms,
386        );
387        self.transactions.write().unwrap().insert(txn_id, txn);
388    }
389
390    /// Add a participant to a transaction.
391    pub fn add_participant(&self, txn_id: &TransactionId, node_id: NodeId) -> bool {
392        if let Some(txn) = self.transactions.write().unwrap().get_mut(txn_id) {
393            txn.add_participant(node_id);
394            true
395        } else {
396            false
397        }
398    }
399
400    /// Add an operation to a transaction.
401    pub fn add_operation(&self, txn_id: &TransactionId, operation: TransactionOperation) -> bool {
402        if let Some(txn) = self.transactions.write().unwrap().get_mut(txn_id) {
403            txn.add_operation(operation);
404            true
405        } else {
406            false
407        }
408    }
409
410    /// Phase 1: Prepare - Generate prepare requests for all participants.
411    pub fn prepare(&self, txn_id: &TransactionId) -> Option<Vec<(NodeId, TwoPhaseMessage)>> {
412        let txns = self.transactions.read().unwrap();
413        let txn = txns.get(txn_id)?;
414
415        if txn.state != TransactionState::Preparing {
416            return None;
417        }
418
419        let messages: Vec<_> = txn
420            .participants
421            .keys()
422            .map(|node_id| {
423                (
424                    node_id.clone(),
425                    TwoPhaseMessage::PrepareRequest {
426                        txn_id: txn_id.clone(),
427                        operations: txn.operations.clone(),
428                    },
429                )
430            })
431            .collect();
432
433        Some(messages)
434    }
435
436    /// Handle prepare response from a participant.
437    pub fn handle_prepare_response(
438        &self,
439        txn_id: &TransactionId,
440        participant: &NodeId,
441        vote: ParticipantVote,
442    ) -> Option<TransactionState> {
443        let mut txns = self.transactions.write().unwrap();
444        let txn = txns.get_mut(txn_id)?;
445
446        if let Some(p) = txn.participants.get_mut(participant) {
447            p.record_prepare(vote);
448        }
449
450        // Check if we can make a decision
451        if txn.any_abort() {
452            txn.state = TransactionState::Aborting;
453            Some(TransactionState::Aborting)
454        } else if txn.all_prepared() {
455            txn.state = TransactionState::Prepared;
456            self.prepared_log.write().unwrap().insert(txn_id.clone());
457            Some(TransactionState::Prepared)
458        } else {
459            None
460        }
461    }
462
463    /// Phase 2: Commit - Generate commit requests for all participants.
464    pub fn commit(&self, txn_id: &TransactionId) -> Option<Vec<(NodeId, TwoPhaseMessage)>> {
465        let mut txns = self.transactions.write().unwrap();
466        let txn = txns.get_mut(txn_id)?;
467
468        if !txn.state.can_commit() {
469            return None;
470        }
471
472        txn.state = TransactionState::Committing;
473
474        let messages: Vec<_> = txn
475            .participants
476            .keys()
477            .map(|node_id| {
478                (
479                    node_id.clone(),
480                    TwoPhaseMessage::CommitRequest {
481                        txn_id: txn_id.clone(),
482                    },
483                )
484            })
485            .collect();
486
487        Some(messages)
488    }
489
490    /// Handle commit acknowledgment from a participant.
491    pub fn handle_commit_ack(
492        &self,
493        txn_id: &TransactionId,
494        participant: &NodeId,
495    ) -> Option<TransactionState> {
496        let mut txns = self.transactions.write().unwrap();
497        let txn = txns.get_mut(txn_id)?;
498
499        if let Some(p) = txn.participants.get_mut(participant) {
500            p.record_commit();
501        }
502
503        if txn.all_committed() {
504            txn.state = TransactionState::Committed;
505            Some(TransactionState::Committed)
506        } else {
507            None
508        }
509    }
510
511    /// Abort a transaction - Generate abort requests.
512    pub fn abort(&self, txn_id: &TransactionId) -> Option<Vec<(NodeId, TwoPhaseMessage)>> {
513        let mut txns = self.transactions.write().unwrap();
514        let txn = txns.get_mut(txn_id)?;
515
516        if !txn.state.can_abort() {
517            return None;
518        }
519
520        txn.state = TransactionState::Aborting;
521
522        let messages: Vec<_> = txn
523            .participants
524            .keys()
525            .map(|node_id| {
526                (
527                    node_id.clone(),
528                    TwoPhaseMessage::AbortRequest {
529                        txn_id: txn_id.clone(),
530                    },
531                )
532            })
533            .collect();
534
535        Some(messages)
536    }
537
538    /// Handle abort acknowledgment.
539    pub fn handle_abort_ack(&self, txn_id: &TransactionId, _participant: &NodeId) -> bool {
540        let mut txns = self.transactions.write().unwrap();
541        if let Some(txn) = txns.get_mut(txn_id) {
542            txn.state = TransactionState::Aborted;
543            true
544        } else {
545            false
546        }
547    }
548
549    /// Get transaction state.
550    pub fn get_state(&self, txn_id: &TransactionId) -> Option<TransactionState> {
551        self.transactions
552            .read()
553            .unwrap()
554            .get(txn_id)
555            .map(|t| t.state)
556    }
557
558    /// Get transaction details.
559    pub fn get_transaction(&self, txn_id: &TransactionId) -> Option<DistributedTransaction> {
560        self.transactions.read().unwrap().get(txn_id).cloned()
561    }
562
563    /// Check for timed out transactions.
564    pub fn check_timeouts(&self) -> Vec<TransactionId> {
565        self.transactions
566            .read()
567            .unwrap()
568            .iter()
569            .filter(|(_, txn)| txn.is_timed_out() && !txn.state.is_terminal())
570            .map(|(id, _)| id.clone())
571            .collect()
572    }
573
574    /// Clean up completed transactions.
575    pub fn cleanup_completed(&self) -> usize {
576        let mut txns = self.transactions.write().unwrap();
577        let before = txns.len();
578        txns.retain(|_, txn| !txn.state.is_terminal());
579        before - txns.len()
580    }
581
582    /// Get active transaction count.
583    pub fn active_count(&self) -> usize {
584        self.transactions
585            .read()
586            .unwrap()
587            .values()
588            .filter(|t| !t.state.is_terminal())
589            .count()
590    }
591
592    /// Check if a transaction was prepared (for recovery).
593    pub fn was_prepared(&self, txn_id: &TransactionId) -> bool {
594        self.prepared_log.read().unwrap().contains(txn_id)
595    }
596}
597
598// =============================================================================
599// Transaction Participant Handler
600// =============================================================================
601
602/// Callback type for validating operations before prepare.
603pub type ValidationCallback = Box<dyn Fn(&TransactionId, &[TransactionOperation]) -> ValidationResult + Send + Sync>;
604
605/// Callback type for committing operations.
606pub type CommitCallback = Box<dyn Fn(&TransactionId, &[TransactionOperation]) -> Result<(), String> + Send + Sync>;
607
608/// Callback type for aborting/rolling back operations.
609pub type AbortCallback = Box<dyn Fn(&TransactionId) -> Result<(), String> + Send + Sync>;
610
611/// Result of validating operations for 2PC prepare.
612#[derive(Debug, Clone)]
613pub struct ValidationResult {
614    /// Whether validation succeeded.
615    pub success: bool,
616    /// Error message if validation failed.
617    pub error: Option<String>,
618    /// Keys that were locked for this transaction.
619    pub locked_keys: Vec<String>,
620}
621
622impl ValidationResult {
623    /// Create a successful validation result.
624    pub fn success(locked_keys: Vec<String>) -> Self {
625        Self {
626            success: true,
627            error: None,
628            locked_keys,
629        }
630    }
631
632    /// Create a failed validation result.
633    pub fn failure(error: impl Into<String>) -> Self {
634        Self {
635            success: false,
636            error: Some(error.into()),
637            locked_keys: vec![],
638        }
639    }
640}
641
642/// Handler for transaction participants with storage integration.
643pub struct ParticipantHandler {
644    node_id: NodeId,
645    pending_prepares: RwLock<HashMap<TransactionId, Vec<TransactionOperation>>>,
646    prepared: RwLock<HashSet<TransactionId>>,
647    committed: RwLock<HashSet<TransactionId>>,
648    locked_keys: RwLock<HashMap<TransactionId, Vec<String>>>,
649    validation_callback: RwLock<Option<ValidationCallback>>,
650    commit_callback: RwLock<Option<CommitCallback>>,
651    abort_callback: RwLock<Option<AbortCallback>>,
652}
653
654impl ParticipantHandler {
655    /// Create a new participant handler.
656    pub fn new(node_id: NodeId) -> Self {
657        Self {
658            node_id,
659            pending_prepares: RwLock::new(HashMap::new()),
660            prepared: RwLock::new(HashSet::new()),
661            committed: RwLock::new(HashSet::new()),
662            locked_keys: RwLock::new(HashMap::new()),
663            validation_callback: RwLock::new(None),
664            commit_callback: RwLock::new(None),
665            abort_callback: RwLock::new(None),
666        }
667    }
668
669    /// Set the validation callback for prepare phase.
670    pub fn set_validation_callback(&self, callback: ValidationCallback) {
671        *self.validation_callback.write().unwrap() = Some(callback);
672    }
673
674    /// Set the commit callback for commit phase.
675    pub fn set_commit_callback(&self, callback: CommitCallback) {
676        *self.commit_callback.write().unwrap() = Some(callback);
677    }
678
679    /// Set the abort callback for abort/rollback phase.
680    pub fn set_abort_callback(&self, callback: AbortCallback) {
681        *self.abort_callback.write().unwrap() = Some(callback);
682    }
683
684    /// Handle prepare request with full storage integration.
685    pub fn handle_prepare(
686        &self,
687        txn_id: &TransactionId,
688        operations: Vec<TransactionOperation>,
689    ) -> TwoPhaseMessage {
690        // Check if there's a validation callback
691        let validation_result = {
692            let callback_guard = self.validation_callback.read().unwrap();
693            if let Some(ref callback) = *callback_guard {
694                // Run validation against storage
695                callback(txn_id, &operations)
696            } else {
697                // No callback - perform basic validation (key conflict check)
698                self.basic_validation(txn_id, &operations)
699            }
700        };
701
702        if !validation_result.success {
703            return TwoPhaseMessage::PrepareResponse {
704                txn_id: txn_id.clone(),
705                vote: ParticipantVote::Abort,
706                participant: self.node_id.clone(),
707            };
708        }
709
710        // Store locked keys for later release
711        self.locked_keys
712            .write()
713            .unwrap()
714            .insert(txn_id.clone(), validation_result.locked_keys);
715
716        // Store operations for commit phase
717        self.pending_prepares
718            .write()
719            .unwrap()
720            .insert(txn_id.clone(), operations);
721        self.prepared.write().unwrap().insert(txn_id.clone());
722
723        TwoPhaseMessage::PrepareResponse {
724            txn_id: txn_id.clone(),
725            vote: ParticipantVote::Commit,
726            participant: self.node_id.clone(),
727        }
728    }
729
730    /// Basic validation when no custom callback is provided.
731    fn basic_validation(&self, txn_id: &TransactionId, operations: &[TransactionOperation]) -> ValidationResult {
732        // Check for key conflicts with other pending transactions
733        let pending = self.pending_prepares.read().unwrap();
734        let locked = self.locked_keys.read().unwrap();
735
736        let mut keys_to_lock = Vec::new();
737        for op in operations {
738            let key = op.key().to_string();
739
740            // Check if key is locked by another transaction
741            for (other_txn_id, other_keys) in locked.iter() {
742                if other_txn_id != txn_id && other_keys.contains(&key) {
743                    return ValidationResult::failure(format!(
744                        "Key '{}' is locked by transaction {}",
745                        key, other_txn_id
746                    ));
747                }
748            }
749
750            // Check if key is being modified by another pending transaction
751            for (other_txn_id, other_ops) in pending.iter() {
752                if other_txn_id != txn_id {
753                    for other_op in other_ops {
754                        if other_op.key() == key && other_op.is_write() && op.is_write() {
755                            return ValidationResult::failure(format!(
756                                "Write conflict on key '{}' with transaction {}",
757                                key, other_txn_id
758                            ));
759                        }
760                    }
761                }
762            }
763
764            if op.is_write() {
765                keys_to_lock.push(key);
766            }
767        }
768
769        ValidationResult::success(keys_to_lock)
770    }
771
772    /// Handle prepare with validation function.
773    pub fn handle_prepare_with_validation<F>(
774        &self,
775        txn_id: &TransactionId,
776        operations: Vec<TransactionOperation>,
777        validator: F,
778    ) -> TwoPhaseMessage
779    where
780        F: FnOnce(&[TransactionOperation]) -> bool,
781    {
782        let vote = if validator(&operations) {
783            self.pending_prepares
784                .write()
785                .unwrap()
786                .insert(txn_id.clone(), operations);
787            self.prepared.write().unwrap().insert(txn_id.clone());
788            ParticipantVote::Commit
789        } else {
790            ParticipantVote::Abort
791        };
792
793        TwoPhaseMessage::PrepareResponse {
794            txn_id: txn_id.clone(),
795            vote,
796            participant: self.node_id.clone(),
797        }
798    }
799
800    /// Handle commit request with storage integration.
801    pub fn handle_commit(&self, txn_id: &TransactionId) -> TwoPhaseMessage {
802        // Get the operations to commit
803        let operations = self.pending_prepares.write().unwrap().remove(txn_id);
804
805        // Execute commit callback if set
806        if let Some(ref callback) = *self.commit_callback.read().unwrap() {
807            if let Some(ops) = &operations {
808                if let Err(e) = callback(txn_id, ops) {
809                    // Log error but continue - we must commit after prepare
810                    tracing::error!("Commit callback failed for {}: {}", txn_id, e);
811                }
812            }
813        }
814
815        // Clean up state
816        self.prepared.write().unwrap().remove(txn_id);
817        self.locked_keys.write().unwrap().remove(txn_id);
818        self.committed.write().unwrap().insert(txn_id.clone());
819
820        TwoPhaseMessage::CommitAck {
821            txn_id: txn_id.clone(),
822            participant: self.node_id.clone(),
823        }
824    }
825
826    /// Handle abort request with proper cleanup.
827    pub fn handle_abort(&self, txn_id: &TransactionId) -> TwoPhaseMessage {
828        // Execute abort callback if set
829        if let Some(ref callback) = *self.abort_callback.read().unwrap() {
830            if let Err(e) = callback(txn_id) {
831                tracing::error!("Abort callback failed for {}: {}", txn_id, e);
832            }
833        }
834
835        // Rollback any prepared state and release locks
836        self.pending_prepares.write().unwrap().remove(txn_id);
837        self.prepared.write().unwrap().remove(txn_id);
838        self.locked_keys.write().unwrap().remove(txn_id);
839
840        TwoPhaseMessage::AbortAck {
841            txn_id: txn_id.clone(),
842            participant: self.node_id.clone(),
843        }
844    }
845
846    /// Get currently locked keys for a transaction.
847    pub fn get_locked_keys(&self, txn_id: &TransactionId) -> Vec<String> {
848        self.locked_keys
849            .read()
850            .unwrap()
851            .get(txn_id)
852            .cloned()
853            .unwrap_or_default()
854    }
855
856    /// Check if a key is locked by any transaction.
857    pub fn is_key_locked(&self, key: &str) -> Option<TransactionId> {
858        let locked = self.locked_keys.read().unwrap();
859        for (txn_id, keys) in locked.iter() {
860            if keys.iter().any(|k| k == key) {
861                return Some(txn_id.clone());
862            }
863        }
864        None
865    }
866
867    /// Check if a transaction is prepared.
868    pub fn is_prepared(&self, txn_id: &TransactionId) -> bool {
869        self.prepared.read().unwrap().contains(txn_id)
870    }
871
872    /// Check if a transaction is committed.
873    pub fn is_committed(&self, txn_id: &TransactionId) -> bool {
874        self.committed.read().unwrap().contains(txn_id)
875    }
876
877    /// Get pending prepare count.
878    pub fn pending_count(&self) -> usize {
879        self.pending_prepares.read().unwrap().len()
880    }
881}
882
883// =============================================================================
884// Tests
885// =============================================================================
886
887#[cfg(test)]
888mod tests {
889    use super::*;
890
891    #[test]
892    fn test_transaction_id() {
893        let id1 = TransactionId::new("txn_1");
894        let id2 = TransactionId::generate();
895
896        assert_eq!(id1.as_str(), "txn_1");
897        assert!(id2.as_str().starts_with("txn_"));
898    }
899
900    #[test]
901    fn test_transaction_state() {
902        assert!(!TransactionState::Preparing.is_terminal());
903        assert!(TransactionState::Committed.is_terminal());
904        assert!(TransactionState::Aborted.is_terminal());
905
906        assert!(TransactionState::Prepared.can_commit());
907        assert!(!TransactionState::Preparing.can_commit());
908
909        assert!(TransactionState::Preparing.can_abort());
910        assert!(!TransactionState::Committed.can_abort());
911    }
912
913    #[test]
914    fn test_distributed_transaction() {
915        let txn_id = TransactionId::new("txn_1");
916        let coordinator = NodeId::new("coord");
917        let mut txn = DistributedTransaction::new(txn_id, coordinator, 30000);
918
919        assert_eq!(txn.state, TransactionState::Preparing);
920        assert_eq!(txn.participant_count(), 0);
921
922        txn.add_participant(NodeId::new("node1"));
923        txn.add_participant(NodeId::new("node2"));
924
925        assert_eq!(txn.participant_count(), 2);
926        assert!(!txn.all_prepared());
927
928        txn.participants
929            .get_mut(&NodeId::new("node1"))
930            .unwrap()
931            .record_prepare(ParticipantVote::Commit);
932
933        assert!(!txn.all_prepared());
934        assert_eq!(txn.prepared_count(), 1);
935
936        txn.participants
937            .get_mut(&NodeId::new("node2"))
938            .unwrap()
939            .record_prepare(ParticipantVote::Commit);
940
941        assert!(txn.all_prepared());
942        assert!(!txn.any_abort());
943    }
944
945    #[test]
946    fn test_transaction_abort_vote() {
947        let txn_id = TransactionId::new("txn_1");
948        let mut txn = DistributedTransaction::new(txn_id, NodeId::new("coord"), 30000);
949
950        txn.add_participant(NodeId::new("node1"));
951        txn.add_participant(NodeId::new("node2"));
952
953        txn.participants
954            .get_mut(&NodeId::new("node1"))
955            .unwrap()
956            .record_prepare(ParticipantVote::Commit);
957        txn.participants
958            .get_mut(&NodeId::new("node2"))
959            .unwrap()
960            .record_prepare(ParticipantVote::Abort);
961
962        assert!(!txn.all_prepared());
963        assert!(txn.any_abort());
964    }
965
966    #[test]
967    fn test_transaction_operation() {
968        let write_op = TransactionOperation::Write {
969            key: "user:1".to_string(),
970            value: vec![1, 2, 3],
971            shard_id: "shard_1".to_string(),
972        };
973
974        assert_eq!(write_op.key(), "user:1");
975        assert_eq!(write_op.shard_id(), "shard_1");
976        assert!(write_op.is_write());
977
978        let read_op = TransactionOperation::Read {
979            key: "user:2".to_string(),
980            shard_id: "shard_2".to_string(),
981        };
982
983        assert!(!read_op.is_write());
984    }
985
986    #[test]
987    fn test_coordinator_begin_transaction() {
988        let coord = TransactionCoordinator::new(NodeId::new("coord"));
989        let txn_id = coord.begin_transaction();
990
991        assert!(coord.get_state(&txn_id).is_some());
992        assert_eq!(
993            coord.get_state(&txn_id).unwrap(),
994            TransactionState::Preparing
995        );
996    }
997
998    #[test]
999    fn test_coordinator_add_participant() {
1000        let coord = TransactionCoordinator::new(NodeId::new("coord"));
1001        let txn_id = coord.begin_transaction();
1002
1003        assert!(coord.add_participant(&txn_id, NodeId::new("node1")));
1004        assert!(coord.add_participant(&txn_id, NodeId::new("node2")));
1005
1006        let txn = coord.get_transaction(&txn_id).unwrap();
1007        assert_eq!(txn.participant_count(), 2);
1008    }
1009
1010    #[test]
1011    fn test_coordinator_prepare() {
1012        let coord = TransactionCoordinator::new(NodeId::new("coord"));
1013        let txn_id = coord.begin_transaction();
1014
1015        coord.add_participant(&txn_id, NodeId::new("node1"));
1016        coord.add_participant(&txn_id, NodeId::new("node2"));
1017
1018        let messages = coord.prepare(&txn_id).unwrap();
1019        assert_eq!(messages.len(), 2);
1020
1021        for (_, msg) in &messages {
1022            match msg {
1023                TwoPhaseMessage::PrepareRequest { txn_id: id, .. } => {
1024                    assert_eq!(id, &txn_id);
1025                }
1026                _ => panic!("Expected PrepareRequest"),
1027            }
1028        }
1029    }
1030
1031    #[test]
1032    fn test_coordinator_full_commit() {
1033        let coord = TransactionCoordinator::new(NodeId::new("coord"));
1034        let txn_id = coord.begin_transaction();
1035
1036        let node1 = NodeId::new("node1");
1037        let node2 = NodeId::new("node2");
1038
1039        coord.add_participant(&txn_id, node1.clone());
1040        coord.add_participant(&txn_id, node2.clone());
1041
1042        // Phase 1: Prepare
1043        coord.prepare(&txn_id);
1044
1045        // Both participants vote commit
1046        coord.handle_prepare_response(&txn_id, &node1, ParticipantVote::Commit);
1047        let state = coord.handle_prepare_response(&txn_id, &node2, ParticipantVote::Commit);
1048
1049        assert_eq!(state, Some(TransactionState::Prepared));
1050
1051        // Phase 2: Commit
1052        let messages = coord.commit(&txn_id).unwrap();
1053        assert_eq!(messages.len(), 2);
1054
1055        // Both participants acknowledge
1056        coord.handle_commit_ack(&txn_id, &node1);
1057        let final_state = coord.handle_commit_ack(&txn_id, &node2);
1058
1059        assert_eq!(final_state, Some(TransactionState::Committed));
1060    }
1061
1062    #[test]
1063    fn test_coordinator_abort_on_vote() {
1064        let coord = TransactionCoordinator::new(NodeId::new("coord"));
1065        let txn_id = coord.begin_transaction();
1066
1067        let node1 = NodeId::new("node1");
1068        let node2 = NodeId::new("node2");
1069
1070        coord.add_participant(&txn_id, node1.clone());
1071        coord.add_participant(&txn_id, node2.clone());
1072
1073        coord.prepare(&txn_id);
1074
1075        coord.handle_prepare_response(&txn_id, &node1, ParticipantVote::Commit);
1076        let state = coord.handle_prepare_response(&txn_id, &node2, ParticipantVote::Abort);
1077
1078        assert_eq!(state, Some(TransactionState::Aborting));
1079    }
1080
1081    #[test]
1082    fn test_participant_handler() {
1083        let handler = ParticipantHandler::new(NodeId::new("node1"));
1084        let txn_id = TransactionId::new("txn_1");
1085
1086        let ops = vec![TransactionOperation::Write {
1087            key: "key1".to_string(),
1088            value: vec![1, 2, 3],
1089            shard_id: "shard_1".to_string(),
1090        }];
1091
1092        // Handle prepare
1093        let response = handler.handle_prepare(&txn_id, ops);
1094        match response {
1095            TwoPhaseMessage::PrepareResponse { vote, .. } => {
1096                assert_eq!(vote, ParticipantVote::Commit);
1097            }
1098            _ => panic!("Expected PrepareResponse"),
1099        }
1100
1101        assert!(handler.is_prepared(&txn_id));
1102        assert!(!handler.is_committed(&txn_id));
1103
1104        // Handle commit
1105        let commit_response = handler.handle_commit(&txn_id);
1106        match commit_response {
1107            TwoPhaseMessage::CommitAck { .. } => {}
1108            _ => panic!("Expected CommitAck"),
1109        }
1110
1111        assert!(!handler.is_prepared(&txn_id));
1112        assert!(handler.is_committed(&txn_id));
1113    }
1114
1115    #[test]
1116    fn test_participant_abort() {
1117        let handler = ParticipantHandler::new(NodeId::new("node1"));
1118        let txn_id = TransactionId::new("txn_1");
1119
1120        let ops = vec![TransactionOperation::Write {
1121            key: "key1".to_string(),
1122            value: vec![1, 2, 3],
1123            shard_id: "shard_1".to_string(),
1124        }];
1125
1126        handler.handle_prepare(&txn_id, ops);
1127        assert!(handler.is_prepared(&txn_id));
1128
1129        let abort_response = handler.handle_abort(&txn_id);
1130        match abort_response {
1131            TwoPhaseMessage::AbortAck { .. } => {}
1132            _ => panic!("Expected AbortAck"),
1133        }
1134
1135        assert!(!handler.is_prepared(&txn_id));
1136        assert!(!handler.is_committed(&txn_id));
1137    }
1138
1139    #[test]
1140    fn test_coordinator_cleanup() {
1141        let coord = TransactionCoordinator::new(NodeId::new("coord"));
1142
1143        // Create and complete a transaction
1144        let txn_id = coord.begin_transaction();
1145        coord.add_participant(&txn_id, NodeId::new("node1"));
1146        coord.prepare(&txn_id);
1147        coord.handle_prepare_response(&txn_id, &NodeId::new("node1"), ParticipantVote::Commit);
1148        coord.commit(&txn_id);
1149        coord.handle_commit_ack(&txn_id, &NodeId::new("node1"));
1150
1151        assert_eq!(coord.get_state(&txn_id), Some(TransactionState::Committed));
1152
1153        let cleaned = coord.cleanup_completed();
1154        assert_eq!(cleaned, 1);
1155        assert!(coord.get_state(&txn_id).is_none());
1156    }
1157
1158    #[test]
1159    fn test_active_count() {
1160        let coord = TransactionCoordinator::new(NodeId::new("coord"));
1161
1162        let txn1 = coord.begin_transaction();
1163        let _txn2 = coord.begin_transaction();
1164
1165        assert_eq!(coord.active_count(), 2);
1166
1167        // Complete txn1
1168        coord.add_participant(&txn1, NodeId::new("node1"));
1169        coord.prepare(&txn1);
1170        coord.handle_prepare_response(&txn1, &NodeId::new("node1"), ParticipantVote::Commit);
1171        coord.commit(&txn1);
1172        coord.handle_commit_ack(&txn1, &NodeId::new("node1"));
1173
1174        assert_eq!(coord.active_count(), 1);
1175    }
1176}