Skip to main content

rivven_core/
transaction.rs

1//! Native Transaction Support (KIP-98)
2//!
3//! Provides exactly-once semantics with cross-topic atomic writes.
4//!
5//! ## Transaction Protocol
6//!
7//! ```text
8//! Producer                     Transaction Coordinator            Partitions
9//!    │                                   │                            │
10//!    │─── InitProducerId ───────────────>│                            │
11//!    │<── PID=123, Epoch=0 ──────────────│                            │
12//!    │                                   │                            │
13//!    │─── BeginTransaction(TxnId) ──────>│                            │
14//!    │<── OK ────────────────────────────│                            │
15//!    │                                   │                            │
16//!    │─── AddPartitionsToTxn(p1,p2) ────>│                            │
17//!    │<── OK ────────────────────────────│                            │
18//!    │                                   │                            │
19//!    │─── Produce(p1, PID, Seq) ──────────────────────────────────────>│
20//!    │<── OK ───────────────────────────────────────────────────────────│
21//!    │                                   │                            │
22//!    │─── Produce(p2, PID, Seq) ──────────────────────────────────────>│
23//!    │<── OK ───────────────────────────────────────────────────────────│
24//!    │                                   │                            │
25//!    │─── CommitTransaction(TxnId) ─────>│                            │
26//!    │                                   │─── WriteTxnMarker(COMMIT) ─>│
27//!    │                                   │<── OK ─────────────────────│
28//!    │<── OK ────────────────────────────│                            │
29//! ```
30//!
31//! ## Transaction States
32//!
33//! ```text
34//! Empty ──────> Ongoing ──────> PrepareCommit ──────> CompleteCommit
35//!                  │                  │                     │
36//!                  │                  v                     v
37//!                  └───────> PrepareAbort ───────> CompleteAbort
38//!                                    │                     │
39//!                                    └─────────────────────┘
40//! ```
41//!
42//! ## Exactly-Once Guarantees
43//!
44//! 1. **Atomic Writes**: All messages in a transaction are committed or aborted together
45//! 2. **Consumer Isolation**: Consumers only see committed messages (read_committed)
46//! 3. **Fencing**: Old producer instances are fenced via epoch
47//! 4. **Durability**: Transaction state is persisted before acknowledgment
48//!
49
50use crate::idempotent::{ProducerEpoch, ProducerId};
51use serde::{Deserialize, Serialize};
52use std::collections::{HashMap, HashSet};
53use std::sync::atomic::{AtomicU64, Ordering};
54use std::sync::RwLock;
55use std::time::{Duration, Instant, SystemTime};
56
57/// Unique identifier for a transaction
58pub type TransactionId = String;
59
60/// Transaction timeout default (1 minute)
61pub const DEFAULT_TRANSACTION_TIMEOUT: Duration = Duration::from_secs(60);
62
63/// Maximum pending transactions per producer
64pub const MAX_PENDING_TRANSACTIONS: usize = 5;
65
66/// Transaction state machine
67#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
68pub enum TransactionState {
69    /// No active transaction
70    Empty,
71
72    /// Transaction in progress, accepting writes
73    Ongoing,
74
75    /// Preparing to commit (2PC phase 1)
76    PrepareCommit,
77
78    /// Preparing to abort (2PC phase 1)
79    PrepareAbort,
80
81    /// Commit complete (2PC phase 2)
82    CompleteCommit,
83
84    /// Abort complete (2PC phase 2)
85    CompleteAbort,
86
87    /// Transaction has expired without completion
88    Dead,
89}
90
91impl TransactionState {
92    /// Check if transaction is in a terminal state
93    pub fn is_terminal(&self) -> bool {
94        matches!(
95            self,
96            TransactionState::Empty
97                | TransactionState::CompleteCommit
98                | TransactionState::CompleteAbort
99                | TransactionState::Dead
100        )
101    }
102
103    /// Check if transaction is still active (can accept writes)
104    pub fn is_active(&self) -> bool {
105        matches!(self, TransactionState::Ongoing)
106    }
107
108    /// Check if transaction can transition to commit
109    pub fn can_commit(&self) -> bool {
110        matches!(self, TransactionState::Ongoing)
111    }
112
113    /// Check if transaction can transition to abort
114    pub fn can_abort(&self) -> bool {
115        matches!(
116            self,
117            TransactionState::Ongoing
118                | TransactionState::PrepareCommit
119                | TransactionState::PrepareAbort
120        )
121    }
122}
123
124/// Result of a transaction operation
125#[derive(Debug, Clone, PartialEq, Eq)]
126pub enum TransactionResult {
127    /// Operation succeeded
128    Ok,
129
130    /// Transaction ID is invalid or not found
131    InvalidTransactionId,
132
133    /// Transaction is in wrong state for this operation
134    InvalidTransactionState {
135        current: TransactionState,
136        expected: &'static str,
137    },
138
139    /// Producer ID/epoch mismatch
140    ProducerFenced {
141        expected_epoch: ProducerEpoch,
142        received_epoch: ProducerEpoch,
143    },
144
145    /// Transaction has timed out
146    TransactionTimeout,
147
148    /// Too many pending transactions
149    TooManyTransactions,
150
151    /// Concurrent modification detected
152    ConcurrentTransaction,
153
154    /// Partition not part of transaction
155    PartitionNotInTransaction { topic: String, partition: u32 },
156}
157
158/// A partition involved in a transaction
159#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
160pub struct TransactionPartition {
161    pub topic: String,
162    pub partition: u32,
163}
164
165impl TransactionPartition {
166    pub fn new(topic: impl Into<String>, partition: u32) -> Self {
167        Self {
168            topic: topic.into(),
169            partition,
170        }
171    }
172}
173
174/// Pending write in a transaction (not yet committed)
175#[derive(Debug, Clone, Serialize, Deserialize)]
176pub struct PendingWrite {
177    /// Target partition
178    pub partition: TransactionPartition,
179
180    /// Sequence number for this write
181    pub sequence: i32,
182
183    /// Offset assigned by the partition leader
184    pub offset: u64,
185
186    /// Write timestamp
187    #[serde(with = "crate::serde_utils::system_time")]
188    pub timestamp: SystemTime,
189}
190
191/// Consumer offset to be committed with the transaction
192#[derive(Debug, Clone, Serialize, Deserialize)]
193pub struct TransactionOffsetCommit {
194    /// Consumer group
195    pub group_id: String,
196
197    /// Topic-partition-offset triples
198    pub offsets: Vec<(TransactionPartition, i64)>,
199}
200
201/// Active transaction state
202#[derive(Debug, Clone, Serialize, Deserialize)]
203pub struct Transaction {
204    /// Transaction ID (unique per producer)
205    pub txn_id: TransactionId,
206
207    /// Producer ID owning this transaction
208    pub producer_id: ProducerId,
209
210    /// Producer epoch (for fencing)
211    pub producer_epoch: ProducerEpoch,
212
213    /// Current state
214    pub state: TransactionState,
215
216    /// Partitions involved in this transaction
217    pub partitions: HashSet<TransactionPartition>,
218
219    /// Pending writes (not yet committed)
220    pub pending_writes: Vec<PendingWrite>,
221
222    /// Consumer offsets to commit with this transaction
223    pub offset_commits: Vec<TransactionOffsetCommit>,
224
225    /// Transaction start time
226    #[serde(with = "crate::serde_utils::system_time")]
227    pub started_at: SystemTime,
228
229    /// Transaction timeout
230    #[serde(with = "crate::serde_utils::duration")]
231    pub timeout: Duration,
232
233    /// Last activity timestamp
234    #[serde(skip)]
235    pub last_activity: Option<Instant>,
236}
237
238impl Transaction {
239    /// Create a new transaction
240    pub fn new(
241        txn_id: TransactionId,
242        producer_id: ProducerId,
243        producer_epoch: ProducerEpoch,
244        timeout: Duration,
245    ) -> Self {
246        Self {
247            txn_id,
248            producer_id,
249            producer_epoch,
250            state: TransactionState::Ongoing,
251            partitions: HashSet::new(),
252            pending_writes: Vec::new(),
253            offset_commits: Vec::new(),
254            started_at: SystemTime::now(),
255            timeout,
256            last_activity: Some(Instant::now()),
257        }
258    }
259
260    /// Check if transaction has timed out
261    pub fn is_timed_out(&self) -> bool {
262        self.last_activity
263            .map(|t| t.elapsed() > self.timeout)
264            .unwrap_or(true)
265    }
266
267    /// Update last activity timestamp
268    pub fn touch(&mut self) {
269        self.last_activity = Some(Instant::now());
270    }
271
272    /// Add a partition to the transaction
273    pub fn add_partition(&mut self, partition: TransactionPartition) {
274        self.partitions.insert(partition);
275        self.touch();
276    }
277
278    /// Record a pending write
279    pub fn add_write(&mut self, partition: TransactionPartition, sequence: i32, offset: u64) {
280        self.pending_writes.push(PendingWrite {
281            partition,
282            sequence,
283            offset,
284            timestamp: SystemTime::now(),
285        });
286        self.touch();
287    }
288
289    /// Add consumer offset commit
290    pub fn add_offset_commit(
291        &mut self,
292        group_id: String,
293        offsets: Vec<(TransactionPartition, i64)>,
294    ) {
295        self.offset_commits
296            .push(TransactionOffsetCommit { group_id, offsets });
297        self.touch();
298    }
299
300    /// Get total number of writes
301    pub fn write_count(&self) -> usize {
302        self.pending_writes.len()
303    }
304
305    /// Get all affected partitions
306    pub fn affected_partitions(&self) -> impl Iterator<Item = &TransactionPartition> {
307        self.partitions.iter()
308    }
309}
310
311/// Transaction marker type written to partition logs
312#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
313pub enum TransactionMarker {
314    /// Transaction committed
315    Commit,
316
317    /// Transaction aborted
318    Abort,
319}
320
321/// Statistics for transaction coordinator
322#[derive(Debug, Default)]
323pub struct TransactionStats {
324    /// Total transactions initiated
325    transactions_started: AtomicU64,
326
327    /// Total transactions committed
328    transactions_committed: AtomicU64,
329
330    /// Total transactions aborted
331    transactions_aborted: AtomicU64,
332
333    /// Total transactions timed out
334    transactions_timed_out: AtomicU64,
335
336    /// Currently active transactions
337    active_transactions: AtomicU64,
338}
339
340impl TransactionStats {
341    pub fn new() -> Self {
342        Self::default()
343    }
344
345    pub fn record_start(&self) {
346        self.transactions_started.fetch_add(1, Ordering::Relaxed);
347        self.active_transactions.fetch_add(1, Ordering::Relaxed);
348    }
349
350    pub fn record_commit(&self) {
351        self.transactions_committed.fetch_add(1, Ordering::Relaxed);
352        self.active_transactions.fetch_sub(1, Ordering::Relaxed);
353    }
354
355    pub fn record_abort(&self) {
356        self.transactions_aborted.fetch_add(1, Ordering::Relaxed);
357        self.active_transactions.fetch_sub(1, Ordering::Relaxed);
358    }
359
360    pub fn record_timeout(&self) {
361        self.transactions_timed_out.fetch_add(1, Ordering::Relaxed);
362        self.active_transactions.fetch_sub(1, Ordering::Relaxed);
363    }
364
365    pub fn transactions_started(&self) -> u64 {
366        self.transactions_started.load(Ordering::Relaxed)
367    }
368
369    pub fn transactions_committed(&self) -> u64 {
370        self.transactions_committed.load(Ordering::Relaxed)
371    }
372
373    pub fn transactions_aborted(&self) -> u64 {
374        self.transactions_aborted.load(Ordering::Relaxed)
375    }
376
377    pub fn transactions_timed_out(&self) -> u64 {
378        self.transactions_timed_out.load(Ordering::Relaxed)
379    }
380
381    pub fn active_transactions(&self) -> u64 {
382        self.active_transactions.load(Ordering::Relaxed)
383    }
384}
385
386/// Snapshot of transaction stats for serialization
387#[derive(Debug, Clone, Serialize, Deserialize)]
388pub struct TransactionStatsSnapshot {
389    pub transactions_started: u64,
390    pub transactions_committed: u64,
391    pub transactions_aborted: u64,
392    pub transactions_timed_out: u64,
393    pub active_transactions: u64,
394}
395
396impl From<&TransactionStats> for TransactionStatsSnapshot {
397    fn from(stats: &TransactionStats) -> Self {
398        Self {
399            transactions_started: stats.transactions_started(),
400            transactions_committed: stats.transactions_committed(),
401            transactions_aborted: stats.transactions_aborted(),
402            transactions_timed_out: stats.transactions_timed_out(),
403            active_transactions: stats.active_transactions(),
404        }
405    }
406}
407
408/// Transaction coordinator manages all active transactions
409///
410/// This is a per-broker component that tracks transactions for producers
411/// assigned to this broker as their transaction coordinator.
412pub struct TransactionCoordinator {
413    /// Active transactions by (producer_id, txn_id)
414    transactions: RwLock<HashMap<(ProducerId, TransactionId), Transaction>>,
415
416    /// Producer to transaction mapping (for single-txn-per-producer enforcement)
417    producer_transactions: RwLock<HashMap<ProducerId, TransactionId>>,
418
419    /// Default transaction timeout
420    default_timeout: Duration,
421
422    /// Statistics
423    stats: TransactionStats,
424}
425
426impl Default for TransactionCoordinator {
427    fn default() -> Self {
428        Self::new()
429    }
430}
431
432impl TransactionCoordinator {
433    /// Create a new transaction coordinator
434    pub fn new() -> Self {
435        Self {
436            transactions: RwLock::new(HashMap::new()),
437            producer_transactions: RwLock::new(HashMap::new()),
438            default_timeout: DEFAULT_TRANSACTION_TIMEOUT,
439            stats: TransactionStats::new(),
440        }
441    }
442
443    /// Create with custom default timeout
444    pub fn with_timeout(timeout: Duration) -> Self {
445        Self {
446            transactions: RwLock::new(HashMap::new()),
447            producer_transactions: RwLock::new(HashMap::new()),
448            default_timeout: timeout,
449            stats: TransactionStats::new(),
450        }
451    }
452
453    /// Get statistics
454    pub fn stats(&self) -> &TransactionStats {
455        &self.stats
456    }
457
458    /// Begin a new transaction
459    pub fn begin_transaction(
460        &self,
461        txn_id: TransactionId,
462        producer_id: ProducerId,
463        producer_epoch: ProducerEpoch,
464        timeout: Option<Duration>,
465    ) -> TransactionResult {
466        // Check if producer already has an active transaction
467        {
468            let producer_txns = self
469                .producer_transactions
470                .read()
471                .expect("transaction manager lock poisoned");
472            if let Some(existing_txn_id) = producer_txns.get(&producer_id) {
473                if existing_txn_id != &txn_id {
474                    return TransactionResult::ConcurrentTransaction;
475                }
476                // Same txn_id - check if we're resuming
477                let transactions = self
478                    .transactions
479                    .read()
480                    .expect("transaction manager lock poisoned");
481                if let Some(txn) = transactions.get(&(producer_id, txn_id.clone())) {
482                    if txn.producer_epoch != producer_epoch {
483                        return TransactionResult::ProducerFenced {
484                            expected_epoch: txn.producer_epoch,
485                            received_epoch: producer_epoch,
486                        };
487                    }
488                    if txn.state.is_active() {
489                        return TransactionResult::Ok; // Already active
490                    }
491                }
492            }
493        }
494
495        // Create new transaction
496        let txn = Transaction::new(
497            txn_id.clone(),
498            producer_id,
499            producer_epoch,
500            timeout.unwrap_or(self.default_timeout),
501        );
502
503        {
504            let mut transactions = self
505                .transactions
506                .write()
507                .expect("transaction manager lock poisoned");
508            let mut producer_txns = self
509                .producer_transactions
510                .write()
511                .expect("transaction manager lock poisoned");
512
513            transactions.insert((producer_id, txn_id.clone()), txn);
514            producer_txns.insert(producer_id, txn_id);
515        }
516
517        self.stats.record_start();
518        TransactionResult::Ok
519    }
520
521    /// Add partitions to an active transaction
522    pub fn add_partitions_to_transaction(
523        &self,
524        txn_id: &TransactionId,
525        producer_id: ProducerId,
526        producer_epoch: ProducerEpoch,
527        partitions: Vec<TransactionPartition>,
528    ) -> TransactionResult {
529        let mut transactions = self
530            .transactions
531            .write()
532            .expect("transaction manager lock poisoned");
533
534        let txn = match transactions.get_mut(&(producer_id, txn_id.clone())) {
535            Some(t) => t,
536            None => return TransactionResult::InvalidTransactionId,
537        };
538
539        // Validate epoch
540        if txn.producer_epoch != producer_epoch {
541            return TransactionResult::ProducerFenced {
542                expected_epoch: txn.producer_epoch,
543                received_epoch: producer_epoch,
544            };
545        }
546
547        // Check state
548        if !txn.state.is_active() {
549            return TransactionResult::InvalidTransactionState {
550                current: txn.state,
551                expected: "Ongoing",
552            };
553        }
554
555        // Check timeout
556        if txn.is_timed_out() {
557            txn.state = TransactionState::Dead;
558            self.stats.record_timeout();
559            return TransactionResult::TransactionTimeout;
560        }
561
562        // Add partitions
563        for partition in partitions {
564            txn.add_partition(partition);
565        }
566
567        TransactionResult::Ok
568    }
569
570    /// Record a write within a transaction
571    pub fn add_write_to_transaction(
572        &self,
573        txn_id: &TransactionId,
574        producer_id: ProducerId,
575        producer_epoch: ProducerEpoch,
576        partition: TransactionPartition,
577        sequence: i32,
578        offset: u64,
579    ) -> TransactionResult {
580        let mut transactions = self
581            .transactions
582            .write()
583            .expect("transaction manager lock poisoned");
584
585        let txn = match transactions.get_mut(&(producer_id, txn_id.clone())) {
586            Some(t) => t,
587            None => return TransactionResult::InvalidTransactionId,
588        };
589
590        // Validate epoch
591        if txn.producer_epoch != producer_epoch {
592            return TransactionResult::ProducerFenced {
593                expected_epoch: txn.producer_epoch,
594                received_epoch: producer_epoch,
595            };
596        }
597
598        // Check state
599        if !txn.state.is_active() {
600            return TransactionResult::InvalidTransactionState {
601                current: txn.state,
602                expected: "Ongoing",
603            };
604        }
605
606        // Check timeout
607        if txn.is_timed_out() {
608            txn.state = TransactionState::Dead;
609            self.stats.record_timeout();
610            return TransactionResult::TransactionTimeout;
611        }
612
613        // Verify partition is part of transaction
614        if !txn.partitions.contains(&partition) {
615            return TransactionResult::PartitionNotInTransaction {
616                topic: partition.topic,
617                partition: partition.partition,
618            };
619        }
620
621        // Record the write
622        txn.add_write(partition, sequence, offset);
623
624        TransactionResult::Ok
625    }
626
627    /// Add consumer offset commit to transaction (for exactly-once consume-transform-produce)
628    pub fn add_offsets_to_transaction(
629        &self,
630        txn_id: &TransactionId,
631        producer_id: ProducerId,
632        producer_epoch: ProducerEpoch,
633        group_id: String,
634        offsets: Vec<(TransactionPartition, i64)>,
635    ) -> TransactionResult {
636        let mut transactions = self
637            .transactions
638            .write()
639            .expect("transaction manager lock poisoned");
640
641        let txn = match transactions.get_mut(&(producer_id, txn_id.clone())) {
642            Some(t) => t,
643            None => return TransactionResult::InvalidTransactionId,
644        };
645
646        // Validate epoch
647        if txn.producer_epoch != producer_epoch {
648            return TransactionResult::ProducerFenced {
649                expected_epoch: txn.producer_epoch,
650                received_epoch: producer_epoch,
651            };
652        }
653
654        // Check state
655        if !txn.state.is_active() {
656            return TransactionResult::InvalidTransactionState {
657                current: txn.state,
658                expected: "Ongoing",
659            };
660        }
661
662        // Check timeout
663        if txn.is_timed_out() {
664            txn.state = TransactionState::Dead;
665            self.stats.record_timeout();
666            return TransactionResult::TransactionTimeout;
667        }
668
669        // Add offset commit
670        txn.add_offset_commit(group_id, offsets);
671
672        TransactionResult::Ok
673    }
674
675    /// Prepare to commit a transaction (2PC phase 1)
676    ///
677    /// Returns the transaction data needed for committing to partitions
678    pub fn prepare_commit(
679        &self,
680        txn_id: &TransactionId,
681        producer_id: ProducerId,
682        producer_epoch: ProducerEpoch,
683    ) -> Result<Transaction, TransactionResult> {
684        let mut transactions = self
685            .transactions
686            .write()
687            .expect("transaction manager lock poisoned");
688
689        let txn = match transactions.get_mut(&(producer_id, txn_id.clone())) {
690            Some(t) => t,
691            None => return Err(TransactionResult::InvalidTransactionId),
692        };
693
694        // Validate epoch
695        if txn.producer_epoch != producer_epoch {
696            return Err(TransactionResult::ProducerFenced {
697                expected_epoch: txn.producer_epoch,
698                received_epoch: producer_epoch,
699            });
700        }
701
702        // Check state
703        if !txn.state.can_commit() {
704            return Err(TransactionResult::InvalidTransactionState {
705                current: txn.state,
706                expected: "Ongoing",
707            });
708        }
709
710        // Check timeout
711        if txn.is_timed_out() {
712            txn.state = TransactionState::Dead;
713            self.stats.record_timeout();
714            return Err(TransactionResult::TransactionTimeout);
715        }
716
717        // Transition to PrepareCommit
718        txn.state = TransactionState::PrepareCommit;
719        txn.touch();
720
721        Ok(txn.clone())
722    }
723
724    /// Complete the commit (2PC phase 2)
725    pub fn complete_commit(
726        &self,
727        txn_id: &TransactionId,
728        producer_id: ProducerId,
729    ) -> TransactionResult {
730        let mut transactions = self
731            .transactions
732            .write()
733            .expect("transaction manager lock poisoned");
734        let mut producer_txns = self
735            .producer_transactions
736            .write()
737            .expect("transaction manager lock poisoned");
738
739        let txn = match transactions.get_mut(&(producer_id, txn_id.clone())) {
740            Some(t) => t,
741            None => return TransactionResult::InvalidTransactionId,
742        };
743
744        if txn.state != TransactionState::PrepareCommit {
745            return TransactionResult::InvalidTransactionState {
746                current: txn.state,
747                expected: "PrepareCommit",
748            };
749        }
750
751        txn.state = TransactionState::CompleteCommit;
752
753        // Clean up
754        transactions.remove(&(producer_id, txn_id.clone()));
755        producer_txns.remove(&producer_id);
756
757        self.stats.record_commit();
758        TransactionResult::Ok
759    }
760
761    /// Prepare to abort a transaction (2PC phase 1)
762    pub fn prepare_abort(
763        &self,
764        txn_id: &TransactionId,
765        producer_id: ProducerId,
766        producer_epoch: ProducerEpoch,
767    ) -> Result<Transaction, TransactionResult> {
768        let mut transactions = self
769            .transactions
770            .write()
771            .expect("transaction manager lock poisoned");
772
773        let txn = match transactions.get_mut(&(producer_id, txn_id.clone())) {
774            Some(t) => t,
775            None => return Err(TransactionResult::InvalidTransactionId),
776        };
777
778        // Validate epoch
779        if txn.producer_epoch != producer_epoch {
780            return Err(TransactionResult::ProducerFenced {
781                expected_epoch: txn.producer_epoch,
782                received_epoch: producer_epoch,
783            });
784        }
785
786        // Check state - abort is allowed from more states than commit
787        if !txn.state.can_abort() {
788            return Err(TransactionResult::InvalidTransactionState {
789                current: txn.state,
790                expected: "Ongoing or PrepareCommit",
791            });
792        }
793
794        // Transition to PrepareAbort
795        txn.state = TransactionState::PrepareAbort;
796        txn.touch();
797
798        Ok(txn.clone())
799    }
800
801    /// Complete the abort (2PC phase 2)
802    pub fn complete_abort(
803        &self,
804        txn_id: &TransactionId,
805        producer_id: ProducerId,
806    ) -> TransactionResult {
807        let mut transactions = self
808            .transactions
809            .write()
810            .expect("transaction manager lock poisoned");
811        let mut producer_txns = self
812            .producer_transactions
813            .write()
814            .expect("transaction manager lock poisoned");
815
816        let txn = match transactions.get_mut(&(producer_id, txn_id.clone())) {
817            Some(t) => t,
818            None => return TransactionResult::InvalidTransactionId,
819        };
820
821        if txn.state != TransactionState::PrepareAbort {
822            return TransactionResult::InvalidTransactionState {
823                current: txn.state,
824                expected: "PrepareAbort",
825            };
826        }
827
828        txn.state = TransactionState::CompleteAbort;
829
830        // Clean up
831        transactions.remove(&(producer_id, txn_id.clone()));
832        producer_txns.remove(&producer_id);
833
834        self.stats.record_abort();
835        TransactionResult::Ok
836    }
837
838    /// Get current transaction state for a producer
839    pub fn get_transaction(
840        &self,
841        txn_id: &TransactionId,
842        producer_id: ProducerId,
843    ) -> Option<Transaction> {
844        let transactions = self
845            .transactions
846            .read()
847            .expect("transaction manager lock poisoned");
848        transactions.get(&(producer_id, txn_id.clone())).cloned()
849    }
850
851    /// Check if a producer has an active transaction
852    pub fn has_active_transaction(&self, producer_id: ProducerId) -> bool {
853        let producer_txns = self
854            .producer_transactions
855            .read()
856            .expect("transaction manager lock poisoned");
857        producer_txns.contains_key(&producer_id)
858    }
859
860    /// Get active transaction ID for a producer
861    pub fn get_active_transaction_id(&self, producer_id: ProducerId) -> Option<TransactionId> {
862        let producer_txns = self
863            .producer_transactions
864            .read()
865            .expect("transaction manager lock poisoned");
866        producer_txns.get(&producer_id).cloned()
867    }
868
869    /// Clean up timed-out transactions
870    pub fn cleanup_timed_out_transactions(&self) -> Vec<Transaction> {
871        let mut timed_out = Vec::new();
872        let mut transactions = self
873            .transactions
874            .write()
875            .expect("transaction manager lock poisoned");
876        let mut producer_txns = self
877            .producer_transactions
878            .write()
879            .expect("transaction manager lock poisoned");
880
881        let keys_to_remove: Vec<_> = transactions
882            .iter()
883            .filter(|(_, txn)| txn.is_timed_out() && !txn.state.is_terminal())
884            .map(|(k, _)| k.clone())
885            .collect();
886
887        for key in keys_to_remove {
888            if let Some(mut txn) = transactions.remove(&key) {
889                txn.state = TransactionState::Dead;
890                producer_txns.remove(&txn.producer_id);
891                self.stats.record_timeout();
892                timed_out.push(txn);
893            }
894        }
895
896        timed_out
897    }
898
899    /// Get number of active transactions
900    pub fn active_count(&self) -> usize {
901        let transactions = self
902            .transactions
903            .read()
904            .expect("transaction manager lock poisoned");
905        transactions
906            .values()
907            .filter(|t| !t.state.is_terminal())
908            .count()
909    }
910}
911
912// ============================================================================
913// Tests
914// ============================================================================
915
916#[cfg(test)]
917mod tests {
918    use super::*;
919
920    #[test]
921    fn test_transaction_state_transitions() {
922        // Test terminal states
923        assert!(TransactionState::Empty.is_terminal());
924        assert!(TransactionState::CompleteCommit.is_terminal());
925        assert!(TransactionState::CompleteAbort.is_terminal());
926        assert!(TransactionState::Dead.is_terminal());
927
928        // Test active states
929        assert!(!TransactionState::Ongoing.is_terminal());
930        assert!(!TransactionState::PrepareCommit.is_terminal());
931        assert!(!TransactionState::PrepareAbort.is_terminal());
932
933        // Test can_commit
934        assert!(TransactionState::Ongoing.can_commit());
935        assert!(!TransactionState::Empty.can_commit());
936        assert!(!TransactionState::PrepareCommit.can_commit());
937
938        // Test can_abort
939        assert!(TransactionState::Ongoing.can_abort());
940        assert!(TransactionState::PrepareCommit.can_abort());
941        assert!(TransactionState::PrepareAbort.can_abort());
942        assert!(!TransactionState::Empty.can_abort());
943    }
944
945    #[test]
946    fn test_begin_transaction() {
947        let coordinator = TransactionCoordinator::new();
948
949        // Begin first transaction
950        let result = coordinator.begin_transaction("txn-1".to_string(), 1, 0, None);
951        assert_eq!(result, TransactionResult::Ok);
952
953        // Verify transaction exists
954        let txn = coordinator.get_transaction(&"txn-1".to_string(), 1);
955        assert!(txn.is_some());
956        let txn = txn.unwrap();
957        assert_eq!(txn.state, TransactionState::Ongoing);
958        assert_eq!(txn.producer_id, 1);
959        assert_eq!(txn.producer_epoch, 0);
960
961        // Stats
962        assert_eq!(coordinator.stats().transactions_started(), 1);
963        assert_eq!(coordinator.stats().active_transactions(), 1);
964    }
965
966    #[test]
967    fn test_concurrent_transaction_rejection() {
968        let coordinator = TransactionCoordinator::new();
969
970        // Begin first transaction
971        coordinator.begin_transaction("txn-1".to_string(), 1, 0, None);
972
973        // Try to begin another transaction for same producer
974        let result = coordinator.begin_transaction("txn-2".to_string(), 1, 0, None);
975        assert_eq!(result, TransactionResult::ConcurrentTransaction);
976    }
977
978    #[test]
979    fn test_add_partitions_to_transaction() {
980        let coordinator = TransactionCoordinator::new();
981        coordinator.begin_transaction("txn-1".to_string(), 1, 0, None);
982
983        // Add partitions
984        let result = coordinator.add_partitions_to_transaction(
985            &"txn-1".to_string(),
986            1,
987            0,
988            vec![
989                TransactionPartition::new("topic-1", 0),
990                TransactionPartition::new("topic-1", 1),
991                TransactionPartition::new("topic-2", 0),
992            ],
993        );
994        assert_eq!(result, TransactionResult::Ok);
995
996        // Verify partitions added
997        let txn = coordinator
998            .get_transaction(&"txn-1".to_string(), 1)
999            .unwrap();
1000        assert_eq!(txn.partitions.len(), 3);
1001    }
1002
1003    #[test]
1004    fn test_add_write_to_transaction() {
1005        let coordinator = TransactionCoordinator::new();
1006        coordinator.begin_transaction("txn-1".to_string(), 1, 0, None);
1007
1008        let partition = TransactionPartition::new("topic-1", 0);
1009        coordinator.add_partitions_to_transaction(
1010            &"txn-1".to_string(),
1011            1,
1012            0,
1013            vec![partition.clone()],
1014        );
1015
1016        // Record write
1017        let result =
1018            coordinator.add_write_to_transaction(&"txn-1".to_string(), 1, 0, partition, 0, 100);
1019        assert_eq!(result, TransactionResult::Ok);
1020
1021        // Verify write recorded
1022        let txn = coordinator
1023            .get_transaction(&"txn-1".to_string(), 1)
1024            .unwrap();
1025        assert_eq!(txn.pending_writes.len(), 1);
1026        assert_eq!(txn.pending_writes[0].offset, 100);
1027        assert_eq!(txn.pending_writes[0].sequence, 0);
1028    }
1029
1030    #[test]
1031    fn test_write_to_non_registered_partition() {
1032        let coordinator = TransactionCoordinator::new();
1033        coordinator.begin_transaction("txn-1".to_string(), 1, 0, None);
1034
1035        // Try to write to partition not added to transaction
1036        let result = coordinator.add_write_to_transaction(
1037            &"txn-1".to_string(),
1038            1,
1039            0,
1040            TransactionPartition::new("topic-1", 0),
1041            0,
1042            100,
1043        );
1044
1045        assert!(matches!(
1046            result,
1047            TransactionResult::PartitionNotInTransaction { .. }
1048        ));
1049    }
1050
1051    #[test]
1052    fn test_commit_transaction() {
1053        let coordinator = TransactionCoordinator::new();
1054        coordinator.begin_transaction("txn-1".to_string(), 1, 0, None);
1055
1056        let partition = TransactionPartition::new("topic-1", 0);
1057        coordinator.add_partitions_to_transaction(
1058            &"txn-1".to_string(),
1059            1,
1060            0,
1061            vec![partition.clone()],
1062        );
1063        coordinator.add_write_to_transaction(&"txn-1".to_string(), 1, 0, partition, 0, 100);
1064
1065        // Prepare commit
1066        let txn = coordinator.prepare_commit(&"txn-1".to_string(), 1, 0);
1067        assert!(txn.is_ok());
1068        let txn = txn.unwrap();
1069        assert_eq!(txn.state, TransactionState::PrepareCommit);
1070
1071        // Complete commit
1072        let result = coordinator.complete_commit(&"txn-1".to_string(), 1);
1073        assert_eq!(result, TransactionResult::Ok);
1074
1075        // Transaction should be removed
1076        assert!(coordinator
1077            .get_transaction(&"txn-1".to_string(), 1)
1078            .is_none());
1079        assert!(!coordinator.has_active_transaction(1));
1080
1081        // Stats
1082        assert_eq!(coordinator.stats().transactions_committed(), 1);
1083        assert_eq!(coordinator.stats().active_transactions(), 0);
1084    }
1085
1086    #[test]
1087    fn test_abort_transaction() {
1088        let coordinator = TransactionCoordinator::new();
1089        coordinator.begin_transaction("txn-1".to_string(), 1, 0, None);
1090
1091        let partition = TransactionPartition::new("topic-1", 0);
1092        coordinator.add_partitions_to_transaction(
1093            &"txn-1".to_string(),
1094            1,
1095            0,
1096            vec![partition.clone()],
1097        );
1098        coordinator.add_write_to_transaction(&"txn-1".to_string(), 1, 0, partition, 0, 100);
1099
1100        // Prepare abort
1101        let txn = coordinator.prepare_abort(&"txn-1".to_string(), 1, 0);
1102        assert!(txn.is_ok());
1103
1104        // Complete abort
1105        let result = coordinator.complete_abort(&"txn-1".to_string(), 1);
1106        assert_eq!(result, TransactionResult::Ok);
1107
1108        // Transaction should be removed
1109        assert!(coordinator
1110            .get_transaction(&"txn-1".to_string(), 1)
1111            .is_none());
1112
1113        // Stats
1114        assert_eq!(coordinator.stats().transactions_aborted(), 1);
1115    }
1116
1117    #[test]
1118    fn test_producer_fencing() {
1119        let coordinator = TransactionCoordinator::new();
1120        coordinator.begin_transaction("txn-1".to_string(), 1, 0, None);
1121
1122        // Try with wrong epoch
1123        let result = coordinator.add_partitions_to_transaction(
1124            &"txn-1".to_string(),
1125            1,
1126            1, // Wrong epoch
1127            vec![TransactionPartition::new("topic-1", 0)],
1128        );
1129
1130        assert!(matches!(
1131            result,
1132            TransactionResult::ProducerFenced {
1133                expected_epoch: 0,
1134                received_epoch: 1
1135            }
1136        ));
1137    }
1138
1139    #[test]
1140    fn test_transaction_timeout() {
1141        // Create coordinator with very short timeout
1142        let coordinator = TransactionCoordinator::with_timeout(Duration::from_millis(1));
1143        coordinator.begin_transaction("txn-1".to_string(), 1, 0, None);
1144
1145        // Wait for timeout
1146        std::thread::sleep(Duration::from_millis(5));
1147
1148        // Try to add partitions - should fail with timeout
1149        let result = coordinator.add_partitions_to_transaction(
1150            &"txn-1".to_string(),
1151            1,
1152            0,
1153            vec![TransactionPartition::new("topic-1", 0)],
1154        );
1155
1156        assert_eq!(result, TransactionResult::TransactionTimeout);
1157    }
1158
1159    #[test]
1160    fn test_cleanup_timed_out_transactions() {
1161        let coordinator = TransactionCoordinator::with_timeout(Duration::from_millis(1));
1162
1163        coordinator.begin_transaction("txn-1".to_string(), 1, 0, None);
1164        coordinator.begin_transaction("txn-2".to_string(), 2, 0, None);
1165
1166        // Wait for timeout
1167        std::thread::sleep(Duration::from_millis(5));
1168
1169        // Cleanup
1170        let timed_out = coordinator.cleanup_timed_out_transactions();
1171        assert_eq!(timed_out.len(), 2);
1172
1173        // Transactions should be gone
1174        assert_eq!(coordinator.active_count(), 0);
1175        assert_eq!(coordinator.stats().transactions_timed_out(), 2);
1176    }
1177
1178    #[test]
1179    fn test_add_offsets_to_transaction() {
1180        let coordinator = TransactionCoordinator::new();
1181        coordinator.begin_transaction("txn-1".to_string(), 1, 0, None);
1182
1183        // Add consumer offsets
1184        let result = coordinator.add_offsets_to_transaction(
1185            &"txn-1".to_string(),
1186            1,
1187            0,
1188            "consumer-group-1".to_string(),
1189            vec![
1190                (TransactionPartition::new("input-topic", 0), 42),
1191                (TransactionPartition::new("input-topic", 1), 100),
1192            ],
1193        );
1194        assert_eq!(result, TransactionResult::Ok);
1195
1196        // Verify
1197        let txn = coordinator
1198            .get_transaction(&"txn-1".to_string(), 1)
1199            .unwrap();
1200        assert_eq!(txn.offset_commits.len(), 1);
1201        assert_eq!(txn.offset_commits[0].group_id, "consumer-group-1");
1202        assert_eq!(txn.offset_commits[0].offsets.len(), 2);
1203    }
1204
1205    #[test]
1206    fn test_invalid_state_transitions() {
1207        let coordinator = TransactionCoordinator::new();
1208        coordinator.begin_transaction("txn-1".to_string(), 1, 0, None);
1209
1210        // Prepare commit
1211        coordinator
1212            .prepare_commit(&"txn-1".to_string(), 1, 0)
1213            .unwrap();
1214
1215        // Try to add partitions after prepare - should fail
1216        let result = coordinator.add_partitions_to_transaction(
1217            &"txn-1".to_string(),
1218            1,
1219            0,
1220            vec![TransactionPartition::new("topic-1", 0)],
1221        );
1222        assert!(matches!(
1223            result,
1224            TransactionResult::InvalidTransactionState { .. }
1225        ));
1226    }
1227
1228    #[test]
1229    fn test_abort_from_prepare_commit() {
1230        let coordinator = TransactionCoordinator::new();
1231        coordinator.begin_transaction("txn-1".to_string(), 1, 0, None);
1232
1233        // Prepare commit
1234        coordinator
1235            .prepare_commit(&"txn-1".to_string(), 1, 0)
1236            .unwrap();
1237
1238        // Abort should still be allowed from PrepareCommit
1239        let result = coordinator.prepare_abort(&"txn-1".to_string(), 1, 0);
1240        assert!(result.is_ok());
1241
1242        let result = coordinator.complete_abort(&"txn-1".to_string(), 1);
1243        assert_eq!(result, TransactionResult::Ok);
1244    }
1245
1246    #[test]
1247    fn test_transaction_partition_hash() {
1248        let p1 = TransactionPartition::new("topic", 0);
1249        let p2 = TransactionPartition::new("topic", 0);
1250        let p3 = TransactionPartition::new("topic", 1);
1251
1252        assert_eq!(p1, p2);
1253        assert_ne!(p1, p3);
1254
1255        let mut set = HashSet::new();
1256        set.insert(p1.clone());
1257        set.insert(p2); // Should not add (duplicate)
1258        set.insert(p3);
1259        assert_eq!(set.len(), 2);
1260    }
1261
1262    #[test]
1263    fn test_resume_same_transaction() {
1264        let coordinator = TransactionCoordinator::new();
1265
1266        // Begin transaction
1267        coordinator.begin_transaction("txn-1".to_string(), 1, 0, None);
1268
1269        // Try to begin same transaction again - should succeed (idempotent)
1270        let result = coordinator.begin_transaction("txn-1".to_string(), 1, 0, None);
1271        assert_eq!(result, TransactionResult::Ok);
1272
1273        // Only one transaction should exist
1274        assert_eq!(coordinator.active_count(), 1);
1275        assert_eq!(coordinator.stats().transactions_started(), 1);
1276    }
1277
1278    #[test]
1279    fn test_stats_snapshot() {
1280        let coordinator = TransactionCoordinator::new();
1281        coordinator.begin_transaction("txn-1".to_string(), 1, 0, None);
1282        coordinator
1283            .prepare_commit(&"txn-1".to_string(), 1, 0)
1284            .unwrap();
1285        coordinator.complete_commit(&"txn-1".to_string(), 1);
1286
1287        let snapshot: TransactionStatsSnapshot = coordinator.stats().into();
1288        assert_eq!(snapshot.transactions_started, 1);
1289        assert_eq!(snapshot.transactions_committed, 1);
1290        assert_eq!(snapshot.active_transactions, 0);
1291    }
1292}