use crate::idempotent::{ProducerEpoch, ProducerId};
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::io::Write;
use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant, SystemTime};
pub type TransactionId = String;
pub const DEFAULT_TRANSACTION_TIMEOUT: Duration = Duration::from_secs(60);
pub const MAX_PENDING_TRANSACTIONS: usize = 5;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum TransactionState {
Empty,
Ongoing,
PrepareCommit,
PrepareAbort,
CompleteCommit,
CompleteAbort,
Dead,
}
impl TransactionState {
pub fn is_terminal(&self) -> bool {
matches!(
self,
TransactionState::Empty
| TransactionState::CompleteCommit
| TransactionState::CompleteAbort
| TransactionState::Dead
)
}
pub fn is_active(&self) -> bool {
matches!(self, TransactionState::Ongoing)
}
pub fn can_commit(&self) -> bool {
matches!(self, TransactionState::Ongoing)
}
pub fn can_abort(&self) -> bool {
matches!(
self,
TransactionState::Ongoing
| TransactionState::PrepareCommit
| TransactionState::PrepareAbort
)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TransactionResult {
Ok,
InvalidTransactionId,
InvalidTransactionState {
current: TransactionState,
expected: &'static str,
},
ProducerFenced {
expected_epoch: ProducerEpoch,
received_epoch: ProducerEpoch,
},
TransactionTimeout,
TooManyTransactions,
ConcurrentTransaction,
PartitionNotInTransaction { topic: String, partition: u32 },
LogWriteError(String),
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct TransactionPartition {
pub topic: String,
pub partition: u32,
}
impl TransactionPartition {
pub fn new(topic: impl Into<String>, partition: u32) -> Self {
Self {
topic: topic.into(),
partition,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PendingWrite {
pub partition: TransactionPartition,
pub sequence: i32,
pub offset: u64,
#[serde(with = "crate::serde_utils::system_time")]
pub timestamp: SystemTime,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TransactionOffsetCommit {
pub group_id: String,
pub offsets: Vec<(TransactionPartition, i64)>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Transaction {
pub txn_id: TransactionId,
pub producer_id: ProducerId,
pub producer_epoch: ProducerEpoch,
pub state: TransactionState,
pub partitions: HashSet<TransactionPartition>,
pub pending_writes: Vec<PendingWrite>,
pub offset_commits: Vec<TransactionOffsetCommit>,
#[serde(with = "crate::serde_utils::system_time")]
pub started_at: SystemTime,
#[serde(with = "crate::serde_utils::duration")]
pub timeout: Duration,
#[serde(skip)]
pub last_activity: Option<Instant>,
}
impl Transaction {
pub fn new(
txn_id: TransactionId,
producer_id: ProducerId,
producer_epoch: ProducerEpoch,
timeout: Duration,
) -> Self {
Self {
txn_id,
producer_id,
producer_epoch,
state: TransactionState::Ongoing,
partitions: HashSet::new(),
pending_writes: Vec::new(),
offset_commits: Vec::new(),
started_at: SystemTime::now(),
timeout,
last_activity: Some(Instant::now()),
}
}
pub fn is_timed_out(&self) -> bool {
self.last_activity
.map(|t| t.elapsed() > self.timeout)
.unwrap_or(false)
}
pub fn touch(&mut self) {
self.last_activity = Some(Instant::now());
}
pub fn add_partition(&mut self, partition: TransactionPartition) {
self.partitions.insert(partition);
self.touch();
}
pub fn add_write(&mut self, partition: TransactionPartition, sequence: i32, offset: u64) {
self.pending_writes.push(PendingWrite {
partition,
sequence,
offset,
timestamp: SystemTime::now(),
});
self.touch();
}
pub fn add_offset_commit(
&mut self,
group_id: String,
offsets: Vec<(TransactionPartition, i64)>,
) {
self.offset_commits
.push(TransactionOffsetCommit { group_id, offsets });
self.touch();
}
pub fn write_count(&self) -> usize {
self.pending_writes.len()
}
pub fn affected_partitions(&self) -> impl Iterator<Item = &TransactionPartition> {
self.partitions.iter()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum TransactionMarker {
Commit,
Abort,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
pub enum IsolationLevel {
#[default]
ReadUncommitted,
ReadCommitted,
}
impl IsolationLevel {
pub fn as_str(&self) -> &'static str {
match self {
Self::ReadUncommitted => "read_uncommitted",
Self::ReadCommitted => "read_committed",
}
}
pub fn from_u8(value: u8) -> Self {
match value {
1 => Self::ReadCommitted,
_ => Self::ReadUncommitted,
}
}
pub fn as_u8(&self) -> u8 {
match self {
Self::ReadUncommitted => 0,
Self::ReadCommitted => 1,
}
}
}
impl std::str::FromStr for IsolationLevel {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"read_uncommitted" => Ok(Self::ReadUncommitted),
"read_committed" => Ok(Self::ReadCommitted),
_ => Err(format!("unknown isolation level: {}", s)),
}
}
}
impl std::fmt::Display for IsolationLevel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.as_str())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AbortedTransaction {
pub producer_id: ProducerId,
pub first_offset: u64,
pub last_offset: u64,
}
#[derive(Debug, Default)]
pub struct AbortedTransactionIndex {
aborted: RwLock<Vec<AbortedTransaction>>,
}
impl AbortedTransactionIndex {
pub fn new() -> Self {
Self::default()
}
pub fn record_abort(&self, producer_id: ProducerId, first_offset: u64, last_offset: u64) {
let mut aborted = self.aborted.write();
aborted.push(AbortedTransaction {
producer_id,
first_offset,
last_offset,
});
aborted.sort_by_key(|a| a.first_offset);
}
pub fn get_aborted_in_range(
&self,
start_offset: u64,
end_offset: u64,
) -> Vec<AbortedTransaction> {
let aborted = self.aborted.read();
aborted
.iter()
.filter(|a| a.first_offset >= start_offset && a.first_offset <= end_offset)
.cloned()
.collect()
}
pub fn is_aborted(&self, producer_id: ProducerId, offset: u64) -> bool {
let aborted = self.aborted.read();
aborted.iter().any(|a| {
a.producer_id == producer_id && a.first_offset <= offset && offset <= a.last_offset
})
}
pub fn truncate_before(&self, offset: u64) {
let mut aborted = self.aborted.write();
aborted.retain(|a| a.first_offset >= offset);
}
pub fn len(&self) -> usize {
self.aborted.read().len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
#[derive(Debug, Default)]
pub struct TransactionStats {
transactions_started: AtomicU64,
transactions_committed: AtomicU64,
transactions_aborted: AtomicU64,
transactions_timed_out: AtomicU64,
active_transactions: AtomicU64,
}
impl TransactionStats {
pub fn new() -> Self {
Self::default()
}
pub fn record_start(&self) {
self.transactions_started.fetch_add(1, Ordering::Relaxed);
self.active_transactions.fetch_add(1, Ordering::Relaxed);
}
pub fn record_commit(&self) {
self.transactions_committed.fetch_add(1, Ordering::Relaxed);
self.active_transactions.fetch_sub(1, Ordering::Relaxed);
}
pub fn record_abort(&self) {
self.transactions_aborted.fetch_add(1, Ordering::Relaxed);
self.active_transactions.fetch_sub(1, Ordering::Relaxed);
}
pub fn record_timeout(&self) {
self.transactions_timed_out.fetch_add(1, Ordering::Relaxed);
self.active_transactions.fetch_sub(1, Ordering::Relaxed);
}
pub fn transactions_started(&self) -> u64 {
self.transactions_started.load(Ordering::Relaxed)
}
pub fn transactions_committed(&self) -> u64 {
self.transactions_committed.load(Ordering::Relaxed)
}
pub fn transactions_aborted(&self) -> u64 {
self.transactions_aborted.load(Ordering::Relaxed)
}
pub fn transactions_timed_out(&self) -> u64 {
self.transactions_timed_out.load(Ordering::Relaxed)
}
pub fn active_transactions(&self) -> u64 {
self.active_transactions.load(Ordering::Relaxed)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TransactionStatsSnapshot {
pub transactions_started: u64,
pub transactions_committed: u64,
pub transactions_aborted: u64,
pub transactions_timed_out: u64,
pub active_transactions: u64,
}
impl From<&TransactionStats> for TransactionStatsSnapshot {
fn from(stats: &TransactionStats) -> Self {
Self {
transactions_started: stats.transactions_started(),
transactions_committed: stats.transactions_committed(),
transactions_aborted: stats.transactions_aborted(),
transactions_timed_out: stats.transactions_timed_out(),
active_transactions: stats.active_transactions(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum TransactionLogEntry {
Begin {
txn_id: TransactionId,
producer_id: ProducerId,
producer_epoch: ProducerEpoch,
timeout_ms: u64,
},
AddPartition {
txn_id: TransactionId,
producer_id: ProducerId,
partition: TransactionPartition,
},
RecordWrite {
txn_id: TransactionId,
producer_id: ProducerId,
partition: TransactionPartition,
sequence: i32,
offset: u64,
},
PrepareCommit {
txn_id: TransactionId,
producer_id: ProducerId,
},
CompleteCommit {
txn_id: TransactionId,
producer_id: ProducerId,
},
PrepareAbort {
txn_id: TransactionId,
producer_id: ProducerId,
},
CompleteAbort {
txn_id: TransactionId,
producer_id: ProducerId,
},
TimedOut {
txn_id: TransactionId,
producer_id: ProducerId,
},
OffsetCommit {
txn_id: TransactionId,
producer_id: ProducerId,
group_id: String,
offsets: Vec<(TransactionPartition, i64)>,
},
}
pub struct TransactionLog {
path: PathBuf,
writer: parking_lot::Mutex<Option<std::io::BufWriter<std::fs::File>>>,
}
impl TransactionLog {
pub fn open(path: impl AsRef<Path>) -> crate::Result<Self> {
let path = path.as_ref().to_path_buf();
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent)?;
}
let file = std::fs::OpenOptions::new()
.create(true)
.append(true)
.open(&path)?;
Ok(Self {
path,
writer: parking_lot::Mutex::new(Some(std::io::BufWriter::new(file))),
})
}
pub fn noop() -> Self {
Self {
path: PathBuf::new(),
writer: parking_lot::Mutex::new(None),
}
}
pub fn append(&self, entry: &TransactionLogEntry) -> crate::Result<()> {
let mut guard = self.writer.lock();
let writer = match guard.as_mut() {
Some(w) => w,
None => return Ok(()), };
let data = postcard::to_allocvec(entry).map_err(|e| crate::Error::Other(e.to_string()))?;
let mut hasher = crc32fast::Hasher::new();
hasher.update(&data);
let crc = hasher.finalize();
writer.write_all(&crc.to_be_bytes())?;
writer.write_all(&(data.len() as u32).to_be_bytes())?;
writer.write_all(&data)?;
writer.flush()?;
writer.get_ref().sync_data()?;
Ok(())
}
pub fn read_all(path: impl AsRef<Path>) -> crate::Result<Vec<TransactionLogEntry>> {
let path = path.as_ref();
if !path.exists() {
return Ok(Vec::new());
}
let data = std::fs::read(path)?;
let mut entries = Vec::new();
let mut pos = 0;
while pos + 8 <= data.len() {
let crc = u32::from_be_bytes(data[pos..pos + 4].try_into().unwrap());
let len = u32::from_be_bytes(data[pos + 4..pos + 8].try_into().unwrap()) as usize;
pos += 8;
if pos + len > data.len() {
break; }
let payload = &data[pos..pos + len];
let mut hasher = crc32fast::Hasher::new();
hasher.update(payload);
if hasher.finalize() != crc {
break; }
match postcard::from_bytes::<TransactionLogEntry>(payload) {
Ok(entry) => entries.push(entry),
Err(_) => break, }
pos += len;
}
Ok(entries)
}
pub fn truncate(&self) -> crate::Result<()> {
let mut guard = self.writer.lock();
if guard.is_none() {
return Ok(());
}
*guard = None;
let file = std::fs::OpenOptions::new()
.create(true)
.write(true)
.truncate(true)
.open(&self.path)?;
*guard = Some(std::io::BufWriter::new(file));
Ok(())
}
}
pub struct TransactionCoordinator {
transactions: RwLock<HashMap<(ProducerId, TransactionId), Transaction>>,
producer_transactions: RwLock<HashMap<ProducerId, TransactionId>>,
default_timeout: Duration,
stats: TransactionStats,
aborted_index: AbortedTransactionIndex,
txn_log: TransactionLog,
}
impl Default for TransactionCoordinator {
fn default() -> Self {
Self::new()
}
}
impl TransactionCoordinator {
pub fn new() -> Self {
Self {
transactions: RwLock::new(HashMap::new()),
producer_transactions: RwLock::new(HashMap::new()),
default_timeout: DEFAULT_TRANSACTION_TIMEOUT,
stats: TransactionStats::new(),
aborted_index: AbortedTransactionIndex::new(),
txn_log: TransactionLog::noop(),
}
}
pub fn with_timeout(timeout: Duration) -> Self {
Self {
transactions: RwLock::new(HashMap::new()),
producer_transactions: RwLock::new(HashMap::new()),
default_timeout: timeout,
stats: TransactionStats::new(),
aborted_index: AbortedTransactionIndex::new(),
txn_log: TransactionLog::noop(),
}
}
pub fn with_persistence(path: impl AsRef<Path>) -> crate::Result<Self> {
let txn_log = TransactionLog::open(path)?;
Ok(Self {
transactions: RwLock::new(HashMap::new()),
producer_transactions: RwLock::new(HashMap::new()),
default_timeout: DEFAULT_TRANSACTION_TIMEOUT,
stats: TransactionStats::new(),
aborted_index: AbortedTransactionIndex::new(),
txn_log,
})
}
pub fn recover(path: impl AsRef<Path>) -> crate::Result<Self> {
let path = path.as_ref();
let entries = TransactionLog::read_all(path)?;
let txn_log = TransactionLog::open(path)?;
let coord = Self {
transactions: RwLock::new(HashMap::new()),
producer_transactions: RwLock::new(HashMap::new()),
default_timeout: DEFAULT_TRANSACTION_TIMEOUT,
stats: TransactionStats::new(),
aborted_index: AbortedTransactionIndex::new(),
txn_log,
};
let mut transactions = coord.transactions.write();
let mut producer_txns = coord.producer_transactions.write();
for entry in entries {
match entry {
TransactionLogEntry::Begin {
txn_id,
producer_id,
producer_epoch,
timeout_ms,
} => {
let txn = Transaction::new(
txn_id.clone(),
producer_id,
producer_epoch,
Duration::from_millis(timeout_ms),
);
transactions.insert((producer_id, txn_id.clone()), txn);
producer_txns.insert(producer_id, txn_id);
}
TransactionLogEntry::AddPartition {
txn_id,
producer_id,
partition,
} => {
if let Some(txn) = transactions.get_mut(&(producer_id, txn_id)) {
txn.partitions.insert(partition);
}
}
TransactionLogEntry::RecordWrite {
txn_id,
producer_id,
partition,
sequence,
offset,
} => {
if let Some(txn) = transactions.get_mut(&(producer_id, txn_id)) {
txn.pending_writes.push(PendingWrite {
partition,
sequence,
offset,
timestamp: SystemTime::now(),
});
}
}
TransactionLogEntry::PrepareCommit {
txn_id,
producer_id,
} => {
if let Some(txn) = transactions.get_mut(&(producer_id, txn_id)) {
txn.state = TransactionState::PrepareCommit;
}
}
TransactionLogEntry::CompleteCommit {
txn_id,
producer_id,
} => {
transactions.remove(&(producer_id, txn_id.clone()));
producer_txns.remove(&producer_id);
}
TransactionLogEntry::PrepareAbort {
txn_id,
producer_id,
} => {
if let Some(txn) = transactions.get_mut(&(producer_id, txn_id)) {
txn.state = TransactionState::PrepareAbort;
}
}
TransactionLogEntry::CompleteAbort {
txn_id,
producer_id,
} => {
if let Some(txn) = transactions.get(&(producer_id, txn_id.clone())) {
let first = txn.pending_writes.iter().map(|w| w.offset).min();
let last = txn.pending_writes.iter().map(|w| w.offset).max();
if let (Some(f), Some(l)) = (first, last) {
coord.aborted_index.record_abort(producer_id, f, l);
}
}
transactions.remove(&(producer_id, txn_id.clone()));
producer_txns.remove(&producer_id);
}
TransactionLogEntry::TimedOut {
txn_id,
producer_id,
} => {
if let Some(txn) = transactions.get(&(producer_id, txn_id.clone())) {
let first = txn.pending_writes.iter().map(|w| w.offset).min();
let last = txn.pending_writes.iter().map(|w| w.offset).max();
if let (Some(f), Some(l)) = (first, last) {
coord.aborted_index.record_abort(producer_id, f, l);
}
}
transactions.remove(&(producer_id, txn_id.clone()));
producer_txns.remove(&producer_id);
}
TransactionLogEntry::OffsetCommit {
txn_id,
producer_id,
group_id,
offsets,
} => {
if let Some(txn) = transactions.get_mut(&(producer_id, txn_id)) {
txn.add_offset_commit(group_id, offsets);
}
}
}
}
drop(transactions);
drop(producer_txns);
let active = coord.active_count();
if active > 0 {
tracing::warn!(
"Transaction coordinator recovered {} in-doubt transactions from log",
active
);
}
Ok(coord)
}
pub fn stats(&self) -> &TransactionStats {
&self.stats
}
pub fn begin_transaction(
&self,
txn_id: TransactionId,
producer_id: ProducerId,
producer_epoch: ProducerEpoch,
timeout: Option<Duration>,
) -> TransactionResult {
let mut transactions = self.transactions.write();
let mut producer_txns = self.producer_transactions.write();
if let Some(existing_txn_id) = producer_txns.get(&producer_id) {
if existing_txn_id != &txn_id {
return TransactionResult::ConcurrentTransaction;
}
if let Some(txn) = transactions.get(&(producer_id, txn_id.clone())) {
if txn.producer_epoch != producer_epoch {
return TransactionResult::ProducerFenced {
expected_epoch: txn.producer_epoch,
received_epoch: producer_epoch,
};
}
if txn.state.is_active() {
return TransactionResult::Ok; }
}
}
let active_count = transactions
.values()
.filter(|t| t.state.is_active())
.count();
if active_count >= MAX_PENDING_TRANSACTIONS {
return TransactionResult::TooManyTransactions;
}
let txn = Transaction::new(
txn_id.clone(),
producer_id,
producer_epoch,
timeout.unwrap_or(self.default_timeout),
);
if let Err(e) = self.txn_log.append(&TransactionLogEntry::Begin {
txn_id: txn_id.clone(),
producer_id,
producer_epoch,
timeout_ms: timeout.unwrap_or(self.default_timeout).as_millis() as u64,
}) {
tracing::error!(producer_id, "Transaction log write failed on begin: {}", e);
return TransactionResult::LogWriteError(e.to_string());
}
transactions.insert((producer_id, txn_id.clone()), txn);
producer_txns.insert(producer_id, txn_id);
self.stats.record_start();
TransactionResult::Ok
}
pub fn add_partitions_to_transaction(
&self,
txn_id: &TransactionId,
producer_id: ProducerId,
producer_epoch: ProducerEpoch,
partitions: Vec<TransactionPartition>,
) -> TransactionResult {
let mut transactions = self.transactions.write();
let txn = match transactions.get_mut(&(producer_id, txn_id.clone())) {
Some(t) => t,
None => return TransactionResult::InvalidTransactionId,
};
if txn.producer_epoch != producer_epoch {
return TransactionResult::ProducerFenced {
expected_epoch: txn.producer_epoch,
received_epoch: producer_epoch,
};
}
if !txn.state.is_active() {
return TransactionResult::InvalidTransactionState {
current: txn.state,
expected: "Ongoing",
};
}
if txn.is_timed_out() {
txn.state = TransactionState::Dead;
self.stats.record_timeout();
return TransactionResult::TransactionTimeout;
}
for partition in &partitions {
if let Err(e) = self.txn_log.append(&TransactionLogEntry::AddPartition {
txn_id: txn_id.clone(),
producer_id,
partition: partition.clone(),
}) {
tracing::error!(
producer_id,
"Transaction log write failed on add_partition: {}",
e
);
return TransactionResult::LogWriteError(e.to_string());
}
}
for partition in partitions {
txn.add_partition(partition);
}
TransactionResult::Ok
}
pub fn validate_transaction_write(
&self,
txn_id: &TransactionId,
producer_id: ProducerId,
producer_epoch: ProducerEpoch,
partition: &TransactionPartition,
) -> TransactionResult {
let mut transactions = self.transactions.write();
let txn = match transactions.get_mut(&(producer_id, txn_id.clone())) {
Some(t) => t,
None => return TransactionResult::InvalidTransactionId,
};
if txn.producer_epoch != producer_epoch {
return TransactionResult::ProducerFenced {
expected_epoch: txn.producer_epoch,
received_epoch: producer_epoch,
};
}
if !txn.state.is_active() {
return TransactionResult::InvalidTransactionState {
current: txn.state,
expected: "Ongoing",
};
}
if txn.is_timed_out() {
txn.state = TransactionState::Dead;
self.stats.record_timeout();
return TransactionResult::TransactionTimeout;
}
if !txn.partitions.contains(partition) {
return TransactionResult::PartitionNotInTransaction {
topic: partition.topic.clone(),
partition: partition.partition,
};
}
TransactionResult::Ok
}
pub fn add_write_to_transaction(
&self,
txn_id: &TransactionId,
producer_id: ProducerId,
producer_epoch: ProducerEpoch,
partition: TransactionPartition,
sequence: i32,
offset: u64,
) -> TransactionResult {
let mut transactions = self.transactions.write();
let txn = match transactions.get_mut(&(producer_id, txn_id.clone())) {
Some(t) => t,
None => return TransactionResult::InvalidTransactionId,
};
if txn.producer_epoch != producer_epoch {
return TransactionResult::ProducerFenced {
expected_epoch: txn.producer_epoch,
received_epoch: producer_epoch,
};
}
if !txn.state.is_active() {
return TransactionResult::InvalidTransactionState {
current: txn.state,
expected: "Ongoing",
};
}
if txn.is_timed_out() {
txn.state = TransactionState::Dead;
self.stats.record_timeout();
return TransactionResult::TransactionTimeout;
}
if !txn.partitions.contains(&partition) {
return TransactionResult::PartitionNotInTransaction {
topic: partition.topic,
partition: partition.partition,
};
}
if let Err(e) = self.txn_log.append(&TransactionLogEntry::RecordWrite {
txn_id: txn_id.clone(),
producer_id,
partition: partition.clone(),
sequence,
offset,
}) {
tracing::error!(
producer_id,
offset,
"Transaction log write failed on record_write: {}",
e
);
return TransactionResult::LogWriteError(e.to_string());
}
txn.add_write(partition, sequence, offset);
TransactionResult::Ok
}
pub fn add_offsets_to_transaction(
&self,
txn_id: &TransactionId,
producer_id: ProducerId,
producer_epoch: ProducerEpoch,
group_id: String,
offsets: Vec<(TransactionPartition, i64)>,
) -> TransactionResult {
let mut transactions = self.transactions.write();
let txn = match transactions.get_mut(&(producer_id, txn_id.clone())) {
Some(t) => t,
None => return TransactionResult::InvalidTransactionId,
};
if txn.producer_epoch != producer_epoch {
return TransactionResult::ProducerFenced {
expected_epoch: txn.producer_epoch,
received_epoch: producer_epoch,
};
}
if !txn.state.is_active() {
return TransactionResult::InvalidTransactionState {
current: txn.state,
expected: "Ongoing",
};
}
if txn.is_timed_out() {
txn.state = TransactionState::Dead;
self.stats.record_timeout();
return TransactionResult::TransactionTimeout;
}
if let Err(e) = self.txn_log.append(&TransactionLogEntry::OffsetCommit {
txn_id: txn_id.clone(),
producer_id,
group_id: group_id.clone(),
offsets: offsets.clone(),
}) {
tracing::error!(
producer_id,
"Transaction log write failed on offset_commit: {}",
e
);
return TransactionResult::LogWriteError(e.to_string());
}
txn.add_offset_commit(group_id, offsets);
TransactionResult::Ok
}
pub fn prepare_commit(
&self,
txn_id: &TransactionId,
producer_id: ProducerId,
producer_epoch: ProducerEpoch,
) -> Result<Transaction, TransactionResult> {
let mut transactions = self.transactions.write();
let txn = match transactions.get_mut(&(producer_id, txn_id.clone())) {
Some(t) => t,
None => return Err(TransactionResult::InvalidTransactionId),
};
if txn.producer_epoch != producer_epoch {
return Err(TransactionResult::ProducerFenced {
expected_epoch: txn.producer_epoch,
received_epoch: producer_epoch,
});
}
if !txn.state.can_commit() {
return Err(TransactionResult::InvalidTransactionState {
current: txn.state,
expected: "Ongoing",
});
}
if txn.is_timed_out() {
txn.state = TransactionState::Dead;
self.stats.record_timeout();
return Err(TransactionResult::TransactionTimeout);
}
txn.state = TransactionState::PrepareCommit;
txn.touch();
if let Err(e) = self.txn_log.append(&TransactionLogEntry::PrepareCommit {
txn_id: txn_id.clone(),
producer_id,
}) {
tracing::error!(
producer_id,
"Transaction log write failed on prepare_commit: {}",
e
);
txn.state = TransactionState::Ongoing;
return Err(TransactionResult::LogWriteError(e.to_string()));
}
Ok(txn.clone())
}
pub fn complete_commit(
&self,
txn_id: &TransactionId,
producer_id: ProducerId,
) -> TransactionResult {
let mut transactions = self.transactions.write();
let mut producer_txns = self.producer_transactions.write();
let txn = match transactions.get_mut(&(producer_id, txn_id.clone())) {
Some(t) => t,
None => return TransactionResult::InvalidTransactionId,
};
if txn.state != TransactionState::PrepareCommit {
return TransactionResult::InvalidTransactionState {
current: txn.state,
expected: "PrepareCommit",
};
}
if let Err(e) = self.txn_log.append(&TransactionLogEntry::CompleteCommit {
txn_id: txn_id.clone(),
producer_id,
}) {
tracing::error!(
producer_id,
"Transaction log write failed on complete_commit: {}",
e
);
return TransactionResult::LogWriteError(e.to_string());
}
txn.state = TransactionState::CompleteCommit;
transactions.remove(&(producer_id, txn_id.clone()));
producer_txns.remove(&producer_id);
self.stats.record_commit();
TransactionResult::Ok
}
pub fn prepare_abort(
&self,
txn_id: &TransactionId,
producer_id: ProducerId,
producer_epoch: ProducerEpoch,
) -> Result<Transaction, TransactionResult> {
let mut transactions = self.transactions.write();
let txn = match transactions.get_mut(&(producer_id, txn_id.clone())) {
Some(t) => t,
None => return Err(TransactionResult::InvalidTransactionId),
};
if txn.producer_epoch != producer_epoch {
return Err(TransactionResult::ProducerFenced {
expected_epoch: txn.producer_epoch,
received_epoch: producer_epoch,
});
}
if !txn.state.can_abort() {
return Err(TransactionResult::InvalidTransactionState {
current: txn.state,
expected: "Ongoing or PrepareCommit",
});
}
txn.state = TransactionState::PrepareAbort;
txn.touch();
if let Err(e) = self.txn_log.append(&TransactionLogEntry::PrepareAbort {
txn_id: txn_id.clone(),
producer_id,
}) {
tracing::error!(
producer_id,
"Transaction log write failed on prepare_abort: {}",
e
);
txn.state = TransactionState::Ongoing;
return Err(TransactionResult::LogWriteError(e.to_string()));
}
Ok(txn.clone())
}
pub fn complete_abort(
&self,
txn_id: &TransactionId,
producer_id: ProducerId,
) -> TransactionResult {
let mut transactions = self.transactions.write();
let mut producer_txns = self.producer_transactions.write();
let txn = match transactions.get_mut(&(producer_id, txn_id.clone())) {
Some(t) => t,
None => return TransactionResult::InvalidTransactionId,
};
if txn.state != TransactionState::PrepareAbort {
return TransactionResult::InvalidTransactionState {
current: txn.state,
expected: "PrepareAbort",
};
}
if let Err(e) = self.txn_log.append(&TransactionLogEntry::CompleteAbort {
txn_id: txn_id.clone(),
producer_id,
}) {
tracing::error!(
producer_id,
"Transaction log write failed on complete_abort: {}",
e
);
return TransactionResult::LogWriteError(e.to_string());
}
txn.state = TransactionState::CompleteAbort;
let first = txn.pending_writes.iter().map(|w| w.offset).min();
let last = txn.pending_writes.iter().map(|w| w.offset).max();
if let (Some(f), Some(l)) = (first, last) {
self.aborted_index.record_abort(producer_id, f, l);
}
transactions.remove(&(producer_id, txn_id.clone()));
producer_txns.remove(&producer_id);
self.stats.record_abort();
TransactionResult::Ok
}
pub fn get_transaction(
&self,
txn_id: &TransactionId,
producer_id: ProducerId,
) -> Option<Transaction> {
let transactions = self.transactions.read();
transactions.get(&(producer_id, txn_id.clone())).cloned()
}
pub fn has_active_transaction(&self, producer_id: ProducerId) -> bool {
let producer_txns = self.producer_transactions.read();
producer_txns.contains_key(&producer_id)
}
pub fn get_active_transaction_id(&self, producer_id: ProducerId) -> Option<TransactionId> {
let producer_txns = self.producer_transactions.read();
producer_txns.get(&producer_id).cloned()
}
pub fn cleanup_timed_out_transactions(&self) -> Vec<Transaction> {
let mut timed_out = Vec::new();
let mut transactions = self.transactions.write();
let mut producer_txns = self.producer_transactions.write();
let keys_to_remove: Vec<_> = transactions
.iter()
.filter(|(_, txn)| txn.is_timed_out() && !txn.state.is_terminal())
.map(|(k, _)| k.clone())
.collect();
for key in keys_to_remove {
if let Some(txn) = transactions.get(&key) {
if let Err(e) = self.txn_log.append(&TransactionLogEntry::TimedOut {
txn_id: txn.txn_id.clone(),
producer_id: txn.producer_id,
}) {
tracing::error!(txn.producer_id, "Transaction log write failed on timeout: {} — skipping cleanup, will retry", e);
continue;
}
}
if let Some(mut txn) = transactions.remove(&key) {
txn.state = TransactionState::Dead;
producer_txns.remove(&txn.producer_id);
let first = txn.pending_writes.iter().map(|w| w.offset).min();
let last = txn.pending_writes.iter().map(|w| w.offset).max();
if let (Some(f), Some(l)) = (first, last) {
self.aborted_index.record_abort(txn.producer_id, f, l);
}
self.stats.record_timeout();
self.stats.record_abort();
timed_out.push(txn);
}
}
timed_out
}
pub fn active_count(&self) -> usize {
let transactions = self.transactions.read();
transactions
.values()
.filter(|t| !t.state.is_terminal())
.count()
}
pub fn is_aborted(&self, producer_id: ProducerId, offset: u64) -> bool {
self.aborted_index.is_aborted(producer_id, offset)
}
pub fn get_aborted_in_range(
&self,
start_offset: u64,
end_offset: u64,
) -> Vec<AbortedTransaction> {
self.aborted_index
.get_aborted_in_range(start_offset, end_offset)
}
pub fn aborted_index(&self) -> &AbortedTransactionIndex {
&self.aborted_index
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::str::FromStr;
#[test]
fn test_transaction_state_transitions() {
assert!(TransactionState::Empty.is_terminal());
assert!(TransactionState::CompleteCommit.is_terminal());
assert!(TransactionState::CompleteAbort.is_terminal());
assert!(TransactionState::Dead.is_terminal());
assert!(!TransactionState::Ongoing.is_terminal());
assert!(!TransactionState::PrepareCommit.is_terminal());
assert!(!TransactionState::PrepareAbort.is_terminal());
assert!(TransactionState::Ongoing.can_commit());
assert!(!TransactionState::Empty.can_commit());
assert!(!TransactionState::PrepareCommit.can_commit());
assert!(TransactionState::Ongoing.can_abort());
assert!(TransactionState::PrepareCommit.can_abort());
assert!(TransactionState::PrepareAbort.can_abort());
assert!(!TransactionState::Empty.can_abort());
}
#[test]
fn test_begin_transaction() {
let coordinator = TransactionCoordinator::new();
let result = coordinator.begin_transaction("txn-1".to_string(), 1, 0, None);
assert_eq!(result, TransactionResult::Ok);
let txn = coordinator.get_transaction(&"txn-1".to_string(), 1);
assert!(txn.is_some());
let txn = txn.unwrap();
assert_eq!(txn.state, TransactionState::Ongoing);
assert_eq!(txn.producer_id, 1);
assert_eq!(txn.producer_epoch, 0);
assert_eq!(coordinator.stats().transactions_started(), 1);
assert_eq!(coordinator.stats().active_transactions(), 1);
}
#[test]
fn test_concurrent_transaction_rejection() {
let coordinator = TransactionCoordinator::new();
coordinator.begin_transaction("txn-1".to_string(), 1, 0, None);
let result = coordinator.begin_transaction("txn-2".to_string(), 1, 0, None);
assert_eq!(result, TransactionResult::ConcurrentTransaction);
}
#[test]
fn test_add_partitions_to_transaction() {
let coordinator = TransactionCoordinator::new();
coordinator.begin_transaction("txn-1".to_string(), 1, 0, None);
let result = coordinator.add_partitions_to_transaction(
&"txn-1".to_string(),
1,
0,
vec![
TransactionPartition::new("topic-1", 0),
TransactionPartition::new("topic-1", 1),
TransactionPartition::new("topic-2", 0),
],
);
assert_eq!(result, TransactionResult::Ok);
let txn = coordinator
.get_transaction(&"txn-1".to_string(), 1)
.unwrap();
assert_eq!(txn.partitions.len(), 3);
}
#[test]
fn test_add_write_to_transaction() {
let coordinator = TransactionCoordinator::new();
coordinator.begin_transaction("txn-1".to_string(), 1, 0, None);
let partition = TransactionPartition::new("topic-1", 0);
coordinator.add_partitions_to_transaction(
&"txn-1".to_string(),
1,
0,
vec![partition.clone()],
);
let result =
coordinator.add_write_to_transaction(&"txn-1".to_string(), 1, 0, partition, 0, 100);
assert_eq!(result, TransactionResult::Ok);
let txn = coordinator
.get_transaction(&"txn-1".to_string(), 1)
.unwrap();
assert_eq!(txn.pending_writes.len(), 1);
assert_eq!(txn.pending_writes[0].offset, 100);
assert_eq!(txn.pending_writes[0].sequence, 0);
}
#[test]
fn test_write_to_non_registered_partition() {
let coordinator = TransactionCoordinator::new();
coordinator.begin_transaction("txn-1".to_string(), 1, 0, None);
let result = coordinator.add_write_to_transaction(
&"txn-1".to_string(),
1,
0,
TransactionPartition::new("topic-1", 0),
0,
100,
);
assert!(matches!(
result,
TransactionResult::PartitionNotInTransaction { .. }
));
}
#[test]
fn test_commit_transaction() {
let coordinator = TransactionCoordinator::new();
coordinator.begin_transaction("txn-1".to_string(), 1, 0, None);
let partition = TransactionPartition::new("topic-1", 0);
coordinator.add_partitions_to_transaction(
&"txn-1".to_string(),
1,
0,
vec![partition.clone()],
);
coordinator.add_write_to_transaction(&"txn-1".to_string(), 1, 0, partition, 0, 100);
let txn = coordinator.prepare_commit(&"txn-1".to_string(), 1, 0);
assert!(txn.is_ok());
let txn = txn.unwrap();
assert_eq!(txn.state, TransactionState::PrepareCommit);
let result = coordinator.complete_commit(&"txn-1".to_string(), 1);
assert_eq!(result, TransactionResult::Ok);
assert!(coordinator
.get_transaction(&"txn-1".to_string(), 1)
.is_none());
assert!(!coordinator.has_active_transaction(1));
assert_eq!(coordinator.stats().transactions_committed(), 1);
assert_eq!(coordinator.stats().active_transactions(), 0);
}
#[test]
fn test_abort_transaction() {
let coordinator = TransactionCoordinator::new();
coordinator.begin_transaction("txn-1".to_string(), 1, 0, None);
let partition = TransactionPartition::new("topic-1", 0);
coordinator.add_partitions_to_transaction(
&"txn-1".to_string(),
1,
0,
vec![partition.clone()],
);
coordinator.add_write_to_transaction(&"txn-1".to_string(), 1, 0, partition, 0, 100);
let txn = coordinator.prepare_abort(&"txn-1".to_string(), 1, 0);
assert!(txn.is_ok());
let result = coordinator.complete_abort(&"txn-1".to_string(), 1);
assert_eq!(result, TransactionResult::Ok);
assert!(coordinator
.get_transaction(&"txn-1".to_string(), 1)
.is_none());
assert_eq!(coordinator.stats().transactions_aborted(), 1);
}
#[test]
fn test_producer_fencing() {
let coordinator = TransactionCoordinator::new();
coordinator.begin_transaction("txn-1".to_string(), 1, 0, None);
let result = coordinator.add_partitions_to_transaction(
&"txn-1".to_string(),
1,
1, vec![TransactionPartition::new("topic-1", 0)],
);
assert!(matches!(
result,
TransactionResult::ProducerFenced {
expected_epoch: 0,
received_epoch: 1
}
));
}
#[test]
fn test_transaction_timeout() {
let coordinator = TransactionCoordinator::with_timeout(Duration::from_millis(1));
coordinator.begin_transaction("txn-1".to_string(), 1, 0, None);
std::thread::sleep(Duration::from_millis(5));
let result = coordinator.add_partitions_to_transaction(
&"txn-1".to_string(),
1,
0,
vec![TransactionPartition::new("topic-1", 0)],
);
assert_eq!(result, TransactionResult::TransactionTimeout);
}
#[test]
fn test_cleanup_timed_out_transactions() {
let coordinator = TransactionCoordinator::with_timeout(Duration::from_millis(1));
coordinator.begin_transaction("txn-1".to_string(), 1, 0, None);
coordinator.begin_transaction("txn-2".to_string(), 2, 0, None);
std::thread::sleep(Duration::from_millis(5));
let timed_out = coordinator.cleanup_timed_out_transactions();
assert_eq!(timed_out.len(), 2);
assert_eq!(coordinator.active_count(), 0);
assert_eq!(coordinator.stats().transactions_timed_out(), 2);
}
#[test]
fn test_add_offsets_to_transaction() {
let coordinator = TransactionCoordinator::new();
coordinator.begin_transaction("txn-1".to_string(), 1, 0, None);
let result = coordinator.add_offsets_to_transaction(
&"txn-1".to_string(),
1,
0,
"consumer-group-1".to_string(),
vec![
(TransactionPartition::new("input-topic", 0), 42),
(TransactionPartition::new("input-topic", 1), 100),
],
);
assert_eq!(result, TransactionResult::Ok);
let txn = coordinator
.get_transaction(&"txn-1".to_string(), 1)
.unwrap();
assert_eq!(txn.offset_commits.len(), 1);
assert_eq!(txn.offset_commits[0].group_id, "consumer-group-1");
assert_eq!(txn.offset_commits[0].offsets.len(), 2);
}
#[test]
fn test_invalid_state_transitions() {
let coordinator = TransactionCoordinator::new();
coordinator.begin_transaction("txn-1".to_string(), 1, 0, None);
coordinator
.prepare_commit(&"txn-1".to_string(), 1, 0)
.unwrap();
let result = coordinator.add_partitions_to_transaction(
&"txn-1".to_string(),
1,
0,
vec![TransactionPartition::new("topic-1", 0)],
);
assert!(matches!(
result,
TransactionResult::InvalidTransactionState { .. }
));
}
#[test]
fn test_abort_from_prepare_commit() {
let coordinator = TransactionCoordinator::new();
coordinator.begin_transaction("txn-1".to_string(), 1, 0, None);
coordinator
.prepare_commit(&"txn-1".to_string(), 1, 0)
.unwrap();
let result = coordinator.prepare_abort(&"txn-1".to_string(), 1, 0);
assert!(result.is_ok());
let result = coordinator.complete_abort(&"txn-1".to_string(), 1);
assert_eq!(result, TransactionResult::Ok);
}
#[test]
fn test_transaction_partition_hash() {
let p1 = TransactionPartition::new("topic", 0);
let p2 = TransactionPartition::new("topic", 0);
let p3 = TransactionPartition::new("topic", 1);
assert_eq!(p1, p2);
assert_ne!(p1, p3);
let mut set = HashSet::new();
set.insert(p1.clone());
set.insert(p2); set.insert(p3);
assert_eq!(set.len(), 2);
}
#[test]
fn test_resume_same_transaction() {
let coordinator = TransactionCoordinator::new();
coordinator.begin_transaction("txn-1".to_string(), 1, 0, None);
let result = coordinator.begin_transaction("txn-1".to_string(), 1, 0, None);
assert_eq!(result, TransactionResult::Ok);
assert_eq!(coordinator.active_count(), 1);
assert_eq!(coordinator.stats().transactions_started(), 1);
}
#[test]
fn test_stats_snapshot() {
let coordinator = TransactionCoordinator::new();
coordinator.begin_transaction("txn-1".to_string(), 1, 0, None);
coordinator
.prepare_commit(&"txn-1".to_string(), 1, 0)
.unwrap();
coordinator.complete_commit(&"txn-1".to_string(), 1);
let snapshot: TransactionStatsSnapshot = coordinator.stats().into();
assert_eq!(snapshot.transactions_started, 1);
assert_eq!(snapshot.transactions_committed, 1);
assert_eq!(snapshot.active_transactions, 0);
}
#[test]
fn test_isolation_level_from_u8() {
assert_eq!(IsolationLevel::from_u8(0), IsolationLevel::ReadUncommitted);
assert_eq!(IsolationLevel::from_u8(1), IsolationLevel::ReadCommitted);
assert_eq!(IsolationLevel::from_u8(2), IsolationLevel::ReadUncommitted); assert_eq!(
IsolationLevel::from_u8(255),
IsolationLevel::ReadUncommitted
);
}
#[test]
fn test_isolation_level_as_u8() {
assert_eq!(IsolationLevel::ReadUncommitted.as_u8(), 0);
assert_eq!(IsolationLevel::ReadCommitted.as_u8(), 1);
}
#[test]
fn test_isolation_level_from_str() {
assert_eq!(
IsolationLevel::from_str("read_uncommitted").unwrap(),
IsolationLevel::ReadUncommitted
);
assert_eq!(
IsolationLevel::from_str("read_committed").unwrap(),
IsolationLevel::ReadCommitted
);
assert_eq!(
IsolationLevel::from_str("READ_UNCOMMITTED").unwrap(),
IsolationLevel::ReadUncommitted
);
assert_eq!(
IsolationLevel::from_str("READ_COMMITTED").unwrap(),
IsolationLevel::ReadCommitted
);
assert!(IsolationLevel::from_str("invalid").is_err());
}
#[test]
fn test_isolation_level_default() {
assert_eq!(IsolationLevel::default(), IsolationLevel::ReadUncommitted);
}
#[test]
fn test_aborted_transaction_index_basic() {
let index = AbortedTransactionIndex::new();
assert!(index.is_empty());
assert_eq!(index.len(), 0);
index.record_abort(1, 100, 200);
assert!(!index.is_empty());
assert_eq!(index.len(), 1);
assert!(index.is_aborted(1, 100)); assert!(index.is_aborted(1, 150)); assert!(index.is_aborted(1, 200)); assert!(!index.is_aborted(1, 201)); assert!(!index.is_aborted(1, 50)); assert!(!index.is_aborted(2, 100)); }
#[test]
fn test_aborted_transaction_index_multiple() {
let index = AbortedTransactionIndex::new();
index.record_abort(1, 100, 199);
index.record_abort(1, 300, 399);
index.record_abort(2, 200, 299);
assert_eq!(index.len(), 3);
assert!(index.is_aborted(1, 100));
assert!(index.is_aborted(1, 150)); assert!(!index.is_aborted(1, 250)); assert!(index.is_aborted(1, 300));
assert!(index.is_aborted(1, 399)); assert!(!index.is_aborted(1, 400)); assert!(index.is_aborted(2, 200));
assert!(index.is_aborted(2, 250));
assert!(!index.is_aborted(2, 100)); assert!(!index.is_aborted(2, 300)); }
#[test]
fn test_aborted_transaction_index_get_range() {
let index = AbortedTransactionIndex::new();
index.record_abort(1, 100, 199);
index.record_abort(2, 200, 299);
index.record_abort(1, 300, 399);
let in_range = index.get_aborted_in_range(150, 250);
assert_eq!(in_range.len(), 1);
assert_eq!(in_range[0].producer_id, 2);
assert_eq!(in_range[0].first_offset, 200);
let in_range = index.get_aborted_in_range(0, 500);
assert_eq!(in_range.len(), 3);
let in_range = index.get_aborted_in_range(400, 500);
assert_eq!(in_range.len(), 0);
}
#[test]
fn test_aborted_transaction_index_truncate() {
let index = AbortedTransactionIndex::new();
index.record_abort(1, 100, 199);
index.record_abort(2, 200, 299);
index.record_abort(1, 300, 399);
assert_eq!(index.len(), 3);
index.truncate_before(200);
assert_eq!(index.len(), 2);
assert!(!index.is_aborted(1, 150)); assert!(index.is_aborted(2, 200));
assert!(index.is_aborted(1, 300));
}
#[test]
fn test_coordinator_is_aborted() {
let coordinator = TransactionCoordinator::new();
coordinator.begin_transaction("txn-1".to_string(), 1, 0, None);
coordinator.add_partitions_to_transaction(
&"txn-1".to_string(),
1,
0,
vec![TransactionPartition::new("test-topic", 0)],
);
coordinator.add_write_to_transaction(
&"txn-1".to_string(),
1,
0,
TransactionPartition::new("test-topic", 0),
0,
100, );
assert!(!coordinator.is_aborted(1, 100));
coordinator
.prepare_abort(&"txn-1".to_string(), 1, 0)
.unwrap();
coordinator.complete_abort(&"txn-1".to_string(), 1);
assert!(coordinator.is_aborted(1, 100));
assert!(!coordinator.is_aborted(1, 150));
assert!(!coordinator.is_aborted(1, 50)); assert!(!coordinator.is_aborted(2, 100)); }
#[test]
fn test_transaction_log_round_trip() {
let dir = tempfile::tempdir().unwrap();
let log_path = dir.path().join("txn.log");
let log = TransactionLog::open(&log_path).unwrap();
log.append(&TransactionLogEntry::Begin {
txn_id: "txn-1".to_string(),
producer_id: 42,
producer_epoch: 0,
timeout_ms: 30000,
})
.unwrap();
log.append(&TransactionLogEntry::AddPartition {
txn_id: "txn-1".to_string(),
producer_id: 42,
partition: TransactionPartition::new("topic-a", 0),
})
.unwrap();
log.append(&TransactionLogEntry::RecordWrite {
txn_id: "txn-1".to_string(),
producer_id: 42,
partition: TransactionPartition::new("topic-a", 0),
sequence: 0,
offset: 100,
})
.unwrap();
drop(log);
let entries = TransactionLog::read_all(&log_path).unwrap();
assert_eq!(entries.len(), 3);
assert!(
matches!(&entries[0], TransactionLogEntry::Begin { txn_id, producer_id, .. } if txn_id == "txn-1" && *producer_id == 42)
);
assert!(
matches!(&entries[1], TransactionLogEntry::AddPartition { partition, .. } if partition.topic == "topic-a" && partition.partition == 0)
);
assert!(
matches!(&entries[2], TransactionLogEntry::RecordWrite { offset, .. } if *offset == 100)
);
}
#[test]
fn test_transaction_log_crc_corruption_detection() {
let dir = tempfile::tempdir().unwrap();
let log_path = dir.path().join("txn.log");
let log = TransactionLog::open(&log_path).unwrap();
log.append(&TransactionLogEntry::Begin {
txn_id: "txn-1".to_string(),
producer_id: 1,
producer_epoch: 0,
timeout_ms: 5000,
})
.unwrap();
log.append(&TransactionLogEntry::PrepareCommit {
txn_id: "txn-1".to_string(),
producer_id: 1,
})
.unwrap();
drop(log);
let mut data = std::fs::read(&log_path).unwrap();
assert!(data.len() > 10);
data[10] ^= 0xFF; std::fs::write(&log_path, &data).unwrap();
let entries = TransactionLog::read_all(&log_path).unwrap();
assert!(entries.len() <= 1);
}
#[test]
fn test_coordinator_with_persistence_commit_flow() {
let dir = tempfile::tempdir().unwrap();
let log_path = dir.path().join("txn.log");
let coord = TransactionCoordinator::with_persistence(&log_path).unwrap();
assert_eq!(
coord.begin_transaction("txn-1".to_string(), 1, 0, None),
TransactionResult::Ok
);
assert_eq!(
coord.add_partitions_to_transaction(
&"txn-1".to_string(),
1,
0,
vec![TransactionPartition::new("topic", 0)],
),
TransactionResult::Ok
);
assert_eq!(
coord.add_write_to_transaction(
&"txn-1".to_string(),
1,
0,
TransactionPartition::new("topic", 0),
0,
500,
),
TransactionResult::Ok
);
coord.prepare_commit(&"txn-1".to_string(), 1, 0).unwrap();
assert_eq!(
coord.complete_commit(&"txn-1".to_string(), 1),
TransactionResult::Ok
);
let entries = TransactionLog::read_all(&log_path).unwrap();
assert_eq!(entries.len(), 5);
}
#[test]
fn test_coordinator_recovery_from_crash() {
let dir = tempfile::tempdir().unwrap();
let log_path = dir.path().join("txn.log");
{
let coord = TransactionCoordinator::with_persistence(&log_path).unwrap();
coord.begin_transaction("txn-1".to_string(), 1, 0, None);
coord.add_partitions_to_transaction(
&"txn-1".to_string(),
1,
0,
vec![TransactionPartition::new("topic", 0)],
);
coord.add_write_to_transaction(
&"txn-1".to_string(),
1,
0,
TransactionPartition::new("topic", 0),
0,
42,
);
coord.prepare_commit(&"txn-1".to_string(), 1, 0).unwrap();
}
let coord = TransactionCoordinator::recover(&log_path).unwrap();
let txn = coord.get_transaction(&"txn-1".to_string(), 1);
assert!(txn.is_some(), "Transaction should be recovered from WAL");
let txn = txn.unwrap();
assert_eq!(txn.state, TransactionState::PrepareCommit);
assert_eq!(txn.pending_writes.len(), 1);
assert_eq!(txn.pending_writes[0].offset, 42);
assert_eq!(
coord.complete_commit(&"txn-1".to_string(), 1),
TransactionResult::Ok
);
assert_eq!(coord.active_count(), 0);
}
#[test]
fn test_coordinator_recovery_abort_flow() {
let dir = tempfile::tempdir().unwrap();
let log_path = dir.path().join("txn.log");
{
let coord = TransactionCoordinator::with_persistence(&log_path).unwrap();
coord.begin_transaction("txn-a".to_string(), 10, 0, None);
coord.add_partitions_to_transaction(
&"txn-a".to_string(),
10,
0,
vec![TransactionPartition::new("t", 0)],
);
coord.add_write_to_transaction(
&"txn-a".to_string(),
10,
0,
TransactionPartition::new("t", 0),
0,
200,
);
coord.prepare_abort(&"txn-a".to_string(), 10, 0).unwrap();
coord.complete_abort(&"txn-a".to_string(), 10);
}
let coord = TransactionCoordinator::recover(&log_path).unwrap();
assert_eq!(coord.active_count(), 0);
assert!(coord.is_aborted(10, 200));
}
#[test]
fn test_transaction_log_noop_is_silent() {
let log = TransactionLog::noop();
assert!(log
.append(&TransactionLogEntry::Begin {
txn_id: "x".to_string(),
producer_id: 1,
producer_epoch: 0,
timeout_ms: 1000,
})
.is_ok());
}
#[test]
fn test_log_write_error_propagated() {
let coord = TransactionCoordinator::new();
let result = coord.begin_transaction("txn-ok".to_string(), 1, 0, None);
assert_eq!(result, TransactionResult::Ok);
}
}