1use crate::idempotent::{ProducerEpoch, ProducerId};
51use serde::{Deserialize, Serialize};
52use std::collections::{HashMap, HashSet};
53use std::sync::atomic::{AtomicU64, Ordering};
54use std::sync::RwLock;
55use std::time::{Duration, Instant, SystemTime};
56
57pub type TransactionId = String;
59
60pub const DEFAULT_TRANSACTION_TIMEOUT: Duration = Duration::from_secs(60);
62
63pub const MAX_PENDING_TRANSACTIONS: usize = 5;
65
66#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
68pub enum TransactionState {
69 Empty,
71
72 Ongoing,
74
75 PrepareCommit,
77
78 PrepareAbort,
80
81 CompleteCommit,
83
84 CompleteAbort,
86
87 Dead,
89}
90
91impl TransactionState {
92 pub fn is_terminal(&self) -> bool {
94 matches!(
95 self,
96 TransactionState::Empty
97 | TransactionState::CompleteCommit
98 | TransactionState::CompleteAbort
99 | TransactionState::Dead
100 )
101 }
102
103 pub fn is_active(&self) -> bool {
105 matches!(self, TransactionState::Ongoing)
106 }
107
108 pub fn can_commit(&self) -> bool {
110 matches!(self, TransactionState::Ongoing)
111 }
112
113 pub fn can_abort(&self) -> bool {
115 matches!(
116 self,
117 TransactionState::Ongoing
118 | TransactionState::PrepareCommit
119 | TransactionState::PrepareAbort
120 )
121 }
122}
123
124#[derive(Debug, Clone, PartialEq, Eq)]
126pub enum TransactionResult {
127 Ok,
129
130 InvalidTransactionId,
132
133 InvalidTransactionState {
135 current: TransactionState,
136 expected: &'static str,
137 },
138
139 ProducerFenced {
141 expected_epoch: ProducerEpoch,
142 received_epoch: ProducerEpoch,
143 },
144
145 TransactionTimeout,
147
148 TooManyTransactions,
150
151 ConcurrentTransaction,
153
154 PartitionNotInTransaction { topic: String, partition: u32 },
156}
157
158#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
160pub struct TransactionPartition {
161 pub topic: String,
162 pub partition: u32,
163}
164
165impl TransactionPartition {
166 pub fn new(topic: impl Into<String>, partition: u32) -> Self {
167 Self {
168 topic: topic.into(),
169 partition,
170 }
171 }
172}
173
174#[derive(Debug, Clone, Serialize, Deserialize)]
176pub struct PendingWrite {
177 pub partition: TransactionPartition,
179
180 pub sequence: i32,
182
183 pub offset: u64,
185
186 #[serde(with = "crate::serde_utils::system_time")]
188 pub timestamp: SystemTime,
189}
190
191#[derive(Debug, Clone, Serialize, Deserialize)]
193pub struct TransactionOffsetCommit {
194 pub group_id: String,
196
197 pub offsets: Vec<(TransactionPartition, i64)>,
199}
200
201#[derive(Debug, Clone, Serialize, Deserialize)]
203pub struct Transaction {
204 pub txn_id: TransactionId,
206
207 pub producer_id: ProducerId,
209
210 pub producer_epoch: ProducerEpoch,
212
213 pub state: TransactionState,
215
216 pub partitions: HashSet<TransactionPartition>,
218
219 pub pending_writes: Vec<PendingWrite>,
221
222 pub offset_commits: Vec<TransactionOffsetCommit>,
224
225 #[serde(with = "crate::serde_utils::system_time")]
227 pub started_at: SystemTime,
228
229 #[serde(with = "crate::serde_utils::duration")]
231 pub timeout: Duration,
232
233 #[serde(skip)]
235 pub last_activity: Option<Instant>,
236}
237
238impl Transaction {
239 pub fn new(
241 txn_id: TransactionId,
242 producer_id: ProducerId,
243 producer_epoch: ProducerEpoch,
244 timeout: Duration,
245 ) -> Self {
246 Self {
247 txn_id,
248 producer_id,
249 producer_epoch,
250 state: TransactionState::Ongoing,
251 partitions: HashSet::new(),
252 pending_writes: Vec::new(),
253 offset_commits: Vec::new(),
254 started_at: SystemTime::now(),
255 timeout,
256 last_activity: Some(Instant::now()),
257 }
258 }
259
260 pub fn is_timed_out(&self) -> bool {
262 self.last_activity
263 .map(|t| t.elapsed() > self.timeout)
264 .unwrap_or(true)
265 }
266
267 pub fn touch(&mut self) {
269 self.last_activity = Some(Instant::now());
270 }
271
272 pub fn add_partition(&mut self, partition: TransactionPartition) {
274 self.partitions.insert(partition);
275 self.touch();
276 }
277
278 pub fn add_write(&mut self, partition: TransactionPartition, sequence: i32, offset: u64) {
280 self.pending_writes.push(PendingWrite {
281 partition,
282 sequence,
283 offset,
284 timestamp: SystemTime::now(),
285 });
286 self.touch();
287 }
288
289 pub fn add_offset_commit(
291 &mut self,
292 group_id: String,
293 offsets: Vec<(TransactionPartition, i64)>,
294 ) {
295 self.offset_commits
296 .push(TransactionOffsetCommit { group_id, offsets });
297 self.touch();
298 }
299
300 pub fn write_count(&self) -> usize {
302 self.pending_writes.len()
303 }
304
305 pub fn affected_partitions(&self) -> impl Iterator<Item = &TransactionPartition> {
307 self.partitions.iter()
308 }
309}
310
311#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
313pub enum TransactionMarker {
314 Commit,
316
317 Abort,
319}
320
321#[derive(Debug, Default)]
323pub struct TransactionStats {
324 transactions_started: AtomicU64,
326
327 transactions_committed: AtomicU64,
329
330 transactions_aborted: AtomicU64,
332
333 transactions_timed_out: AtomicU64,
335
336 active_transactions: AtomicU64,
338}
339
340impl TransactionStats {
341 pub fn new() -> Self {
342 Self::default()
343 }
344
345 pub fn record_start(&self) {
346 self.transactions_started.fetch_add(1, Ordering::Relaxed);
347 self.active_transactions.fetch_add(1, Ordering::Relaxed);
348 }
349
350 pub fn record_commit(&self) {
351 self.transactions_committed.fetch_add(1, Ordering::Relaxed);
352 self.active_transactions.fetch_sub(1, Ordering::Relaxed);
353 }
354
355 pub fn record_abort(&self) {
356 self.transactions_aborted.fetch_add(1, Ordering::Relaxed);
357 self.active_transactions.fetch_sub(1, Ordering::Relaxed);
358 }
359
360 pub fn record_timeout(&self) {
361 self.transactions_timed_out.fetch_add(1, Ordering::Relaxed);
362 self.active_transactions.fetch_sub(1, Ordering::Relaxed);
363 }
364
365 pub fn transactions_started(&self) -> u64 {
366 self.transactions_started.load(Ordering::Relaxed)
367 }
368
369 pub fn transactions_committed(&self) -> u64 {
370 self.transactions_committed.load(Ordering::Relaxed)
371 }
372
373 pub fn transactions_aborted(&self) -> u64 {
374 self.transactions_aborted.load(Ordering::Relaxed)
375 }
376
377 pub fn transactions_timed_out(&self) -> u64 {
378 self.transactions_timed_out.load(Ordering::Relaxed)
379 }
380
381 pub fn active_transactions(&self) -> u64 {
382 self.active_transactions.load(Ordering::Relaxed)
383 }
384}
385
386#[derive(Debug, Clone, Serialize, Deserialize)]
388pub struct TransactionStatsSnapshot {
389 pub transactions_started: u64,
390 pub transactions_committed: u64,
391 pub transactions_aborted: u64,
392 pub transactions_timed_out: u64,
393 pub active_transactions: u64,
394}
395
396impl From<&TransactionStats> for TransactionStatsSnapshot {
397 fn from(stats: &TransactionStats) -> Self {
398 Self {
399 transactions_started: stats.transactions_started(),
400 transactions_committed: stats.transactions_committed(),
401 transactions_aborted: stats.transactions_aborted(),
402 transactions_timed_out: stats.transactions_timed_out(),
403 active_transactions: stats.active_transactions(),
404 }
405 }
406}
407
408pub struct TransactionCoordinator {
413 transactions: RwLock<HashMap<(ProducerId, TransactionId), Transaction>>,
415
416 producer_transactions: RwLock<HashMap<ProducerId, TransactionId>>,
418
419 default_timeout: Duration,
421
422 stats: TransactionStats,
424}
425
426impl Default for TransactionCoordinator {
427 fn default() -> Self {
428 Self::new()
429 }
430}
431
432impl TransactionCoordinator {
433 pub fn new() -> Self {
435 Self {
436 transactions: RwLock::new(HashMap::new()),
437 producer_transactions: RwLock::new(HashMap::new()),
438 default_timeout: DEFAULT_TRANSACTION_TIMEOUT,
439 stats: TransactionStats::new(),
440 }
441 }
442
443 pub fn with_timeout(timeout: Duration) -> Self {
445 Self {
446 transactions: RwLock::new(HashMap::new()),
447 producer_transactions: RwLock::new(HashMap::new()),
448 default_timeout: timeout,
449 stats: TransactionStats::new(),
450 }
451 }
452
453 pub fn stats(&self) -> &TransactionStats {
455 &self.stats
456 }
457
458 pub fn begin_transaction(
460 &self,
461 txn_id: TransactionId,
462 producer_id: ProducerId,
463 producer_epoch: ProducerEpoch,
464 timeout: Option<Duration>,
465 ) -> TransactionResult {
466 {
468 let producer_txns = self
469 .producer_transactions
470 .read()
471 .expect("transaction manager lock poisoned");
472 if let Some(existing_txn_id) = producer_txns.get(&producer_id) {
473 if existing_txn_id != &txn_id {
474 return TransactionResult::ConcurrentTransaction;
475 }
476 let transactions = self
478 .transactions
479 .read()
480 .expect("transaction manager lock poisoned");
481 if let Some(txn) = transactions.get(&(producer_id, txn_id.clone())) {
482 if txn.producer_epoch != producer_epoch {
483 return TransactionResult::ProducerFenced {
484 expected_epoch: txn.producer_epoch,
485 received_epoch: producer_epoch,
486 };
487 }
488 if txn.state.is_active() {
489 return TransactionResult::Ok; }
491 }
492 }
493 }
494
495 let txn = Transaction::new(
497 txn_id.clone(),
498 producer_id,
499 producer_epoch,
500 timeout.unwrap_or(self.default_timeout),
501 );
502
503 {
504 let mut transactions = self
505 .transactions
506 .write()
507 .expect("transaction manager lock poisoned");
508 let mut producer_txns = self
509 .producer_transactions
510 .write()
511 .expect("transaction manager lock poisoned");
512
513 transactions.insert((producer_id, txn_id.clone()), txn);
514 producer_txns.insert(producer_id, txn_id);
515 }
516
517 self.stats.record_start();
518 TransactionResult::Ok
519 }
520
521 pub fn add_partitions_to_transaction(
523 &self,
524 txn_id: &TransactionId,
525 producer_id: ProducerId,
526 producer_epoch: ProducerEpoch,
527 partitions: Vec<TransactionPartition>,
528 ) -> TransactionResult {
529 let mut transactions = self
530 .transactions
531 .write()
532 .expect("transaction manager lock poisoned");
533
534 let txn = match transactions.get_mut(&(producer_id, txn_id.clone())) {
535 Some(t) => t,
536 None => return TransactionResult::InvalidTransactionId,
537 };
538
539 if txn.producer_epoch != producer_epoch {
541 return TransactionResult::ProducerFenced {
542 expected_epoch: txn.producer_epoch,
543 received_epoch: producer_epoch,
544 };
545 }
546
547 if !txn.state.is_active() {
549 return TransactionResult::InvalidTransactionState {
550 current: txn.state,
551 expected: "Ongoing",
552 };
553 }
554
555 if txn.is_timed_out() {
557 txn.state = TransactionState::Dead;
558 self.stats.record_timeout();
559 return TransactionResult::TransactionTimeout;
560 }
561
562 for partition in partitions {
564 txn.add_partition(partition);
565 }
566
567 TransactionResult::Ok
568 }
569
570 pub fn add_write_to_transaction(
572 &self,
573 txn_id: &TransactionId,
574 producer_id: ProducerId,
575 producer_epoch: ProducerEpoch,
576 partition: TransactionPartition,
577 sequence: i32,
578 offset: u64,
579 ) -> TransactionResult {
580 let mut transactions = self
581 .transactions
582 .write()
583 .expect("transaction manager lock poisoned");
584
585 let txn = match transactions.get_mut(&(producer_id, txn_id.clone())) {
586 Some(t) => t,
587 None => return TransactionResult::InvalidTransactionId,
588 };
589
590 if txn.producer_epoch != producer_epoch {
592 return TransactionResult::ProducerFenced {
593 expected_epoch: txn.producer_epoch,
594 received_epoch: producer_epoch,
595 };
596 }
597
598 if !txn.state.is_active() {
600 return TransactionResult::InvalidTransactionState {
601 current: txn.state,
602 expected: "Ongoing",
603 };
604 }
605
606 if txn.is_timed_out() {
608 txn.state = TransactionState::Dead;
609 self.stats.record_timeout();
610 return TransactionResult::TransactionTimeout;
611 }
612
613 if !txn.partitions.contains(&partition) {
615 return TransactionResult::PartitionNotInTransaction {
616 topic: partition.topic,
617 partition: partition.partition,
618 };
619 }
620
621 txn.add_write(partition, sequence, offset);
623
624 TransactionResult::Ok
625 }
626
627 pub fn add_offsets_to_transaction(
629 &self,
630 txn_id: &TransactionId,
631 producer_id: ProducerId,
632 producer_epoch: ProducerEpoch,
633 group_id: String,
634 offsets: Vec<(TransactionPartition, i64)>,
635 ) -> TransactionResult {
636 let mut transactions = self
637 .transactions
638 .write()
639 .expect("transaction manager lock poisoned");
640
641 let txn = match transactions.get_mut(&(producer_id, txn_id.clone())) {
642 Some(t) => t,
643 None => return TransactionResult::InvalidTransactionId,
644 };
645
646 if txn.producer_epoch != producer_epoch {
648 return TransactionResult::ProducerFenced {
649 expected_epoch: txn.producer_epoch,
650 received_epoch: producer_epoch,
651 };
652 }
653
654 if !txn.state.is_active() {
656 return TransactionResult::InvalidTransactionState {
657 current: txn.state,
658 expected: "Ongoing",
659 };
660 }
661
662 if txn.is_timed_out() {
664 txn.state = TransactionState::Dead;
665 self.stats.record_timeout();
666 return TransactionResult::TransactionTimeout;
667 }
668
669 txn.add_offset_commit(group_id, offsets);
671
672 TransactionResult::Ok
673 }
674
675 pub fn prepare_commit(
679 &self,
680 txn_id: &TransactionId,
681 producer_id: ProducerId,
682 producer_epoch: ProducerEpoch,
683 ) -> Result<Transaction, TransactionResult> {
684 let mut transactions = self
685 .transactions
686 .write()
687 .expect("transaction manager lock poisoned");
688
689 let txn = match transactions.get_mut(&(producer_id, txn_id.clone())) {
690 Some(t) => t,
691 None => return Err(TransactionResult::InvalidTransactionId),
692 };
693
694 if txn.producer_epoch != producer_epoch {
696 return Err(TransactionResult::ProducerFenced {
697 expected_epoch: txn.producer_epoch,
698 received_epoch: producer_epoch,
699 });
700 }
701
702 if !txn.state.can_commit() {
704 return Err(TransactionResult::InvalidTransactionState {
705 current: txn.state,
706 expected: "Ongoing",
707 });
708 }
709
710 if txn.is_timed_out() {
712 txn.state = TransactionState::Dead;
713 self.stats.record_timeout();
714 return Err(TransactionResult::TransactionTimeout);
715 }
716
717 txn.state = TransactionState::PrepareCommit;
719 txn.touch();
720
721 Ok(txn.clone())
722 }
723
724 pub fn complete_commit(
726 &self,
727 txn_id: &TransactionId,
728 producer_id: ProducerId,
729 ) -> TransactionResult {
730 let mut transactions = self
731 .transactions
732 .write()
733 .expect("transaction manager lock poisoned");
734 let mut producer_txns = self
735 .producer_transactions
736 .write()
737 .expect("transaction manager lock poisoned");
738
739 let txn = match transactions.get_mut(&(producer_id, txn_id.clone())) {
740 Some(t) => t,
741 None => return TransactionResult::InvalidTransactionId,
742 };
743
744 if txn.state != TransactionState::PrepareCommit {
745 return TransactionResult::InvalidTransactionState {
746 current: txn.state,
747 expected: "PrepareCommit",
748 };
749 }
750
751 txn.state = TransactionState::CompleteCommit;
752
753 transactions.remove(&(producer_id, txn_id.clone()));
755 producer_txns.remove(&producer_id);
756
757 self.stats.record_commit();
758 TransactionResult::Ok
759 }
760
761 pub fn prepare_abort(
763 &self,
764 txn_id: &TransactionId,
765 producer_id: ProducerId,
766 producer_epoch: ProducerEpoch,
767 ) -> Result<Transaction, TransactionResult> {
768 let mut transactions = self
769 .transactions
770 .write()
771 .expect("transaction manager lock poisoned");
772
773 let txn = match transactions.get_mut(&(producer_id, txn_id.clone())) {
774 Some(t) => t,
775 None => return Err(TransactionResult::InvalidTransactionId),
776 };
777
778 if txn.producer_epoch != producer_epoch {
780 return Err(TransactionResult::ProducerFenced {
781 expected_epoch: txn.producer_epoch,
782 received_epoch: producer_epoch,
783 });
784 }
785
786 if !txn.state.can_abort() {
788 return Err(TransactionResult::InvalidTransactionState {
789 current: txn.state,
790 expected: "Ongoing or PrepareCommit",
791 });
792 }
793
794 txn.state = TransactionState::PrepareAbort;
796 txn.touch();
797
798 Ok(txn.clone())
799 }
800
801 pub fn complete_abort(
803 &self,
804 txn_id: &TransactionId,
805 producer_id: ProducerId,
806 ) -> TransactionResult {
807 let mut transactions = self
808 .transactions
809 .write()
810 .expect("transaction manager lock poisoned");
811 let mut producer_txns = self
812 .producer_transactions
813 .write()
814 .expect("transaction manager lock poisoned");
815
816 let txn = match transactions.get_mut(&(producer_id, txn_id.clone())) {
817 Some(t) => t,
818 None => return TransactionResult::InvalidTransactionId,
819 };
820
821 if txn.state != TransactionState::PrepareAbort {
822 return TransactionResult::InvalidTransactionState {
823 current: txn.state,
824 expected: "PrepareAbort",
825 };
826 }
827
828 txn.state = TransactionState::CompleteAbort;
829
830 transactions.remove(&(producer_id, txn_id.clone()));
832 producer_txns.remove(&producer_id);
833
834 self.stats.record_abort();
835 TransactionResult::Ok
836 }
837
838 pub fn get_transaction(
840 &self,
841 txn_id: &TransactionId,
842 producer_id: ProducerId,
843 ) -> Option<Transaction> {
844 let transactions = self
845 .transactions
846 .read()
847 .expect("transaction manager lock poisoned");
848 transactions.get(&(producer_id, txn_id.clone())).cloned()
849 }
850
851 pub fn has_active_transaction(&self, producer_id: ProducerId) -> bool {
853 let producer_txns = self
854 .producer_transactions
855 .read()
856 .expect("transaction manager lock poisoned");
857 producer_txns.contains_key(&producer_id)
858 }
859
860 pub fn get_active_transaction_id(&self, producer_id: ProducerId) -> Option<TransactionId> {
862 let producer_txns = self
863 .producer_transactions
864 .read()
865 .expect("transaction manager lock poisoned");
866 producer_txns.get(&producer_id).cloned()
867 }
868
869 pub fn cleanup_timed_out_transactions(&self) -> Vec<Transaction> {
871 let mut timed_out = Vec::new();
872 let mut transactions = self
873 .transactions
874 .write()
875 .expect("transaction manager lock poisoned");
876 let mut producer_txns = self
877 .producer_transactions
878 .write()
879 .expect("transaction manager lock poisoned");
880
881 let keys_to_remove: Vec<_> = transactions
882 .iter()
883 .filter(|(_, txn)| txn.is_timed_out() && !txn.state.is_terminal())
884 .map(|(k, _)| k.clone())
885 .collect();
886
887 for key in keys_to_remove {
888 if let Some(mut txn) = transactions.remove(&key) {
889 txn.state = TransactionState::Dead;
890 producer_txns.remove(&txn.producer_id);
891 self.stats.record_timeout();
892 timed_out.push(txn);
893 }
894 }
895
896 timed_out
897 }
898
899 pub fn active_count(&self) -> usize {
901 let transactions = self
902 .transactions
903 .read()
904 .expect("transaction manager lock poisoned");
905 transactions
906 .values()
907 .filter(|t| !t.state.is_terminal())
908 .count()
909 }
910}
911
912#[cfg(test)]
917mod tests {
918 use super::*;
919
920 #[test]
921 fn test_transaction_state_transitions() {
922 assert!(TransactionState::Empty.is_terminal());
924 assert!(TransactionState::CompleteCommit.is_terminal());
925 assert!(TransactionState::CompleteAbort.is_terminal());
926 assert!(TransactionState::Dead.is_terminal());
927
928 assert!(!TransactionState::Ongoing.is_terminal());
930 assert!(!TransactionState::PrepareCommit.is_terminal());
931 assert!(!TransactionState::PrepareAbort.is_terminal());
932
933 assert!(TransactionState::Ongoing.can_commit());
935 assert!(!TransactionState::Empty.can_commit());
936 assert!(!TransactionState::PrepareCommit.can_commit());
937
938 assert!(TransactionState::Ongoing.can_abort());
940 assert!(TransactionState::PrepareCommit.can_abort());
941 assert!(TransactionState::PrepareAbort.can_abort());
942 assert!(!TransactionState::Empty.can_abort());
943 }
944
945 #[test]
946 fn test_begin_transaction() {
947 let coordinator = TransactionCoordinator::new();
948
949 let result = coordinator.begin_transaction("txn-1".to_string(), 1, 0, None);
951 assert_eq!(result, TransactionResult::Ok);
952
953 let txn = coordinator.get_transaction(&"txn-1".to_string(), 1);
955 assert!(txn.is_some());
956 let txn = txn.unwrap();
957 assert_eq!(txn.state, TransactionState::Ongoing);
958 assert_eq!(txn.producer_id, 1);
959 assert_eq!(txn.producer_epoch, 0);
960
961 assert_eq!(coordinator.stats().transactions_started(), 1);
963 assert_eq!(coordinator.stats().active_transactions(), 1);
964 }
965
966 #[test]
967 fn test_concurrent_transaction_rejection() {
968 let coordinator = TransactionCoordinator::new();
969
970 coordinator.begin_transaction("txn-1".to_string(), 1, 0, None);
972
973 let result = coordinator.begin_transaction("txn-2".to_string(), 1, 0, None);
975 assert_eq!(result, TransactionResult::ConcurrentTransaction);
976 }
977
978 #[test]
979 fn test_add_partitions_to_transaction() {
980 let coordinator = TransactionCoordinator::new();
981 coordinator.begin_transaction("txn-1".to_string(), 1, 0, None);
982
983 let result = coordinator.add_partitions_to_transaction(
985 &"txn-1".to_string(),
986 1,
987 0,
988 vec![
989 TransactionPartition::new("topic-1", 0),
990 TransactionPartition::new("topic-1", 1),
991 TransactionPartition::new("topic-2", 0),
992 ],
993 );
994 assert_eq!(result, TransactionResult::Ok);
995
996 let txn = coordinator
998 .get_transaction(&"txn-1".to_string(), 1)
999 .unwrap();
1000 assert_eq!(txn.partitions.len(), 3);
1001 }
1002
1003 #[test]
1004 fn test_add_write_to_transaction() {
1005 let coordinator = TransactionCoordinator::new();
1006 coordinator.begin_transaction("txn-1".to_string(), 1, 0, None);
1007
1008 let partition = TransactionPartition::new("topic-1", 0);
1009 coordinator.add_partitions_to_transaction(
1010 &"txn-1".to_string(),
1011 1,
1012 0,
1013 vec![partition.clone()],
1014 );
1015
1016 let result =
1018 coordinator.add_write_to_transaction(&"txn-1".to_string(), 1, 0, partition, 0, 100);
1019 assert_eq!(result, TransactionResult::Ok);
1020
1021 let txn = coordinator
1023 .get_transaction(&"txn-1".to_string(), 1)
1024 .unwrap();
1025 assert_eq!(txn.pending_writes.len(), 1);
1026 assert_eq!(txn.pending_writes[0].offset, 100);
1027 assert_eq!(txn.pending_writes[0].sequence, 0);
1028 }
1029
1030 #[test]
1031 fn test_write_to_non_registered_partition() {
1032 let coordinator = TransactionCoordinator::new();
1033 coordinator.begin_transaction("txn-1".to_string(), 1, 0, None);
1034
1035 let result = coordinator.add_write_to_transaction(
1037 &"txn-1".to_string(),
1038 1,
1039 0,
1040 TransactionPartition::new("topic-1", 0),
1041 0,
1042 100,
1043 );
1044
1045 assert!(matches!(
1046 result,
1047 TransactionResult::PartitionNotInTransaction { .. }
1048 ));
1049 }
1050
1051 #[test]
1052 fn test_commit_transaction() {
1053 let coordinator = TransactionCoordinator::new();
1054 coordinator.begin_transaction("txn-1".to_string(), 1, 0, None);
1055
1056 let partition = TransactionPartition::new("topic-1", 0);
1057 coordinator.add_partitions_to_transaction(
1058 &"txn-1".to_string(),
1059 1,
1060 0,
1061 vec![partition.clone()],
1062 );
1063 coordinator.add_write_to_transaction(&"txn-1".to_string(), 1, 0, partition, 0, 100);
1064
1065 let txn = coordinator.prepare_commit(&"txn-1".to_string(), 1, 0);
1067 assert!(txn.is_ok());
1068 let txn = txn.unwrap();
1069 assert_eq!(txn.state, TransactionState::PrepareCommit);
1070
1071 let result = coordinator.complete_commit(&"txn-1".to_string(), 1);
1073 assert_eq!(result, TransactionResult::Ok);
1074
1075 assert!(coordinator
1077 .get_transaction(&"txn-1".to_string(), 1)
1078 .is_none());
1079 assert!(!coordinator.has_active_transaction(1));
1080
1081 assert_eq!(coordinator.stats().transactions_committed(), 1);
1083 assert_eq!(coordinator.stats().active_transactions(), 0);
1084 }
1085
1086 #[test]
1087 fn test_abort_transaction() {
1088 let coordinator = TransactionCoordinator::new();
1089 coordinator.begin_transaction("txn-1".to_string(), 1, 0, None);
1090
1091 let partition = TransactionPartition::new("topic-1", 0);
1092 coordinator.add_partitions_to_transaction(
1093 &"txn-1".to_string(),
1094 1,
1095 0,
1096 vec![partition.clone()],
1097 );
1098 coordinator.add_write_to_transaction(&"txn-1".to_string(), 1, 0, partition, 0, 100);
1099
1100 let txn = coordinator.prepare_abort(&"txn-1".to_string(), 1, 0);
1102 assert!(txn.is_ok());
1103
1104 let result = coordinator.complete_abort(&"txn-1".to_string(), 1);
1106 assert_eq!(result, TransactionResult::Ok);
1107
1108 assert!(coordinator
1110 .get_transaction(&"txn-1".to_string(), 1)
1111 .is_none());
1112
1113 assert_eq!(coordinator.stats().transactions_aborted(), 1);
1115 }
1116
1117 #[test]
1118 fn test_producer_fencing() {
1119 let coordinator = TransactionCoordinator::new();
1120 coordinator.begin_transaction("txn-1".to_string(), 1, 0, None);
1121
1122 let result = coordinator.add_partitions_to_transaction(
1124 &"txn-1".to_string(),
1125 1,
1126 1, vec![TransactionPartition::new("topic-1", 0)],
1128 );
1129
1130 assert!(matches!(
1131 result,
1132 TransactionResult::ProducerFenced {
1133 expected_epoch: 0,
1134 received_epoch: 1
1135 }
1136 ));
1137 }
1138
1139 #[test]
1140 fn test_transaction_timeout() {
1141 let coordinator = TransactionCoordinator::with_timeout(Duration::from_millis(1));
1143 coordinator.begin_transaction("txn-1".to_string(), 1, 0, None);
1144
1145 std::thread::sleep(Duration::from_millis(5));
1147
1148 let result = coordinator.add_partitions_to_transaction(
1150 &"txn-1".to_string(),
1151 1,
1152 0,
1153 vec![TransactionPartition::new("topic-1", 0)],
1154 );
1155
1156 assert_eq!(result, TransactionResult::TransactionTimeout);
1157 }
1158
1159 #[test]
1160 fn test_cleanup_timed_out_transactions() {
1161 let coordinator = TransactionCoordinator::with_timeout(Duration::from_millis(1));
1162
1163 coordinator.begin_transaction("txn-1".to_string(), 1, 0, None);
1164 coordinator.begin_transaction("txn-2".to_string(), 2, 0, None);
1165
1166 std::thread::sleep(Duration::from_millis(5));
1168
1169 let timed_out = coordinator.cleanup_timed_out_transactions();
1171 assert_eq!(timed_out.len(), 2);
1172
1173 assert_eq!(coordinator.active_count(), 0);
1175 assert_eq!(coordinator.stats().transactions_timed_out(), 2);
1176 }
1177
1178 #[test]
1179 fn test_add_offsets_to_transaction() {
1180 let coordinator = TransactionCoordinator::new();
1181 coordinator.begin_transaction("txn-1".to_string(), 1, 0, None);
1182
1183 let result = coordinator.add_offsets_to_transaction(
1185 &"txn-1".to_string(),
1186 1,
1187 0,
1188 "consumer-group-1".to_string(),
1189 vec![
1190 (TransactionPartition::new("input-topic", 0), 42),
1191 (TransactionPartition::new("input-topic", 1), 100),
1192 ],
1193 );
1194 assert_eq!(result, TransactionResult::Ok);
1195
1196 let txn = coordinator
1198 .get_transaction(&"txn-1".to_string(), 1)
1199 .unwrap();
1200 assert_eq!(txn.offset_commits.len(), 1);
1201 assert_eq!(txn.offset_commits[0].group_id, "consumer-group-1");
1202 assert_eq!(txn.offset_commits[0].offsets.len(), 2);
1203 }
1204
1205 #[test]
1206 fn test_invalid_state_transitions() {
1207 let coordinator = TransactionCoordinator::new();
1208 coordinator.begin_transaction("txn-1".to_string(), 1, 0, None);
1209
1210 coordinator
1212 .prepare_commit(&"txn-1".to_string(), 1, 0)
1213 .unwrap();
1214
1215 let result = coordinator.add_partitions_to_transaction(
1217 &"txn-1".to_string(),
1218 1,
1219 0,
1220 vec![TransactionPartition::new("topic-1", 0)],
1221 );
1222 assert!(matches!(
1223 result,
1224 TransactionResult::InvalidTransactionState { .. }
1225 ));
1226 }
1227
1228 #[test]
1229 fn test_abort_from_prepare_commit() {
1230 let coordinator = TransactionCoordinator::new();
1231 coordinator.begin_transaction("txn-1".to_string(), 1, 0, None);
1232
1233 coordinator
1235 .prepare_commit(&"txn-1".to_string(), 1, 0)
1236 .unwrap();
1237
1238 let result = coordinator.prepare_abort(&"txn-1".to_string(), 1, 0);
1240 assert!(result.is_ok());
1241
1242 let result = coordinator.complete_abort(&"txn-1".to_string(), 1);
1243 assert_eq!(result, TransactionResult::Ok);
1244 }
1245
1246 #[test]
1247 fn test_transaction_partition_hash() {
1248 let p1 = TransactionPartition::new("topic", 0);
1249 let p2 = TransactionPartition::new("topic", 0);
1250 let p3 = TransactionPartition::new("topic", 1);
1251
1252 assert_eq!(p1, p2);
1253 assert_ne!(p1, p3);
1254
1255 let mut set = HashSet::new();
1256 set.insert(p1.clone());
1257 set.insert(p2); set.insert(p3);
1259 assert_eq!(set.len(), 2);
1260 }
1261
1262 #[test]
1263 fn test_resume_same_transaction() {
1264 let coordinator = TransactionCoordinator::new();
1265
1266 coordinator.begin_transaction("txn-1".to_string(), 1, 0, None);
1268
1269 let result = coordinator.begin_transaction("txn-1".to_string(), 1, 0, None);
1271 assert_eq!(result, TransactionResult::Ok);
1272
1273 assert_eq!(coordinator.active_count(), 1);
1275 assert_eq!(coordinator.stats().transactions_started(), 1);
1276 }
1277
1278 #[test]
1279 fn test_stats_snapshot() {
1280 let coordinator = TransactionCoordinator::new();
1281 coordinator.begin_transaction("txn-1".to_string(), 1, 0, None);
1282 coordinator
1283 .prepare_commit(&"txn-1".to_string(), 1, 0)
1284 .unwrap();
1285 coordinator.complete_commit(&"txn-1".to_string(), 1);
1286
1287 let snapshot: TransactionStatsSnapshot = coordinator.stats().into();
1288 assert_eq!(snapshot.transactions_started, 1);
1289 assert_eq!(snapshot.transactions_committed, 1);
1290 assert_eq!(snapshot.active_transactions, 0);
1291 }
1292}