Skip to main content

aegis_replication/
transaction.rs

1//! Aegis Distributed Transactions
2//!
3//! Two-Phase Commit (2PC) protocol for distributed transaction coordination.
4//!
5//! @version 0.1.0
6//! @author AutomataNexus Development Team
7
8use crate::node::NodeId;
9use serde::{Deserialize, Serialize};
10use std::collections::{HashMap, HashSet};
11use std::sync::RwLock;
12
13// =============================================================================
14// Transaction ID
15// =============================================================================
16
17/// Unique identifier for distributed transactions.
18#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
19pub struct TransactionId(String);
20
21impl TransactionId {
22    /// Create a new transaction ID.
23    pub fn new(id: impl Into<String>) -> Self {
24        Self(id.into())
25    }
26
27    /// Generate a unique transaction ID.
28    pub fn generate() -> Self {
29        let timestamp = std::time::SystemTime::now()
30            .duration_since(std::time::UNIX_EPOCH)
31            .unwrap_or_default()
32            .as_nanos();
33        Self(format!("txn_{:x}", timestamp))
34    }
35
36    /// Get the ID as a string.
37    pub fn as_str(&self) -> &str {
38        &self.0
39    }
40}
41
42impl std::fmt::Display for TransactionId {
43    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44        write!(f, "{}", self.0)
45    }
46}
47
48// =============================================================================
49// Transaction State
50// =============================================================================
51
52/// State of a distributed transaction.
53#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
54pub enum TransactionState {
55    /// Transaction is being prepared.
56    Preparing,
57    /// Transaction is prepared and ready to commit.
58    Prepared,
59    /// Transaction is committing.
60    Committing,
61    /// Transaction has been committed.
62    Committed,
63    /// Transaction is aborting.
64    Aborting,
65    /// Transaction has been aborted.
66    Aborted,
67    /// Transaction state is unknown (recovery needed).
68    Unknown,
69}
70
71impl TransactionState {
72    /// Check if the transaction is in a terminal state.
73    pub fn is_terminal(&self) -> bool {
74        matches!(self, Self::Committed | Self::Aborted)
75    }
76
77    /// Check if the transaction can be committed.
78    pub fn can_commit(&self) -> bool {
79        matches!(self, Self::Prepared)
80    }
81
82    /// Check if the transaction can be aborted.
83    pub fn can_abort(&self) -> bool {
84        !matches!(self, Self::Committed)
85    }
86}
87
88// =============================================================================
89// Participant Vote
90// =============================================================================
91
92/// Vote from a participant in 2PC.
93#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
94pub enum ParticipantVote {
95    /// Participant votes to commit.
96    Commit,
97    /// Participant votes to abort.
98    Abort,
99    /// Participant hasn't voted yet.
100    Pending,
101}
102
103// =============================================================================
104// Transaction Participant
105// =============================================================================
106
107/// A participant in a distributed transaction.
108#[derive(Debug, Clone, Serialize, Deserialize)]
109pub struct TransactionParticipant {
110    pub node_id: NodeId,
111    pub vote: ParticipantVote,
112    pub prepared: bool,
113    pub committed: bool,
114    pub last_contact: Option<u64>,
115}
116
117impl TransactionParticipant {
118    /// Create a new participant.
119    pub fn new(node_id: NodeId) -> Self {
120        Self {
121            node_id,
122            vote: ParticipantVote::Pending,
123            prepared: false,
124            committed: false,
125            last_contact: None,
126        }
127    }
128
129    /// Record a prepare vote.
130    pub fn record_prepare(&mut self, vote: ParticipantVote) {
131        self.vote = vote;
132        self.prepared = vote == ParticipantVote::Commit;
133    }
134
135    /// Record commit acknowledgment.
136    pub fn record_commit(&mut self) {
137        self.committed = true;
138    }
139}
140
141// =============================================================================
142// Distributed Transaction
143// =============================================================================
144
145/// A distributed transaction managed by 2PC.
146#[derive(Debug, Clone, Serialize, Deserialize)]
147pub struct DistributedTransaction {
148    pub id: TransactionId,
149    pub coordinator: NodeId,
150    pub participants: HashMap<NodeId, TransactionParticipant>,
151    pub state: TransactionState,
152    pub created_at: u64,
153    pub timeout_ms: u64,
154    pub operations: Vec<TransactionOperation>,
155}
156
157impl DistributedTransaction {
158    /// Create a new distributed transaction.
159    pub fn new(id: TransactionId, coordinator: NodeId, timeout_ms: u64) -> Self {
160        let created_at = std::time::SystemTime::now()
161            .duration_since(std::time::UNIX_EPOCH)
162            .unwrap_or_default()
163            .as_millis() as u64;
164
165        Self {
166            id,
167            coordinator,
168            participants: HashMap::new(),
169            state: TransactionState::Preparing,
170            created_at,
171            timeout_ms,
172            operations: Vec::new(),
173        }
174    }
175
176    /// Add a participant to the transaction.
177    pub fn add_participant(&mut self, node_id: NodeId) {
178        if !self.participants.contains_key(&node_id) {
179            self.participants
180                .insert(node_id.clone(), TransactionParticipant::new(node_id));
181        }
182    }
183
184    /// Add an operation to the transaction.
185    pub fn add_operation(&mut self, operation: TransactionOperation) {
186        self.operations.push(operation);
187    }
188
189    /// Check if all participants are prepared.
190    pub fn all_prepared(&self) -> bool {
191        self.participants.values().all(|p| p.prepared)
192    }
193
194    /// Check if all participants have committed.
195    pub fn all_committed(&self) -> bool {
196        self.participants.values().all(|p| p.committed)
197    }
198
199    /// Check if any participant voted to abort.
200    pub fn any_abort(&self) -> bool {
201        self.participants
202            .values()
203            .any(|p| p.vote == ParticipantVote::Abort)
204    }
205
206    /// Check if the transaction has timed out.
207    pub fn is_timed_out(&self) -> bool {
208        let now = std::time::SystemTime::now()
209            .duration_since(std::time::UNIX_EPOCH)
210            .unwrap_or_default()
211            .as_millis() as u64;
212        now - self.created_at > self.timeout_ms
213    }
214
215    /// Get participant count.
216    pub fn participant_count(&self) -> usize {
217        self.participants.len()
218    }
219
220    /// Get prepared count.
221    pub fn prepared_count(&self) -> usize {
222        self.participants.values().filter(|p| p.prepared).count()
223    }
224}
225
226// =============================================================================
227// Transaction Operation
228// =============================================================================
229
230/// An operation within a distributed transaction.
231#[derive(Debug, Clone, Serialize, Deserialize)]
232pub enum TransactionOperation {
233    /// Read operation.
234    Read { key: String, shard_id: String },
235    /// Write operation.
236    Write {
237        key: String,
238        value: Vec<u8>,
239        shard_id: String,
240    },
241    /// Delete operation.
242    Delete { key: String, shard_id: String },
243    /// Compare and swap operation.
244    CompareAndSwap {
245        key: String,
246        expected: Option<Vec<u8>>,
247        new_value: Vec<u8>,
248        shard_id: String,
249    },
250}
251
252impl TransactionOperation {
253    /// Get the shard ID for this operation.
254    pub fn shard_id(&self) -> &str {
255        match self {
256            Self::Read { shard_id, .. } => shard_id,
257            Self::Write { shard_id, .. } => shard_id,
258            Self::Delete { shard_id, .. } => shard_id,
259            Self::CompareAndSwap { shard_id, .. } => shard_id,
260        }
261    }
262
263    /// Get the key for this operation.
264    pub fn key(&self) -> &str {
265        match self {
266            Self::Read { key, .. } => key,
267            Self::Write { key, .. } => key,
268            Self::Delete { key, .. } => key,
269            Self::CompareAndSwap { key, .. } => key,
270        }
271    }
272
273    /// Check if this is a write operation.
274    pub fn is_write(&self) -> bool {
275        !matches!(self, Self::Read { .. })
276    }
277}
278
279// =============================================================================
280// 2PC Messages
281// =============================================================================
282
283/// Messages for 2PC protocol.
284#[derive(Debug, Clone, Serialize, Deserialize)]
285pub enum TwoPhaseMessage {
286    /// Prepare request from coordinator.
287    PrepareRequest {
288        txn_id: TransactionId,
289        operations: Vec<TransactionOperation>,
290    },
291    /// Prepare response from participant.
292    PrepareResponse {
293        txn_id: TransactionId,
294        vote: ParticipantVote,
295        participant: NodeId,
296    },
297    /// Commit request from coordinator.
298    CommitRequest { txn_id: TransactionId },
299    /// Commit acknowledgment from participant.
300    CommitAck {
301        txn_id: TransactionId,
302        participant: NodeId,
303    },
304    /// Abort request from coordinator.
305    AbortRequest { txn_id: TransactionId },
306    /// Abort acknowledgment from participant.
307    AbortAck {
308        txn_id: TransactionId,
309        participant: NodeId,
310    },
311    /// Query transaction status (for recovery).
312    StatusQuery { txn_id: TransactionId },
313    /// Status response.
314    StatusResponse {
315        txn_id: TransactionId,
316        state: TransactionState,
317    },
318}
319
320// =============================================================================
321// Transaction Coordinator
322// =============================================================================
323
324/// Coordinator for distributed transactions using 2PC.
325pub struct TransactionCoordinator {
326    node_id: NodeId,
327    transactions: RwLock<HashMap<TransactionId, DistributedTransaction>>,
328    default_timeout_ms: u64,
329    prepared_log: RwLock<HashSet<TransactionId>>,
330}
331
332impl TransactionCoordinator {
333    /// Create a new transaction coordinator.
334    pub fn new(node_id: NodeId) -> Self {
335        Self {
336            node_id,
337            transactions: RwLock::new(HashMap::new()),
338            default_timeout_ms: 30000,
339            prepared_log: RwLock::new(HashSet::new()),
340        }
341    }
342
343    /// Create with custom timeout.
344    pub fn with_timeout(node_id: NodeId, timeout_ms: u64) -> Self {
345        Self {
346            node_id,
347            transactions: RwLock::new(HashMap::new()),
348            default_timeout_ms: timeout_ms,
349            prepared_log: RwLock::new(HashSet::new()),
350        }
351    }
352
353    /// Begin a new distributed transaction.
354    pub fn begin_transaction(&self) -> TransactionId {
355        let txn_id = TransactionId::generate();
356        let txn = DistributedTransaction::new(
357            txn_id.clone(),
358            self.node_id.clone(),
359            self.default_timeout_ms,
360        );
361        self.transactions
362            .write()
363            .expect("transaction coordinator transactions lock poisoned")
364            .insert(txn_id.clone(), txn);
365        txn_id
366    }
367
368    /// Begin a transaction with a specific ID.
369    pub fn begin_transaction_with_id(&self, txn_id: TransactionId) {
370        let txn = DistributedTransaction::new(
371            txn_id.clone(),
372            self.node_id.clone(),
373            self.default_timeout_ms,
374        );
375        self.transactions
376            .write()
377            .expect("transaction coordinator transactions lock poisoned")
378            .insert(txn_id, txn);
379    }
380
381    /// Add a participant to a transaction.
382    pub fn add_participant(&self, txn_id: &TransactionId, node_id: NodeId) -> bool {
383        if let Some(txn) = self
384            .transactions
385            .write()
386            .expect("transaction coordinator transactions lock poisoned")
387            .get_mut(txn_id)
388        {
389            txn.add_participant(node_id);
390            true
391        } else {
392            false
393        }
394    }
395
396    /// Add an operation to a transaction.
397    pub fn add_operation(&self, txn_id: &TransactionId, operation: TransactionOperation) -> bool {
398        if let Some(txn) = self
399            .transactions
400            .write()
401            .expect("transaction coordinator transactions lock poisoned")
402            .get_mut(txn_id)
403        {
404            txn.add_operation(operation);
405            true
406        } else {
407            false
408        }
409    }
410
411    /// Phase 1: Prepare - Generate prepare requests for all participants.
412    pub fn prepare(&self, txn_id: &TransactionId) -> Option<Vec<(NodeId, TwoPhaseMessage)>> {
413        let txns = self
414            .transactions
415            .read()
416            .expect("transaction coordinator transactions lock poisoned");
417        let txn = txns.get(txn_id)?;
418
419        if txn.state != TransactionState::Preparing {
420            return None;
421        }
422
423        let messages: Vec<_> = txn
424            .participants
425            .keys()
426            .map(|node_id| {
427                (
428                    node_id.clone(),
429                    TwoPhaseMessage::PrepareRequest {
430                        txn_id: txn_id.clone(),
431                        operations: txn.operations.clone(),
432                    },
433                )
434            })
435            .collect();
436
437        Some(messages)
438    }
439
440    /// Handle prepare response from a participant.
441    pub fn handle_prepare_response(
442        &self,
443        txn_id: &TransactionId,
444        participant: &NodeId,
445        vote: ParticipantVote,
446    ) -> Option<TransactionState> {
447        let mut txns = self
448            .transactions
449            .write()
450            .expect("transaction coordinator transactions lock poisoned");
451        let txn = txns.get_mut(txn_id)?;
452
453        if let Some(p) = txn.participants.get_mut(participant) {
454            p.record_prepare(vote);
455        }
456
457        // Check if we can make a decision
458        if txn.any_abort() {
459            txn.state = TransactionState::Aborting;
460            Some(TransactionState::Aborting)
461        } else if txn.all_prepared() {
462            txn.state = TransactionState::Prepared;
463            self.prepared_log
464                .write()
465                .expect("transaction coordinator prepared_log lock poisoned")
466                .insert(txn_id.clone());
467            Some(TransactionState::Prepared)
468        } else {
469            None
470        }
471    }
472
473    /// Phase 2: Commit - Generate commit requests for all participants.
474    pub fn commit(&self, txn_id: &TransactionId) -> Option<Vec<(NodeId, TwoPhaseMessage)>> {
475        let mut txns = self
476            .transactions
477            .write()
478            .expect("transaction coordinator transactions lock poisoned");
479        let txn = txns.get_mut(txn_id)?;
480
481        if !txn.state.can_commit() {
482            return None;
483        }
484
485        txn.state = TransactionState::Committing;
486
487        let messages: Vec<_> = txn
488            .participants
489            .keys()
490            .map(|node_id| {
491                (
492                    node_id.clone(),
493                    TwoPhaseMessage::CommitRequest {
494                        txn_id: txn_id.clone(),
495                    },
496                )
497            })
498            .collect();
499
500        Some(messages)
501    }
502
503    /// Handle commit acknowledgment from a participant.
504    pub fn handle_commit_ack(
505        &self,
506        txn_id: &TransactionId,
507        participant: &NodeId,
508    ) -> Option<TransactionState> {
509        let mut txns = self
510            .transactions
511            .write()
512            .expect("transaction coordinator transactions lock poisoned");
513        let txn = txns.get_mut(txn_id)?;
514
515        if let Some(p) = txn.participants.get_mut(participant) {
516            p.record_commit();
517        }
518
519        if txn.all_committed() {
520            txn.state = TransactionState::Committed;
521            Some(TransactionState::Committed)
522        } else {
523            None
524        }
525    }
526
527    /// Abort a transaction - Generate abort requests.
528    pub fn abort(&self, txn_id: &TransactionId) -> Option<Vec<(NodeId, TwoPhaseMessage)>> {
529        let mut txns = self
530            .transactions
531            .write()
532            .expect("transaction coordinator transactions lock poisoned");
533        let txn = txns.get_mut(txn_id)?;
534
535        if !txn.state.can_abort() {
536            return None;
537        }
538
539        txn.state = TransactionState::Aborting;
540
541        let messages: Vec<_> = txn
542            .participants
543            .keys()
544            .map(|node_id| {
545                (
546                    node_id.clone(),
547                    TwoPhaseMessage::AbortRequest {
548                        txn_id: txn_id.clone(),
549                    },
550                )
551            })
552            .collect();
553
554        Some(messages)
555    }
556
557    /// Handle abort acknowledgment.
558    pub fn handle_abort_ack(&self, txn_id: &TransactionId, _participant: &NodeId) -> bool {
559        let mut txns = self
560            .transactions
561            .write()
562            .expect("transaction coordinator transactions lock poisoned");
563        if let Some(txn) = txns.get_mut(txn_id) {
564            txn.state = TransactionState::Aborted;
565            true
566        } else {
567            false
568        }
569    }
570
571    /// Get transaction state.
572    pub fn get_state(&self, txn_id: &TransactionId) -> Option<TransactionState> {
573        self.transactions
574            .read()
575            .expect("transaction coordinator transactions lock poisoned")
576            .get(txn_id)
577            .map(|t| t.state)
578    }
579
580    /// Get transaction details.
581    pub fn get_transaction(&self, txn_id: &TransactionId) -> Option<DistributedTransaction> {
582        self.transactions
583            .read()
584            .expect("transaction coordinator transactions lock poisoned")
585            .get(txn_id)
586            .cloned()
587    }
588
589    /// Check for timed out transactions.
590    pub fn check_timeouts(&self) -> Vec<TransactionId> {
591        self.transactions
592            .read()
593            .expect("transaction coordinator transactions lock poisoned")
594            .iter()
595            .filter(|(_, txn)| txn.is_timed_out() && !txn.state.is_terminal())
596            .map(|(id, _)| id.clone())
597            .collect()
598    }
599
600    /// Clean up completed transactions.
601    pub fn cleanup_completed(&self) -> usize {
602        let mut txns = self
603            .transactions
604            .write()
605            .expect("transaction coordinator transactions lock poisoned");
606        let before = txns.len();
607        txns.retain(|_, txn| !txn.state.is_terminal());
608        before - txns.len()
609    }
610
611    /// Get active transaction count.
612    pub fn active_count(&self) -> usize {
613        self.transactions
614            .read()
615            .expect("transaction coordinator transactions lock poisoned")
616            .values()
617            .filter(|t| !t.state.is_terminal())
618            .count()
619    }
620
621    /// Check if a transaction was prepared (for recovery).
622    pub fn was_prepared(&self, txn_id: &TransactionId) -> bool {
623        self.prepared_log
624            .read()
625            .expect("transaction coordinator prepared_log lock poisoned")
626            .contains(txn_id)
627    }
628}
629
630// =============================================================================
631// Transaction Participant Handler
632// =============================================================================
633
634/// Callback type for validating operations before prepare.
635pub type ValidationCallback =
636    Box<dyn Fn(&TransactionId, &[TransactionOperation]) -> ValidationResult + Send + Sync>;
637
638/// Callback type for committing operations.
639pub type CommitCallback =
640    Box<dyn Fn(&TransactionId, &[TransactionOperation]) -> Result<(), String> + Send + Sync>;
641
642/// Callback type for aborting/rolling back operations.
643pub type AbortCallback = Box<dyn Fn(&TransactionId) -> Result<(), String> + Send + Sync>;
644
645/// Result of validating operations for 2PC prepare.
646#[derive(Debug, Clone)]
647pub struct ValidationResult {
648    /// Whether validation succeeded.
649    pub success: bool,
650    /// Error message if validation failed.
651    pub error: Option<String>,
652    /// Keys that were locked for this transaction.
653    pub locked_keys: Vec<String>,
654}
655
656impl ValidationResult {
657    /// Create a successful validation result.
658    pub fn success(locked_keys: Vec<String>) -> Self {
659        Self {
660            success: true,
661            error: None,
662            locked_keys,
663        }
664    }
665
666    /// Create a failed validation result.
667    pub fn failure(error: impl Into<String>) -> Self {
668        Self {
669            success: false,
670            error: Some(error.into()),
671            locked_keys: vec![],
672        }
673    }
674}
675
676/// Handler for transaction participants with storage integration.
677pub struct ParticipantHandler {
678    node_id: NodeId,
679    pending_prepares: RwLock<HashMap<TransactionId, Vec<TransactionOperation>>>,
680    prepared: RwLock<HashSet<TransactionId>>,
681    committed: RwLock<HashSet<TransactionId>>,
682    locked_keys: RwLock<HashMap<TransactionId, Vec<String>>>,
683    validation_callback: RwLock<Option<ValidationCallback>>,
684    commit_callback: RwLock<Option<CommitCallback>>,
685    abort_callback: RwLock<Option<AbortCallback>>,
686}
687
688impl ParticipantHandler {
689    /// Create a new participant handler.
690    pub fn new(node_id: NodeId) -> Self {
691        Self {
692            node_id,
693            pending_prepares: RwLock::new(HashMap::new()),
694            prepared: RwLock::new(HashSet::new()),
695            committed: RwLock::new(HashSet::new()),
696            locked_keys: RwLock::new(HashMap::new()),
697            validation_callback: RwLock::new(None),
698            commit_callback: RwLock::new(None),
699            abort_callback: RwLock::new(None),
700        }
701    }
702
703    /// Set the validation callback for prepare phase.
704    pub fn set_validation_callback(&self, callback: ValidationCallback) {
705        *self
706            .validation_callback
707            .write()
708            .expect("participant handler validation_callback lock poisoned") = Some(callback);
709    }
710
711    /// Set the commit callback for commit phase.
712    pub fn set_commit_callback(&self, callback: CommitCallback) {
713        *self
714            .commit_callback
715            .write()
716            .expect("participant handler commit_callback lock poisoned") = Some(callback);
717    }
718
719    /// Set the abort callback for abort/rollback phase.
720    pub fn set_abort_callback(&self, callback: AbortCallback) {
721        *self
722            .abort_callback
723            .write()
724            .expect("participant handler abort_callback lock poisoned") = Some(callback);
725    }
726
727    /// Handle prepare request with full storage integration.
728    pub fn handle_prepare(
729        &self,
730        txn_id: &TransactionId,
731        operations: Vec<TransactionOperation>,
732    ) -> TwoPhaseMessage {
733        // Check if there's a validation callback
734        let validation_result = {
735            let callback_guard = self
736                .validation_callback
737                .read()
738                .expect("participant handler validation_callback lock poisoned");
739            if let Some(ref callback) = *callback_guard {
740                // Run validation against storage
741                callback(txn_id, &operations)
742            } else {
743                // No callback - perform basic validation (key conflict check)
744                self.basic_validation(txn_id, &operations)
745            }
746        };
747
748        if !validation_result.success {
749            return TwoPhaseMessage::PrepareResponse {
750                txn_id: txn_id.clone(),
751                vote: ParticipantVote::Abort,
752                participant: self.node_id.clone(),
753            };
754        }
755
756        // Store locked keys for later release
757        self.locked_keys
758            .write()
759            .expect("participant handler locked_keys lock poisoned")
760            .insert(txn_id.clone(), validation_result.locked_keys);
761
762        // Store operations for commit phase
763        self.pending_prepares
764            .write()
765            .expect("participant handler pending_prepares lock poisoned")
766            .insert(txn_id.clone(), operations);
767        self.prepared
768            .write()
769            .expect("participant handler prepared lock poisoned")
770            .insert(txn_id.clone());
771
772        TwoPhaseMessage::PrepareResponse {
773            txn_id: txn_id.clone(),
774            vote: ParticipantVote::Commit,
775            participant: self.node_id.clone(),
776        }
777    }
778
779    /// Basic validation when no custom callback is provided.
780    fn basic_validation(
781        &self,
782        txn_id: &TransactionId,
783        operations: &[TransactionOperation],
784    ) -> ValidationResult {
785        // Check for key conflicts with other pending transactions
786        let pending = self
787            .pending_prepares
788            .read()
789            .expect("participant handler pending_prepares lock poisoned");
790        let locked = self
791            .locked_keys
792            .read()
793            .expect("participant handler locked_keys lock poisoned");
794
795        let mut keys_to_lock = Vec::new();
796        for op in operations {
797            let key = op.key().to_string();
798
799            // Check if key is locked by another transaction
800            for (other_txn_id, other_keys) in locked.iter() {
801                if other_txn_id != txn_id && other_keys.contains(&key) {
802                    return ValidationResult::failure(format!(
803                        "Key '{}' is locked by transaction {}",
804                        key, other_txn_id
805                    ));
806                }
807            }
808
809            // Check if key is being modified by another pending transaction
810            for (other_txn_id, other_ops) in pending.iter() {
811                if other_txn_id != txn_id {
812                    for other_op in other_ops {
813                        if other_op.key() == key && other_op.is_write() && op.is_write() {
814                            return ValidationResult::failure(format!(
815                                "Write conflict on key '{}' with transaction {}",
816                                key, other_txn_id
817                            ));
818                        }
819                    }
820                }
821            }
822
823            if op.is_write() {
824                keys_to_lock.push(key);
825            }
826        }
827
828        ValidationResult::success(keys_to_lock)
829    }
830
831    /// Handle prepare with validation function.
832    pub fn handle_prepare_with_validation<F>(
833        &self,
834        txn_id: &TransactionId,
835        operations: Vec<TransactionOperation>,
836        validator: F,
837    ) -> TwoPhaseMessage
838    where
839        F: FnOnce(&[TransactionOperation]) -> bool,
840    {
841        let vote = if validator(&operations) {
842            self.pending_prepares
843                .write()
844                .expect("participant handler pending_prepares lock poisoned")
845                .insert(txn_id.clone(), operations);
846            self.prepared
847                .write()
848                .expect("participant handler prepared lock poisoned")
849                .insert(txn_id.clone());
850            ParticipantVote::Commit
851        } else {
852            ParticipantVote::Abort
853        };
854
855        TwoPhaseMessage::PrepareResponse {
856            txn_id: txn_id.clone(),
857            vote,
858            participant: self.node_id.clone(),
859        }
860    }
861
862    /// Handle commit request with storage integration.
863    pub fn handle_commit(&self, txn_id: &TransactionId) -> TwoPhaseMessage {
864        // Get the operations to commit
865        let operations = self
866            .pending_prepares
867            .write()
868            .expect("participant handler pending_prepares lock poisoned")
869            .remove(txn_id);
870
871        // Execute commit callback if set
872        if let Some(ref callback) = *self
873            .commit_callback
874            .read()
875            .expect("participant handler commit_callback lock poisoned")
876        {
877            if let Some(ops) = &operations {
878                if let Err(e) = callback(txn_id, ops) {
879                    // Log error but continue - we must commit after prepare
880                    tracing::error!("Commit callback failed for {}: {}", txn_id, e);
881                }
882            }
883        }
884
885        // Clean up state
886        self.prepared
887            .write()
888            .expect("participant handler prepared lock poisoned")
889            .remove(txn_id);
890        self.locked_keys
891            .write()
892            .expect("participant handler locked_keys lock poisoned")
893            .remove(txn_id);
894        self.committed
895            .write()
896            .expect("participant handler committed lock poisoned")
897            .insert(txn_id.clone());
898
899        TwoPhaseMessage::CommitAck {
900            txn_id: txn_id.clone(),
901            participant: self.node_id.clone(),
902        }
903    }
904
905    /// Handle abort request with proper cleanup.
906    pub fn handle_abort(&self, txn_id: &TransactionId) -> TwoPhaseMessage {
907        // Execute abort callback if set
908        if let Some(ref callback) = *self
909            .abort_callback
910            .read()
911            .expect("participant handler abort_callback lock poisoned")
912        {
913            if let Err(e) = callback(txn_id) {
914                tracing::error!("Abort callback failed for {}: {}", txn_id, e);
915            }
916        }
917
918        // Rollback any prepared state and release locks
919        self.pending_prepares
920            .write()
921            .expect("participant handler pending_prepares lock poisoned")
922            .remove(txn_id);
923        self.prepared
924            .write()
925            .expect("participant handler prepared lock poisoned")
926            .remove(txn_id);
927        self.locked_keys
928            .write()
929            .expect("participant handler locked_keys lock poisoned")
930            .remove(txn_id);
931
932        TwoPhaseMessage::AbortAck {
933            txn_id: txn_id.clone(),
934            participant: self.node_id.clone(),
935        }
936    }
937
938    /// Get currently locked keys for a transaction.
939    pub fn get_locked_keys(&self, txn_id: &TransactionId) -> Vec<String> {
940        self.locked_keys
941            .read()
942            .expect("participant handler locked_keys lock poisoned")
943            .get(txn_id)
944            .cloned()
945            .unwrap_or_default()
946    }
947
948    /// Check if a key is locked by any transaction.
949    pub fn is_key_locked(&self, key: &str) -> Option<TransactionId> {
950        let locked = self
951            .locked_keys
952            .read()
953            .expect("participant handler locked_keys lock poisoned");
954        for (txn_id, keys) in locked.iter() {
955            if keys.iter().any(|k| k == key) {
956                return Some(txn_id.clone());
957            }
958        }
959        None
960    }
961
962    /// Check if a transaction is prepared.
963    pub fn is_prepared(&self, txn_id: &TransactionId) -> bool {
964        self.prepared
965            .read()
966            .expect("participant handler prepared lock poisoned")
967            .contains(txn_id)
968    }
969
970    /// Check if a transaction is committed.
971    pub fn is_committed(&self, txn_id: &TransactionId) -> bool {
972        self.committed
973            .read()
974            .expect("participant handler committed lock poisoned")
975            .contains(txn_id)
976    }
977
978    /// Get pending prepare count.
979    pub fn pending_count(&self) -> usize {
980        self.pending_prepares
981            .read()
982            .expect("participant handler pending_prepares lock poisoned")
983            .len()
984    }
985}
986
987// =============================================================================
988// Tests
989// =============================================================================
990
991#[cfg(test)]
992mod tests {
993    use super::*;
994
995    #[test]
996    fn test_transaction_id() {
997        let id1 = TransactionId::new("txn_1");
998        let id2 = TransactionId::generate();
999
1000        assert_eq!(id1.as_str(), "txn_1");
1001        assert!(id2.as_str().starts_with("txn_"));
1002    }
1003
1004    #[test]
1005    fn test_transaction_state() {
1006        assert!(!TransactionState::Preparing.is_terminal());
1007        assert!(TransactionState::Committed.is_terminal());
1008        assert!(TransactionState::Aborted.is_terminal());
1009
1010        assert!(TransactionState::Prepared.can_commit());
1011        assert!(!TransactionState::Preparing.can_commit());
1012
1013        assert!(TransactionState::Preparing.can_abort());
1014        assert!(!TransactionState::Committed.can_abort());
1015    }
1016
1017    #[test]
1018    fn test_distributed_transaction() {
1019        let txn_id = TransactionId::new("txn_1");
1020        let coordinator = NodeId::new("coord");
1021        let mut txn = DistributedTransaction::new(txn_id, coordinator, 30000);
1022
1023        assert_eq!(txn.state, TransactionState::Preparing);
1024        assert_eq!(txn.participant_count(), 0);
1025
1026        txn.add_participant(NodeId::new("node1"));
1027        txn.add_participant(NodeId::new("node2"));
1028
1029        assert_eq!(txn.participant_count(), 2);
1030        assert!(!txn.all_prepared());
1031
1032        txn.participants
1033            .get_mut(&NodeId::new("node1"))
1034            .unwrap()
1035            .record_prepare(ParticipantVote::Commit);
1036
1037        assert!(!txn.all_prepared());
1038        assert_eq!(txn.prepared_count(), 1);
1039
1040        txn.participants
1041            .get_mut(&NodeId::new("node2"))
1042            .unwrap()
1043            .record_prepare(ParticipantVote::Commit);
1044
1045        assert!(txn.all_prepared());
1046        assert!(!txn.any_abort());
1047    }
1048
1049    #[test]
1050    fn test_transaction_abort_vote() {
1051        let txn_id = TransactionId::new("txn_1");
1052        let mut txn = DistributedTransaction::new(txn_id, NodeId::new("coord"), 30000);
1053
1054        txn.add_participant(NodeId::new("node1"));
1055        txn.add_participant(NodeId::new("node2"));
1056
1057        txn.participants
1058            .get_mut(&NodeId::new("node1"))
1059            .unwrap()
1060            .record_prepare(ParticipantVote::Commit);
1061        txn.participants
1062            .get_mut(&NodeId::new("node2"))
1063            .unwrap()
1064            .record_prepare(ParticipantVote::Abort);
1065
1066        assert!(!txn.all_prepared());
1067        assert!(txn.any_abort());
1068    }
1069
1070    #[test]
1071    fn test_transaction_operation() {
1072        let write_op = TransactionOperation::Write {
1073            key: "user:1".to_string(),
1074            value: vec![1, 2, 3],
1075            shard_id: "shard_1".to_string(),
1076        };
1077
1078        assert_eq!(write_op.key(), "user:1");
1079        assert_eq!(write_op.shard_id(), "shard_1");
1080        assert!(write_op.is_write());
1081
1082        let read_op = TransactionOperation::Read {
1083            key: "user:2".to_string(),
1084            shard_id: "shard_2".to_string(),
1085        };
1086
1087        assert!(!read_op.is_write());
1088    }
1089
1090    #[test]
1091    fn test_coordinator_begin_transaction() {
1092        let coord = TransactionCoordinator::new(NodeId::new("coord"));
1093        let txn_id = coord.begin_transaction();
1094
1095        assert!(coord.get_state(&txn_id).is_some());
1096        assert_eq!(
1097            coord.get_state(&txn_id).unwrap(),
1098            TransactionState::Preparing
1099        );
1100    }
1101
1102    #[test]
1103    fn test_coordinator_add_participant() {
1104        let coord = TransactionCoordinator::new(NodeId::new("coord"));
1105        let txn_id = coord.begin_transaction();
1106
1107        assert!(coord.add_participant(&txn_id, NodeId::new("node1")));
1108        assert!(coord.add_participant(&txn_id, NodeId::new("node2")));
1109
1110        let txn = coord.get_transaction(&txn_id).unwrap();
1111        assert_eq!(txn.participant_count(), 2);
1112    }
1113
1114    #[test]
1115    fn test_coordinator_prepare() {
1116        let coord = TransactionCoordinator::new(NodeId::new("coord"));
1117        let txn_id = coord.begin_transaction();
1118
1119        coord.add_participant(&txn_id, NodeId::new("node1"));
1120        coord.add_participant(&txn_id, NodeId::new("node2"));
1121
1122        let messages = coord.prepare(&txn_id).unwrap();
1123        assert_eq!(messages.len(), 2);
1124
1125        for (_, msg) in &messages {
1126            match msg {
1127                TwoPhaseMessage::PrepareRequest { txn_id: id, .. } => {
1128                    assert_eq!(id, &txn_id);
1129                }
1130                _ => panic!("Expected PrepareRequest"),
1131            }
1132        }
1133    }
1134
1135    #[test]
1136    fn test_coordinator_full_commit() {
1137        let coord = TransactionCoordinator::new(NodeId::new("coord"));
1138        let txn_id = coord.begin_transaction();
1139
1140        let node1 = NodeId::new("node1");
1141        let node2 = NodeId::new("node2");
1142
1143        coord.add_participant(&txn_id, node1.clone());
1144        coord.add_participant(&txn_id, node2.clone());
1145
1146        // Phase 1: Prepare
1147        coord.prepare(&txn_id);
1148
1149        // Both participants vote commit
1150        coord.handle_prepare_response(&txn_id, &node1, ParticipantVote::Commit);
1151        let state = coord.handle_prepare_response(&txn_id, &node2, ParticipantVote::Commit);
1152
1153        assert_eq!(state, Some(TransactionState::Prepared));
1154
1155        // Phase 2: Commit
1156        let messages = coord.commit(&txn_id).unwrap();
1157        assert_eq!(messages.len(), 2);
1158
1159        // Both participants acknowledge
1160        coord.handle_commit_ack(&txn_id, &node1);
1161        let final_state = coord.handle_commit_ack(&txn_id, &node2);
1162
1163        assert_eq!(final_state, Some(TransactionState::Committed));
1164    }
1165
1166    #[test]
1167    fn test_coordinator_abort_on_vote() {
1168        let coord = TransactionCoordinator::new(NodeId::new("coord"));
1169        let txn_id = coord.begin_transaction();
1170
1171        let node1 = NodeId::new("node1");
1172        let node2 = NodeId::new("node2");
1173
1174        coord.add_participant(&txn_id, node1.clone());
1175        coord.add_participant(&txn_id, node2.clone());
1176
1177        coord.prepare(&txn_id);
1178
1179        coord.handle_prepare_response(&txn_id, &node1, ParticipantVote::Commit);
1180        let state = coord.handle_prepare_response(&txn_id, &node2, ParticipantVote::Abort);
1181
1182        assert_eq!(state, Some(TransactionState::Aborting));
1183    }
1184
1185    #[test]
1186    fn test_participant_handler() {
1187        let handler = ParticipantHandler::new(NodeId::new("node1"));
1188        let txn_id = TransactionId::new("txn_1");
1189
1190        let ops = vec![TransactionOperation::Write {
1191            key: "key1".to_string(),
1192            value: vec![1, 2, 3],
1193            shard_id: "shard_1".to_string(),
1194        }];
1195
1196        // Handle prepare
1197        let response = handler.handle_prepare(&txn_id, ops);
1198        match response {
1199            TwoPhaseMessage::PrepareResponse { vote, .. } => {
1200                assert_eq!(vote, ParticipantVote::Commit);
1201            }
1202            _ => panic!("Expected PrepareResponse"),
1203        }
1204
1205        assert!(handler.is_prepared(&txn_id));
1206        assert!(!handler.is_committed(&txn_id));
1207
1208        // Handle commit
1209        let commit_response = handler.handle_commit(&txn_id);
1210        match commit_response {
1211            TwoPhaseMessage::CommitAck { .. } => {}
1212            _ => panic!("Expected CommitAck"),
1213        }
1214
1215        assert!(!handler.is_prepared(&txn_id));
1216        assert!(handler.is_committed(&txn_id));
1217    }
1218
1219    #[test]
1220    fn test_participant_abort() {
1221        let handler = ParticipantHandler::new(NodeId::new("node1"));
1222        let txn_id = TransactionId::new("txn_1");
1223
1224        let ops = vec![TransactionOperation::Write {
1225            key: "key1".to_string(),
1226            value: vec![1, 2, 3],
1227            shard_id: "shard_1".to_string(),
1228        }];
1229
1230        handler.handle_prepare(&txn_id, ops);
1231        assert!(handler.is_prepared(&txn_id));
1232
1233        let abort_response = handler.handle_abort(&txn_id);
1234        match abort_response {
1235            TwoPhaseMessage::AbortAck { .. } => {}
1236            _ => panic!("Expected AbortAck"),
1237        }
1238
1239        assert!(!handler.is_prepared(&txn_id));
1240        assert!(!handler.is_committed(&txn_id));
1241    }
1242
1243    #[test]
1244    fn test_coordinator_cleanup() {
1245        let coord = TransactionCoordinator::new(NodeId::new("coord"));
1246
1247        // Create and complete a transaction
1248        let txn_id = coord.begin_transaction();
1249        coord.add_participant(&txn_id, NodeId::new("node1"));
1250        coord.prepare(&txn_id);
1251        coord.handle_prepare_response(&txn_id, &NodeId::new("node1"), ParticipantVote::Commit);
1252        coord.commit(&txn_id);
1253        coord.handle_commit_ack(&txn_id, &NodeId::new("node1"));
1254
1255        assert_eq!(coord.get_state(&txn_id), Some(TransactionState::Committed));
1256
1257        let cleaned = coord.cleanup_completed();
1258        assert_eq!(cleaned, 1);
1259        assert!(coord.get_state(&txn_id).is_none());
1260    }
1261
1262    #[test]
1263    fn test_active_count() {
1264        let coord = TransactionCoordinator::new(NodeId::new("coord"));
1265
1266        let txn1 = coord.begin_transaction();
1267        let _txn2 = coord.begin_transaction();
1268
1269        assert_eq!(coord.active_count(), 2);
1270
1271        // Complete txn1
1272        coord.add_participant(&txn1, NodeId::new("node1"));
1273        coord.prepare(&txn1);
1274        coord.handle_prepare_response(&txn1, &NodeId::new("node1"), ParticipantVote::Commit);
1275        coord.commit(&txn1);
1276        coord.handle_commit_ack(&txn1, &NodeId::new("node1"));
1277
1278        assert_eq!(coord.active_count(), 1);
1279    }
1280}