use std::future::Future;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU8, Ordering};
use std::time::Duration;
use bytes::Bytes;
use tokio::sync::{Notify, RwLock};
use tracing::{debug, info, warn};
use crate::PartitionId;
use crate::auth::AuthConfig;
use crate::error::{ErrorCode, KrafkaError, Result};
use crate::metadata::ClusterMetadata;
use crate::network::{BrokerConnection, ConnectionConfig, ConnectionPool};
use crate::protocol::{
AddOffsetsToTxnRequest, AddOffsetsToTxnResponse, AddPartitionsToTxnRequest,
AddPartitionsToTxnResponse, ApiKey, Compression, EndTxnRequest, EndTxnResponse,
FindCoordinatorRequest, FindCoordinatorResponse, InitProducerIdRequest, InitProducerIdResponse,
ProducePartitionData, ProduceRequest, ProduceResponse, ProduceTopicData, RecordBatchBuilder,
TxnOffsetCommitRequest, TxnOffsetCommitResponse, VersionedDecode, VersionedEncode, versions,
};
use super::barrier::InFlightBarrier;
use super::config::Acks;
use super::idempotent::ProducerIdentity;
use super::partitioner::{DefaultPartitioner, Partitioner};
use super::record::{ProducerRecord, RecordMetadata, RoutedRecord, TopicHandle};
use super::retry::RetryPolicy;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
pub enum TransactionState {
Uninitialized = 0,
Ready = 1,
InTransaction = 2,
Committing = 3,
Aborting = 4,
FatalError = 5,
Initializing = 6,
}
impl TransactionState {
fn from_u8(v: u8) -> Self {
match v {
0 => Self::Uninitialized,
1 => Self::Ready,
2 => Self::InTransaction,
3 => Self::Committing,
4 => Self::Aborting,
6 => Self::Initializing,
_ => Self::FatalError,
}
}
}
#[derive(Debug, Clone)]
pub struct TransactionalProducerConfig {
pub bootstrap_servers: String,
pub client_id: String,
pub transactional_id: String,
pub transaction_timeout_ms: i32,
pub request_timeout: Duration,
pub max_request_size: usize,
pub compression: Compression,
pub metadata_max_age: Duration,
pub auth: Option<AuthConfig>,
#[cfg(feature = "socks5")]
pub proxy: Option<crate::network::ProxyConfig>,
}
impl Default for TransactionalProducerConfig {
fn default() -> Self {
Self {
bootstrap_servers: String::new(),
client_id: "krafka-txn-producer".to_string(),
transactional_id: String::new(),
transaction_timeout_ms: 60000,
request_timeout: Duration::from_secs(30),
max_request_size: crate::protocol::MAX_MESSAGE_SIZE,
compression: Compression::None,
metadata_max_age: Duration::from_secs(300),
auth: None,
#[cfg(feature = "socks5")]
proxy: None,
}
}
}
#[derive(Debug, Clone)]
enum PartitionAddState {
Pending(Arc<Notify>),
Added,
}
enum BeginAddResult {
AlreadyAdded,
Wait(Arc<Notify>),
NeedAdd(Arc<Notify>),
}
#[derive(Debug, Default)]
struct TransactionPartitions {
partitions: std::collections::HashMap<
String,
std::collections::HashMap<PartitionId, PartitionAddState>,
>,
}
impl TransactionPartitions {
fn begin_add(&mut self, topic: &str, partition: PartitionId) -> BeginAddResult {
if let Some(topic_map) = self.partitions.get(topic) {
match topic_map.get(&partition) {
Some(PartitionAddState::Added) => return BeginAddResult::AlreadyAdded,
Some(PartitionAddState::Pending(notify)) => {
return BeginAddResult::Wait(notify.clone());
}
None => {}
}
}
let notify = Arc::new(Notify::new());
self.partitions
.entry(topic.to_string())
.or_default()
.insert(partition, PartitionAddState::Pending(notify.clone()));
BeginAddResult::NeedAdd(notify)
}
fn confirm_add(&mut self, topic: &str, partition: PartitionId, notify: &Notify) {
self.partitions
.entry(topic.to_string())
.or_default()
.insert(partition, PartitionAddState::Added);
notify.notify_waiters();
}
fn cancel_add(&mut self, topic: &str, partition: PartitionId, notify: &Notify) {
if let Some(topic_map) = self.partitions.get_mut(topic) {
topic_map.remove(&partition);
if topic_map.is_empty() {
self.partitions.remove(topic);
}
}
notify.notify_waiters();
}
fn clear(&mut self) {
self.partitions.clear();
}
#[cfg(test)]
fn is_empty(&self) -> bool {
self.partitions.is_empty()
}
}
struct PendingAddGuard {
txn_partitions: Arc<RwLock<TransactionPartitions>>,
topic: TopicHandle,
partition: PartitionId,
notify: Arc<Notify>,
defused: bool,
}
impl PendingAddGuard {
async fn confirm(mut self, topic: &str, partition: PartitionId) {
self.defused = true;
let mut txn_partitions = self.txn_partitions.write().await;
txn_partitions.confirm_add(topic, partition, &self.notify);
}
async fn cancel(mut self, topic: &str, partition: PartitionId) {
self.defused = true;
let mut txn_partitions = self.txn_partitions.write().await;
txn_partitions.cancel_add(topic, partition, &self.notify);
}
}
impl Drop for PendingAddGuard {
fn drop(&mut self) {
if !self.defused {
let topic = self.topic.clone();
let partition = self.partition;
let notify = self.notify.clone();
if let Ok(mut tp) = self.txn_partitions.try_write() {
tp.cancel_add(&topic, partition, ¬ify);
} else if let Ok(handle) = tokio::runtime::Handle::try_current() {
let txn_partitions = self.txn_partitions.clone();
handle.spawn(async move {
let mut tp = txn_partitions.write().await;
tp.cancel_add(&topic, partition, ¬ify);
});
} else {
let mut tp = self.txn_partitions.blocking_write();
tp.cancel_add(&topic, partition, ¬ify);
}
}
}
}
pub struct TransactionalProducer {
config: TransactionalProducerConfig,
metadata: Arc<ClusterMetadata>,
pool: Arc<ConnectionPool>,
partitioner: Arc<dyn Partitioner>,
state: AtomicU8,
abort_required: AtomicBool,
coordinator_id: RwLock<Option<i32>>,
txn_partitions: Arc<RwLock<TransactionPartitions>>,
identity: ProducerIdentity,
retry_policy: RetryPolicy,
in_flight_barrier: Arc<InFlightBarrier>,
}
impl TransactionalProducer {
pub fn builder() -> TransactionalProducerBuilder {
TransactionalProducerBuilder::default()
}
#[inline]
pub fn state(&self) -> TransactionState {
TransactionState::from_u8(self.state.load(Ordering::SeqCst))
}
fn checked_transactional_identity(&self) -> Result<(i64, i16)> {
let producer_id = self.identity.producer_id();
let producer_epoch = self.identity.producer_epoch();
if producer_id < 0 || producer_epoch < 0 {
return Err(KrafkaError::invalid_state(
"transactional producer identity not initialized",
));
}
debug_assert!(
producer_id >= 0 && producer_epoch >= 0,
"transactional producer identity must be initialized before sending"
);
Ok((producer_id, producer_epoch))
}
#[inline]
fn abort_required(&self) -> bool {
self.abort_required.load(Ordering::SeqCst)
}
fn ensure_transaction_can_continue(&self, operation: &str) -> Result<()> {
if self.abort_required() {
return Err(KrafkaError::broker(
ErrorCode::TransactionAbortable,
format!("cannot {operation}: abort_transaction() is required before continuing"),
));
}
Ok(())
}
fn mark_unknown_producer_id_abort_required(&self, operation: &str) -> KrafkaError {
self.abort_required.store(true, Ordering::SeqCst);
KrafkaError::broker(
ErrorCode::TransactionAbortable,
format!(
"{operation} failed with UnknownProducerId; abort_transaction() is required before continuing"
),
)
}
fn is_unknown_producer_id_error(error: &KrafkaError) -> bool {
matches!(
error,
KrafkaError::Broker {
code: ErrorCode::UnknownProducerId,
..
}
)
}
fn is_abortable_transaction_error(error: &KrafkaError) -> bool {
matches!(
error,
KrafkaError::Broker {
code: ErrorCode::TransactionAbortable,
..
}
)
}
async fn coordinator_connection(&self) -> Result<(i32, Arc<BrokerConnection>)> {
let coordinator_id = {
let cached = *self.coordinator_id.read().await;
match cached {
Some(id) => id,
None => {
let id = self.find_coordinator().await?;
*self.coordinator_id.write().await = Some(id);
debug!("Auto-discovered transaction coordinator: broker {}", id);
id
}
}
};
let brokers = self.metadata.brokers();
let broker = brokers
.iter()
.find(|b| b.id == coordinator_id)
.ok_or_else(|| KrafkaError::protocol("coordinator not found in metadata"))?;
let conn = self
.pool
.get_connection_by_id(broker.id, broker.address())
.await?;
Ok((coordinator_id, conn))
}
fn needs_coordinator_refresh(err: &KrafkaError) -> bool {
match err {
KrafkaError::Broker { code, .. } => matches!(
code,
ErrorCode::NotCoordinator
| ErrorCode::CoordinatorNotAvailable
| ErrorCode::CoordinatorLoadInProgress
),
KrafkaError::Network(_) | KrafkaError::Timeout { .. } => true,
_ => false,
}
}
async fn invalidate_coordinator(&self) {
*self.coordinator_id.write().await = None;
}
async fn retry_with_coordinator<F, Fut>(&self, op_name: &str, op: F) -> Result<()>
where
F: Fn() -> Fut,
Fut: Future<Output = Result<()>>,
{
let max_retries = self.retry_policy.max_retries;
for attempt in 0..=max_retries {
if attempt > 0 {
tokio::time::sleep(self.retry_policy.calculate_backoff(attempt)).await;
}
let result = op().await;
match &result {
Ok(()) => return Ok(()),
Err(e) if Self::is_unknown_producer_id_error(e) => return result,
Err(e) if Self::needs_coordinator_refresh(e) && attempt < max_retries => {
warn!(
attempt,
error = %e,
op_name,
"Coordinator error, refreshing and retrying"
);
self.invalidate_coordinator().await;
}
Err(e) if e.is_retriable() && attempt < max_retries => {
warn!(
attempt,
error = %e,
op_name,
"Retriable error, retrying"
);
}
Err(_) => return result,
}
}
Err(KrafkaError::protocol(format!(
"{op_name} retry loop exhausted after {max_retries} retries"
)))
}
fn set_state(&self, state: TransactionState) {
self.state.store(state as u8, Ordering::SeqCst);
}
fn try_transition(
&self,
expected: TransactionState,
new: TransactionState,
) -> std::result::Result<(), TransactionState> {
self.state
.compare_exchange(
expected as u8,
new as u8,
Ordering::SeqCst,
Ordering::SeqCst,
)
.map(|_| ())
.map_err(TransactionState::from_u8)
}
pub async fn init_transactions(&self) -> Result<()> {
if let Err(actual) = self.try_transition(
TransactionState::Uninitialized,
TransactionState::Initializing,
) {
return Err(KrafkaError::invalid_state(format!(
"init_transactions can only be called once (state={:?})",
actual
)));
}
let result = self.do_init_transactions().await;
if result.is_err() {
self.set_state(TransactionState::Uninitialized);
}
result
}
async fn do_init_transactions(&self) -> Result<()> {
self.retry_with_coordinator("InitProducerId", || async {
let (_coordinator_id, conn) = self.coordinator_connection().await?;
let ip_version = conn
.negotiate_api_version(
ApiKey::InitProducerId,
versions::INIT_PRODUCER_ID_MAX,
versions::INIT_PRODUCER_ID_MIN,
)
.await
.ok_or_else(|| {
KrafkaError::protocol("no mutually supported InitProducerId API version")
})?;
let request = InitProducerIdRequest::transactional(
&self.config.transactional_id,
self.config.transaction_timeout_ms,
);
let response_bytes = conn
.send_request(ApiKey::InitProducerId, ip_version, |buf| {
request.encode_versioned(ip_version, buf)
})
.await?;
let mut buf = response_bytes;
let response = InitProducerIdResponse::decode_versioned(ip_version, &mut buf)?;
if !response.is_ok() {
return Err(KrafkaError::broker(
response.error_code,
"failed to initialize producer ID",
));
}
self.identity
.initialize(response.producer_id, response.producer_epoch);
self.abort_required.store(false, Ordering::SeqCst);
self.set_state(TransactionState::Ready);
info!(
"Transactional producer initialized: PID={}, epoch={}",
response.producer_id, response.producer_epoch
);
Ok(())
})
.await
}
async fn find_coordinator(&self) -> Result<i32> {
let brokers = self.metadata.brokers();
if brokers.is_empty() {
return Err(KrafkaError::protocol("no brokers available"));
}
let broker = &brokers[0];
let conn = self
.pool
.get_connection_by_id(broker.id, broker.address())
.await?;
let request = FindCoordinatorRequest::for_transaction(&self.config.transactional_id);
let fc_version = conn
.negotiate_api_version(
ApiKey::FindCoordinator,
versions::FIND_COORDINATOR_MAX,
versions::FIND_COORDINATOR_MIN,
)
.await
.ok_or_else(|| {
KrafkaError::protocol(
"no mutually supported FindCoordinator API version; \
transactional coordinator lookup requires v1+",
)
})?;
let response_bytes = conn
.send_request(ApiKey::FindCoordinator, fc_version, |buf| {
request.encode_versioned(fc_version, buf)
})
.await?;
let mut buf = response_bytes;
let response = FindCoordinatorResponse::decode_versioned(fc_version, &mut buf)?;
if !response.error_code.is_ok() {
return Err(KrafkaError::broker(
response.error_code,
"failed to find transaction coordinator",
));
}
debug!(
"Found transaction coordinator: broker {} at {}:{}",
response.node_id, response.host, response.port
);
Ok(response.node_id)
}
pub fn begin_transaction(&self) -> Result<()> {
if let Err(actual) =
self.try_transition(TransactionState::Ready, TransactionState::InTransaction)
{
return Err(KrafkaError::invalid_state(format!(
"cannot begin transaction in state {:?}",
actual
)));
}
debug!("Transaction started");
Ok(())
}
pub async fn send(
&self,
topic: &str,
key: Option<&[u8]>,
value: &[u8],
) -> Result<RecordMetadata> {
let mut record = ProducerRecord::new(topic, Bytes::copy_from_slice(value));
if let Some(k) = key {
record = record.with_key(Bytes::copy_from_slice(k));
}
self.send_record(record).await
}
pub async fn send_record(&self, record: ProducerRecord) -> Result<RecordMetadata> {
let _operation_guard = self.in_flight_barrier.start("transactional producer")?;
let current = self.state();
if current != TransactionState::InTransaction {
return Err(KrafkaError::invalid_state(format!(
"cannot send in state {:?}",
current
)));
}
self.ensure_transaction_can_continue("send records")?;
record.validate()?;
let _identity = self.checked_transactional_identity()?;
let routed = record.into_routed_parts();
let topic = routed.topic;
let record = routed.record;
let partition = match routed.partition {
Some(p) => p,
None => {
let partition_count = self
.metadata
.partition_count(topic.as_ref())
.ok_or_else(|| KrafkaError::invalid_state(format!("unknown topic: {topic}")))?;
self.partitioner
.partition(topic.as_ref(), record.key_bytes(), partition_count)
}
};
loop {
let mut txn_partitions = self.txn_partitions.write().await;
match txn_partitions.begin_add(topic.as_ref(), partition) {
BeginAddResult::AlreadyAdded => break,
BeginAddResult::Wait(notify) => {
let notified = notify.notified();
tokio::pin!(notified);
notified.as_mut().enable();
drop(txn_partitions);
notified.await;
}
BeginAddResult::NeedAdd(notify) => {
drop(txn_partitions);
let guard = PendingAddGuard {
txn_partitions: self.txn_partitions.clone(),
topic: topic.clone(),
partition,
notify,
defused: false,
};
match self.add_partition_to_txn(topic.as_ref(), partition).await {
Ok(()) => {
guard.confirm(topic.as_ref(), partition).await;
}
Err(e) => {
guard.cancel(topic.as_ref(), partition).await;
return Err(e);
}
}
break;
}
}
}
self.send_to_partition(topic, partition, record).await
}
async fn add_partition_to_txn(&self, topic: &str, partition: PartitionId) -> Result<()> {
let result = self.retry_with_coordinator("AddPartitionsToTxn", || async {
let (_coordinator_id, conn) = self.coordinator_connection().await?;
let (producer_id, producer_epoch) = self.checked_transactional_identity()?;
let apt_version = conn
.negotiate_api_version(
ApiKey::AddPartitionsToTxn,
versions::ADD_PARTITIONS_TO_TXN_MAX,
versions::ADD_PARTITIONS_TO_TXN_MIN,
)
.await
.ok_or_else(|| {
KrafkaError::protocol("no mutually supported AddPartitionsToTxn API version")
})?;
let request = AddPartitionsToTxnRequest::new(
&self.config.transactional_id,
producer_id,
producer_epoch,
)
.add_partition(topic, partition);
let response_bytes = conn
.send_request(ApiKey::AddPartitionsToTxn, apt_version, |buf| {
request.encode_versioned(apt_version, buf)
})
.await?;
let mut buf = response_bytes;
let response = AddPartitionsToTxnResponse::decode_versioned(apt_version, &mut buf)?;
if !response.is_ok() {
for topic_result in &response.results {
for partition_result in &topic_result.partitions {
if !partition_result.error_code.is_ok() {
return Err(KrafkaError::broker(
partition_result.error_code,
format!("failed to add {}-{} to transaction", topic, partition),
));
}
}
}
return Err(KrafkaError::protocol(format!(
"failed to add {}-{} to transaction: response indicated error but no per-partition error found",
topic, partition
)));
}
debug!("Added partition {}-{} to transaction", topic, partition);
Ok(())
})
.await;
match result {
Err(error) if Self::is_unknown_producer_id_error(&error) => {
Err(self.mark_unknown_producer_id_abort_required("AddPartitionsToTxn"))
}
other => other,
}
}
async fn send_to_partition(
&self,
topic: TopicHandle,
partition: PartitionId,
record: RoutedRecord,
) -> Result<RecordMetadata> {
let retry_policy = &self.retry_policy;
let max_retries = retry_policy.max_retries;
let (producer_id, producer_epoch) = self.checked_transactional_identity()?;
let mut sequence = self.next_sequence(topic.as_ref(), partition).await?;
let mut request = match self.build_produce_request(
topic.as_ref(),
partition,
&record,
producer_id,
producer_epoch,
sequence,
) {
Ok(req) => req,
Err(e) => {
let _ = self.identity.rollback_sequence(topic.as_ref(), partition);
return Err(e);
}
};
for attempt in 0..=max_retries {
let send_result: Result<RecordMetadata> = async {
let conn = self
.metadata
.get_leader_connection(topic.as_ref(), partition)
.await?;
let version = conn
.negotiate_api_version(
ApiKey::Produce,
versions::PRODUCE_MAX,
versions::PRODUCE_MIN,
)
.await
.ok_or_else(|| {
KrafkaError::protocol(
"no mutually supported Produce API version; \
transactional produce requires v3+",
)
})?;
super::validate_produce_request_size(
&self.config.client_id,
self.config.max_request_size,
version,
&request,
)?;
let response = conn
.send_request(ApiKey::Produce, version, |buf| {
request.encode_versioned(version, buf)
})
.await?;
let mut buf = response;
let produce_response = ProduceResponse::decode_versioned(version, &mut buf)?;
for topic_response in &produce_response.responses {
for partition_response in &topic_response.partition_responses {
if partition_response.index == partition {
if !partition_response.error_code.is_ok() {
if is_fatal_transaction_error(partition_response.error_code) {
self.set_state(TransactionState::FatalError);
}
return Err(KrafkaError::broker(
partition_response.error_code,
format!("produce failed for {topic}-{partition}"),
));
}
self.identity
.acknowledge(topic.as_ref(), partition, sequence);
return Ok(RecordMetadata {
topic: topic.to_string(),
partition,
offset: partition_response.base_offset,
timestamp: partition_response.log_append_time_ms,
});
}
}
}
Err(KrafkaError::protocol("partition not found in response"))
}
.await;
match send_result {
Ok(metadata) => return Ok(metadata),
Err(e) => {
if Self::is_unknown_producer_id_error(&e) {
return Err(
self.mark_unknown_producer_id_abort_required("transactional produce")
);
}
if let KrafkaError::Broker { code, .. } = &e
&& *code == ErrorCode::OutOfOrderSequenceNumber
{
if attempt >= max_retries {
return Err(e);
}
warn!(
topic = %topic,
partition = partition,
"OutOfOrderSequenceNumber, resetting sequence and rebuilding batch"
);
self.identity.reset_sequence(topic.as_ref(), partition);
sequence = self.next_sequence(topic.as_ref(), partition).await?;
request = self.build_produce_request(
topic.as_ref(),
partition,
&record,
producer_id,
producer_epoch,
sequence,
)?;
tokio::time::sleep(retry_policy.calculate_backoff(attempt + 1)).await;
continue;
}
if !e.is_retriable() || attempt >= max_retries {
return Err(e);
}
debug!(
topic = %topic,
partition = partition,
attempt = attempt + 1,
"Transient error in txn send, retrying: {}",
e
);
if should_refresh_metadata_after_txn_send_error(&e)
&& let Err(refresh_err) = self
.metadata
.refresh_for_topics(Some(&[topic.as_ref()]))
.await
{
debug!(error = %refresh_err, "Metadata refresh failed during txn retry");
}
tokio::time::sleep(retry_policy.calculate_backoff(attempt + 1)).await;
}
}
}
Err(KrafkaError::protocol(format!(
"transactional produce retry loop exhausted after {max_retries} retries"
)))
}
fn build_produce_request(
&self,
topic: &str,
partition: PartitionId,
record: &RoutedRecord,
producer_id: i64,
producer_epoch: i16,
sequence: i32,
) -> Result<ProduceRequest> {
let mut batch_builder = RecordBatchBuilder::new()
.compression(self.config.compression)
.producer(producer_id, producer_epoch, sequence)
.transactional(true);
if let Some(ts) = record.timestamp {
batch_builder = batch_builder.base_timestamp(ts);
}
batch_builder = record.append_to_batch_builder(batch_builder);
let batch = batch_builder.build();
let batch_bytes = batch.encode()?;
Ok(ProduceRequest {
transactional_id: Some(self.config.transactional_id.clone()),
acks: Acks::All.to_i16(),
timeout_ms: crate::util::duration_to_millis_i32(self.config.request_timeout),
topic_data: vec![ProduceTopicData {
name: topic.to_string(),
topic_id: None,
partition_data: vec![ProducePartitionData {
index: partition,
records: batch_bytes,
}],
}],
})
}
pub async fn send_offsets_to_transaction(
&self,
offsets: std::collections::HashMap<(String, PartitionId), i64>,
group_id: &str,
) -> Result<()> {
let current = self.state();
if current != TransactionState::InTransaction {
return Err(KrafkaError::invalid_state(format!(
"cannot send offsets in state {:?}",
current
)));
}
self.ensure_transaction_can_continue("send offsets")?;
let (producer_id, producer_epoch) = self.checked_transactional_identity()?;
let add_offsets_result = self
.retry_with_coordinator("AddOffsetsToTxn", || async {
let (_coordinator_id, conn) = self.coordinator_connection().await?;
let add_request = AddOffsetsToTxnRequest::new(
&self.config.transactional_id,
producer_id,
producer_epoch,
group_id,
);
let aot_version = conn
.negotiate_api_version(
ApiKey::AddOffsetsToTxn,
versions::ADD_OFFSETS_TO_TXN_MAX,
versions::ADD_OFFSETS_TO_TXN_MIN,
)
.await
.ok_or_else(|| {
KrafkaError::protocol("no mutually supported AddOffsetsToTxn API version")
})?;
let response_bytes = conn
.send_request(ApiKey::AddOffsetsToTxn, aot_version, |buf| {
add_request.encode_versioned(aot_version, buf)
})
.await?;
let mut buf = response_bytes;
let add_response =
AddOffsetsToTxnResponse::decode_versioned(aot_version, &mut buf)?;
if !add_response.is_ok() {
return Err(KrafkaError::broker(
add_response.error_code,
"failed to add offsets to transaction",
));
}
Ok(())
})
.await;
match add_offsets_result {
Err(error) if Self::is_unknown_producer_id_error(&error) => {
return Err(self.mark_unknown_producer_id_abort_required("AddOffsetsToTxn"));
}
Err(error) => return Err(error),
Ok(()) => {}
}
let mut commit_request = TxnOffsetCommitRequest::new(
&self.config.transactional_id,
group_id,
producer_id,
producer_epoch,
);
for ((topic, partition), offset) in offsets {
commit_request = commit_request.add_offset(&topic, partition, offset, None);
}
let max_retries = self.retry_policy.max_retries;
for attempt in 0..=max_retries {
if attempt > 0 {
tokio::time::sleep(self.retry_policy.calculate_backoff(attempt)).await;
}
let result: Result<()> = async {
let (group_node_id, group_host, group_port) =
self.find_group_coordinator(group_id).await?;
let group_addr = format!("{group_host}:{group_port}");
let group_conn = self
.pool
.get_connection_by_id(group_node_id, &group_addr)
.await?;
let toc_version = group_conn
.negotiate_api_version(
ApiKey::TxnOffsetCommit,
versions::TXN_OFFSET_COMMIT_MAX,
versions::TXN_OFFSET_COMMIT_MIN,
)
.await
.ok_or_else(|| {
KrafkaError::protocol("no mutually supported TxnOffsetCommit API version")
})?;
let response_bytes = group_conn
.send_request(ApiKey::TxnOffsetCommit, toc_version, |buf| {
commit_request.encode_versioned(toc_version, buf)
})
.await?;
let mut buf = response_bytes;
let commit_response =
TxnOffsetCommitResponse::decode_versioned(toc_version, &mut buf)?;
if !commit_response.is_ok() {
for topic_result in &commit_response.topics {
for part_result in &topic_result.partitions {
if !part_result.error_code.is_ok() {
return Err(KrafkaError::broker(
part_result.error_code,
format!(
"failed to commit offset for {}-{} in transaction",
topic_result.name, part_result.partition
),
));
}
}
}
return Err(KrafkaError::protocol(
"failed to commit offsets in transaction",
));
}
Ok(())
}
.await;
if let Err(error) = &result
&& Self::is_unknown_producer_id_error(error)
{
return Err(self.mark_unknown_producer_id_abort_required("TxnOffsetCommit"));
}
match &result {
Ok(()) => {
debug!("Added offsets to transaction for group {}", group_id);
return Ok(());
}
Err(e) if Self::needs_coordinator_refresh(e) && attempt < max_retries => {
warn!(
attempt,
error = %e,
"TxnOffsetCommit group coordinator error, re-discovering and retrying"
);
}
Err(e) if e.is_retriable() && attempt < max_retries => {
warn!(
attempt,
error = %e,
"TxnOffsetCommit retriable error, retrying"
);
}
Err(_) => return result,
}
}
Err(KrafkaError::protocol(format!(
"TxnOffsetCommit retry loop exhausted after {max_retries} retries"
)))
}
async fn find_group_coordinator(&self, group_id: &str) -> Result<(i32, String, i32)> {
let brokers = self.metadata.brokers();
if brokers.is_empty() {
return Err(KrafkaError::protocol("no brokers available"));
}
let broker = &brokers[0];
let conn = self
.pool
.get_connection_by_id(broker.id, broker.address())
.await?;
let request = FindCoordinatorRequest::for_group(group_id);
let fc_version = conn
.negotiate_api_version(
ApiKey::FindCoordinator,
versions::FIND_COORDINATOR_MAX,
versions::FIND_COORDINATOR_MIN,
)
.await
.ok_or_else(|| {
KrafkaError::protocol("no mutually supported FindCoordinator API version")
})?;
let response_bytes = conn
.send_request(ApiKey::FindCoordinator, fc_version, |buf| {
request.encode_versioned(fc_version, buf)
})
.await?;
let mut buf = response_bytes;
let response = FindCoordinatorResponse::decode_versioned(fc_version, &mut buf)?;
if !response.error_code.is_ok() {
return Err(KrafkaError::broker(
response.error_code,
"failed to find group coordinator",
));
}
Ok((response.node_id, response.host, response.port))
}
pub async fn commit_transaction(&self) -> Result<()> {
self.ensure_transaction_can_continue("commit transaction")?;
if let Err(actual) = self.try_transition(
TransactionState::InTransaction,
TransactionState::Committing,
) {
return Err(KrafkaError::invalid_state(format!(
"cannot commit in state {:?}",
actual
)));
}
let result = match self.end_transaction(true).await {
Err(error) if Self::is_unknown_producer_id_error(&error) => {
Err(self.mark_unknown_producer_id_abort_required("commit_transaction"))
}
other => other,
};
match &result {
Ok(()) => {
self.set_state(TransactionState::Ready);
self.txn_partitions.write().await.clear();
info!("Transaction committed");
}
Err(e) if Self::is_abortable_transaction_error(e) => {
match self.try_transition(
TransactionState::Committing,
TransactionState::InTransaction,
) {
Ok(()) => {
warn!("Transaction commit failed (abort required): {}", e);
}
Err(actual) => {
warn!(
"Transaction commit failed (abort required): {}; \
state is now {:?} (concurrent abort may be in progress)",
e, actual
);
}
}
}
Err(e) => {
if e.is_retriable() {
match self.try_transition(
TransactionState::Committing,
TransactionState::InTransaction,
) {
Ok(()) => {
warn!("Transaction commit failed (retriable): {}", e);
}
Err(actual) => {
warn!(
"Transaction commit failed (retriable): {}; \
state is now {:?} (concurrent abort may be in progress)",
e, actual
);
}
}
} else {
self.set_state(TransactionState::FatalError);
warn!("Transaction commit failed (fatal): {}", e);
}
}
}
result
}
pub async fn abort_transaction(&self) -> Result<()> {
let transition = self
.try_transition(TransactionState::InTransaction, TransactionState::Aborting)
.or_else(|_| {
self.try_transition(TransactionState::Committing, TransactionState::Aborting)
});
if let Err(actual) = transition {
return Err(KrafkaError::invalid_state(format!(
"cannot abort in state {:?}",
actual
)));
}
let needs_reinitialize = self.abort_required.swap(false, Ordering::SeqCst);
let result = if needs_reinitialize {
match self.end_transaction(false).await {
Ok(()) => self.do_init_transactions().await,
Err(error) if Self::is_unknown_producer_id_error(&error) => {
debug!(
"Abort observed UnknownProducerId after transactional error; reinitializing producer identity"
);
self.do_init_transactions().await
}
Err(error) => Err(error),
}
} else {
self.end_transaction(false).await
};
match &result {
Ok(()) => {
self.set_state(TransactionState::Ready);
self.txn_partitions.write().await.clear();
info!("Transaction aborted");
}
Err(_) => {
self.set_state(TransactionState::FatalError);
warn!("Transaction abort failed, producer is now in fatal error state");
}
}
result
}
async fn end_transaction(&self, commit: bool) -> Result<()> {
self.retry_with_coordinator("EndTxn", || async {
let (_coordinator_id, conn) = self.coordinator_connection().await?;
let (producer_id, producer_epoch) = self.checked_transactional_identity()?;
let et_version = conn
.negotiate_api_version(ApiKey::EndTxn, versions::END_TXN_MAX, versions::END_TXN_MIN)
.await
.ok_or_else(|| KrafkaError::protocol("no mutually supported EndTxn API version"))?;
let request = if commit {
EndTxnRequest::commit(&self.config.transactional_id, producer_id, producer_epoch)
} else {
EndTxnRequest::abort(&self.config.transactional_id, producer_id, producer_epoch)
};
let response_bytes = conn
.send_request(ApiKey::EndTxn, et_version, |buf| {
request.encode_versioned(et_version, buf)
})
.await?;
let mut buf = response_bytes;
let response = EndTxnResponse::decode_versioned(et_version, &mut buf)?;
if !response.is_ok() {
return Err(KrafkaError::broker(
response.error_code,
if commit {
"failed to commit transaction"
} else {
"failed to abort transaction"
},
));
}
Ok(())
})
.await
}
#[inline]
pub fn transactional_id(&self) -> &str {
&self.config.transactional_id
}
#[inline]
pub fn producer_id(&self) -> i64 {
self.identity.producer_id()
}
#[inline]
pub fn producer_epoch(&self) -> i16 {
self.identity.producer_epoch()
}
async fn next_sequence(&self, topic: &str, partition: PartitionId) -> Result<i32> {
self.identity.next_sequence(topic, partition)
}
pub async fn close(&self) {
let _ = self.close_inner(None).await;
}
pub async fn close_with_timeout(&self, timeout: Duration) -> Result<()> {
self.close_inner(Some(timeout)).await
}
async fn close_inner(&self, timeout: Option<Duration>) -> Result<()> {
let Some(target) = self.in_flight_barrier.begin_close() else {
return Ok(());
};
let graceful_close = async {
self.in_flight_barrier.wait_for(target).await;
let current = self.state();
if current == TransactionState::InTransaction {
warn!("Closing transactional producer with active transaction — aborting");
self.abort_transaction().await?;
}
Ok::<(), KrafkaError>(())
};
let close_result = if let Some(timeout) = timeout {
tokio::time::timeout(timeout, graceful_close)
.await
.map_err(|_| KrafkaError::timeout("transactional producer close"))?
} else {
graceful_close.await
};
self.set_state(TransactionState::FatalError);
self.pool.close_all().await;
info!(
"TransactionalProducer closed: txn.id={}",
self.config.transactional_id
);
close_result
}
#[inline]
pub fn is_closed(&self) -> bool {
self.in_flight_barrier.is_closing()
}
}
fn is_fatal_transaction_error(error_code: ErrorCode) -> bool {
matches!(
error_code,
ErrorCode::InvalidProducerEpoch
| ErrorCode::ProducerFenced
| ErrorCode::TransactionalIdAuthorizationFailed
| ErrorCode::InvalidTxnState
| ErrorCode::TransactionCoordinatorFenced
)
}
fn should_refresh_metadata_after_txn_send_error(error: &KrafkaError) -> bool {
error.is_retriable()
&& !matches!(
error,
KrafkaError::Broker {
code: ErrorCode::OutOfOrderSequenceNumber,
..
}
)
}
#[must_use = "builders do nothing until .build() is called"]
#[derive(Default)]
pub struct TransactionalProducerBuilder {
config: TransactionalProducerConfig,
partitioner: Option<Arc<dyn Partitioner>>,
}
impl TransactionalProducerBuilder {
pub fn bootstrap_servers(mut self, servers: impl Into<String>) -> Self {
self.config.bootstrap_servers = servers.into();
self
}
pub fn client_id(mut self, client_id: impl Into<String>) -> Self {
self.config.client_id = client_id.into();
self
}
pub fn transactional_id(mut self, txn_id: impl Into<String>) -> Self {
self.config.transactional_id = txn_id.into();
self
}
pub fn transaction_timeout_ms(mut self, timeout: i32) -> Self {
self.config.transaction_timeout_ms = timeout;
self
}
pub fn request_timeout(mut self, timeout: Duration) -> Self {
self.config.request_timeout = timeout;
self
}
pub fn max_request_size(mut self, bytes: usize) -> Self {
self.config.max_request_size = bytes;
self
}
pub fn compression(mut self, compression: Compression) -> Self {
self.config.compression = compression;
self
}
pub fn partitioner(mut self, partitioner: impl Partitioner + 'static) -> Self {
self.partitioner = Some(Arc::new(partitioner));
self
}
pub fn auth(mut self, auth: AuthConfig) -> Self {
self.config.auth = Some(auth);
self
}
#[cfg(feature = "socks5")]
pub fn proxy(mut self, proxy: crate::network::ProxyConfig) -> Self {
self.config.proxy = Some(proxy);
self
}
pub fn sasl_plain(mut self, username: &str, password: &str) -> crate::Result<Self> {
self.config.auth = Some(AuthConfig::sasl_plain(username, password)?);
Ok(self)
}
pub fn sasl_scram_sha256(mut self, username: &str, password: &str) -> Self {
self.config.auth = Some(AuthConfig::sasl_scram_sha256(username, password));
self
}
pub fn sasl_scram_sha512(mut self, username: &str, password: &str) -> Self {
self.config.auth = Some(AuthConfig::sasl_scram_sha512(username, password));
self
}
pub async fn build(self) -> Result<TransactionalProducer> {
if self.config.bootstrap_servers.is_empty() {
return Err(KrafkaError::config("bootstrap.servers is required"));
}
if self.config.transactional_id.is_empty() {
return Err(KrafkaError::config("transactional_id is required"));
}
if self.config.transaction_timeout_ms <= 0 {
return Err(KrafkaError::config("transaction_timeout_ms must be > 0"));
}
if self.config.max_request_size == 0 {
return Err(KrafkaError::config("max_request_size must be >= 1"));
}
let mut pool_config_builder = ConnectionConfig::builder()
.client_id(&self.config.client_id)
.request_timeout(self.config.request_timeout);
if let Some(ref auth) = self.config.auth {
pool_config_builder = pool_config_builder.auth(auth.clone());
}
#[cfg(feature = "socks5")]
if let Some(ref proxy) = self.config.proxy {
pool_config_builder = pool_config_builder.proxy(proxy.clone());
}
let mut pool_config = pool_config_builder.build();
pool_config.init_tls().await?;
let pool = Arc::new(ConnectionPool::new(pool_config));
pool.start_idle_evictor();
let bootstrap_servers =
crate::util::parse_bootstrap_servers(&self.config.bootstrap_servers)?;
let metadata = Arc::new(ClusterMetadata::new(
bootstrap_servers,
pool.clone(),
self.config.metadata_max_age,
));
metadata.refresh().await?;
info!(
"TransactionalProducer created with transactional.id={}",
self.config.transactional_id
);
Ok(TransactionalProducer {
config: self.config,
metadata,
pool,
partitioner: self
.partitioner
.unwrap_or_else(|| Arc::new(DefaultPartitioner::new())),
state: AtomicU8::new(TransactionState::Uninitialized as u8),
abort_required: AtomicBool::new(false),
coordinator_id: RwLock::new(None),
txn_partitions: Arc::new(RwLock::new(TransactionPartitions::default())),
identity: ProducerIdentity::new(),
retry_policy: RetryPolicy::default(),
in_flight_barrier: Arc::new(InFlightBarrier::new()),
})
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
mod tests {
use super::*;
use crate::metadata::ClusterMetadata;
use crate::network::ConnectionPool;
#[test]
fn test_transaction_state() {
assert_eq!(
TransactionState::from_u8(0),
TransactionState::Uninitialized
);
assert_eq!(TransactionState::from_u8(1), TransactionState::Ready);
assert_eq!(
TransactionState::from_u8(2),
TransactionState::InTransaction
);
assert_eq!(TransactionState::from_u8(3), TransactionState::Committing);
assert_eq!(TransactionState::from_u8(4), TransactionState::Aborting);
assert_eq!(TransactionState::from_u8(5), TransactionState::FatalError);
assert_eq!(TransactionState::from_u8(99), TransactionState::FatalError);
}
#[test]
fn test_transactional_producer_config_default() {
let config = TransactionalProducerConfig::default();
assert_eq!(config.client_id, "krafka-txn-producer");
assert_eq!(config.transaction_timeout_ms, 60000);
assert_eq!(config.max_request_size, crate::protocol::MAX_MESSAGE_SIZE);
}
#[test]
fn test_transaction_partitions() {
let mut partitions = TransactionPartitions::default();
assert!(partitions.is_empty());
let result = partitions.begin_add("topic1", 0);
let notify = match result {
BeginAddResult::NeedAdd(n) => n,
_ => panic!("expected NeedAdd"),
};
assert!(!partitions.is_empty());
assert!(matches!(
partitions.begin_add("topic1", 0),
BeginAddResult::Wait(_)
));
partitions.confirm_add("topic1", 0, ¬ify);
assert!(matches!(
partitions.begin_add("topic1", 0),
BeginAddResult::AlreadyAdded
));
assert!(matches!(
partitions.begin_add("topic1", 1),
BeginAddResult::NeedAdd(_)
));
partitions.clear();
assert!(partitions.is_empty());
}
#[test]
fn test_is_fatal_transaction_error() {
assert!(is_fatal_transaction_error(ErrorCode::InvalidProducerEpoch));
assert!(is_fatal_transaction_error(ErrorCode::ProducerFenced));
assert!(is_fatal_transaction_error(
ErrorCode::TransactionCoordinatorFenced
));
assert!(is_fatal_transaction_error(
ErrorCode::TransactionalIdAuthorizationFailed
));
assert!(is_fatal_transaction_error(ErrorCode::InvalidTxnState));
assert!(!is_fatal_transaction_error(ErrorCode::None));
assert!(!is_fatal_transaction_error(ErrorCode::UnknownServerError));
}
#[test]
fn test_needs_coordinator_refresh() {
assert!(TransactionalProducer::needs_coordinator_refresh(
&KrafkaError::broker(ErrorCode::NotCoordinator, "test")
));
assert!(TransactionalProducer::needs_coordinator_refresh(
&KrafkaError::broker(ErrorCode::CoordinatorNotAvailable, "test")
));
assert!(TransactionalProducer::needs_coordinator_refresh(
&KrafkaError::broker(ErrorCode::CoordinatorLoadInProgress, "test")
));
assert!(TransactionalProducer::needs_coordinator_refresh(
&KrafkaError::network(std::io::Error::new(
std::io::ErrorKind::ConnectionRefused,
"refused"
))
));
assert!(TransactionalProducer::needs_coordinator_refresh(
&KrafkaError::timeout("test operation")
));
assert!(!TransactionalProducer::needs_coordinator_refresh(
&KrafkaError::broker(ErrorCode::InvalidProducerEpoch, "test")
));
assert!(!TransactionalProducer::needs_coordinator_refresh(
&KrafkaError::broker(ErrorCode::TransactionCoordinatorFenced, "test")
));
assert!(!TransactionalProducer::needs_coordinator_refresh(
&KrafkaError::protocol("test")
));
assert!(!TransactionalProducer::needs_coordinator_refresh(
&KrafkaError::invalid_state("test")
));
}
#[test]
fn test_should_refresh_metadata_after_txn_send_error() {
assert!(!should_refresh_metadata_after_txn_send_error(
&KrafkaError::broker(ErrorCode::OutOfOrderSequenceNumber, "sequence mismatch")
));
assert!(should_refresh_metadata_after_txn_send_error(
&KrafkaError::broker(ErrorCode::LeaderNotAvailable, "leader moved")
));
assert!(should_refresh_metadata_after_txn_send_error(
&KrafkaError::timeout("produce")
));
assert!(!should_refresh_metadata_after_txn_send_error(
&KrafkaError::broker(ErrorCode::InvalidProducerEpoch, "fenced")
));
}
#[tokio::test]
async fn test_builder_missing_bootstrap() {
let result = TransactionalProducer::builder()
.transactional_id("my-txn")
.build()
.await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_send_record_requires_initialized_transactional_identity() {
let pool = Arc::new(ConnectionPool::new(ConnectionConfig::default()));
let metadata = Arc::new(ClusterMetadata::new(
vec!["localhost:9092".to_string()],
pool.clone(),
Duration::from_secs(300),
));
let producer = TransactionalProducer {
config: TransactionalProducerConfig {
bootstrap_servers: "localhost:9092".to_string(),
transactional_id: "txn-test".to_string(),
..TransactionalProducerConfig::default()
},
metadata,
pool,
partitioner: Arc::new(DefaultPartitioner::new()),
state: AtomicU8::new(TransactionState::InTransaction as u8),
abort_required: AtomicBool::new(false),
coordinator_id: RwLock::new(None),
txn_partitions: Arc::new(RwLock::new(TransactionPartitions::default())),
identity: ProducerIdentity::new(),
retry_policy: RetryPolicy::default(),
in_flight_barrier: Arc::new(InFlightBarrier::new()),
};
let record = ProducerRecord::new("topic", Bytes::from_static(b"value")).with_partition(0);
let err = producer.send_record(record).await.unwrap_err();
assert!(
err.to_string()
.contains("transactional producer identity not initialized"),
"expected invalid identity guard, got: {err}"
);
}
#[tokio::test]
async fn test_builder_missing_txn_id() {
let result = TransactionalProducer::builder()
.bootstrap_servers("localhost:9092")
.build()
.await;
assert!(result.is_err());
}
#[test]
fn test_mark_unknown_producer_id_requires_abort() {
let pool = Arc::new(ConnectionPool::new(ConnectionConfig::default()));
let metadata = Arc::new(ClusterMetadata::new(
vec!["localhost:9092".to_string()],
pool.clone(),
Duration::from_secs(300),
));
let producer = TransactionalProducer {
config: TransactionalProducerConfig {
bootstrap_servers: "localhost:9092".to_string(),
transactional_id: "txn-test".to_string(),
..TransactionalProducerConfig::default()
},
metadata,
pool,
partitioner: Arc::new(DefaultPartitioner::new()),
state: AtomicU8::new(TransactionState::InTransaction as u8),
abort_required: AtomicBool::new(false),
coordinator_id: RwLock::new(None),
txn_partitions: Arc::new(RwLock::new(TransactionPartitions::default())),
identity: ProducerIdentity::new(),
retry_policy: RetryPolicy::default(),
in_flight_barrier: Arc::new(InFlightBarrier::new()),
};
let error = producer.mark_unknown_producer_id_abort_required("transactional produce");
assert!(matches!(
error,
KrafkaError::Broker {
code: ErrorCode::TransactionAbortable,
..
}
));
assert!(producer.abort_required());
let gate_error = producer
.ensure_transaction_can_continue("commit transaction")
.unwrap_err();
assert!(matches!(
gate_error,
KrafkaError::Broker {
code: ErrorCode::TransactionAbortable,
..
}
));
}
#[tokio::test]
async fn test_commit_transaction_rejects_abort_required() {
let pool = Arc::new(ConnectionPool::new(ConnectionConfig::default()));
let metadata = Arc::new(ClusterMetadata::new(
vec!["localhost:9092".to_string()],
pool.clone(),
Duration::from_secs(300),
));
let producer = TransactionalProducer {
config: TransactionalProducerConfig {
bootstrap_servers: "localhost:9092".to_string(),
transactional_id: "txn-test".to_string(),
..TransactionalProducerConfig::default()
},
metadata,
pool,
partitioner: Arc::new(DefaultPartitioner::new()),
state: AtomicU8::new(TransactionState::InTransaction as u8),
abort_required: AtomicBool::new(true),
coordinator_id: RwLock::new(None),
txn_partitions: Arc::new(RwLock::new(TransactionPartitions::default())),
identity: ProducerIdentity::new(),
retry_policy: RetryPolicy::default(),
in_flight_barrier: Arc::new(InFlightBarrier::new()),
};
let error = producer.commit_transaction().await.unwrap_err();
assert!(matches!(
error,
KrafkaError::Broker {
code: ErrorCode::TransactionAbortable,
..
}
));
assert_eq!(producer.state(), TransactionState::InTransaction);
}
#[test]
fn test_try_transition_success() {
let state = AtomicU8::new(TransactionState::Ready as u8);
let result = state.compare_exchange(
TransactionState::Ready as u8,
TransactionState::InTransaction as u8,
Ordering::SeqCst,
Ordering::SeqCst,
);
assert!(result.is_ok());
assert_eq!(
TransactionState::from_u8(state.load(Ordering::SeqCst)),
TransactionState::InTransaction
);
}
#[test]
fn test_try_transition_failure() {
let state = AtomicU8::new(TransactionState::Uninitialized as u8);
let result = state.compare_exchange(
TransactionState::Ready as u8,
TransactionState::InTransaction as u8,
Ordering::SeqCst,
Ordering::SeqCst,
);
assert!(result.is_err());
assert_eq!(
TransactionState::from_u8(state.load(Ordering::SeqCst)),
TransactionState::Uninitialized
);
}
#[test]
fn test_txn_builder_no_auth_by_default() {
let builder = TransactionalProducer::builder()
.bootstrap_servers("broker:9092")
.transactional_id("txn-1");
assert!(builder.config.auth.is_none());
}
#[test]
fn test_txn_builder_sets_max_request_size() {
let builder = TransactionalProducer::builder()
.bootstrap_servers("broker:9092")
.transactional_id("txn-1")
.max_request_size(65_536);
assert_eq!(builder.config.max_request_size, 65_536);
}
#[test]
fn test_txn_builder_sasl_plain() {
let builder = TransactionalProducer::builder()
.bootstrap_servers("broker:9093")
.transactional_id("txn-1")
.sasl_plain("user", "pass")
.unwrap();
let auth = builder.config.auth.as_ref().unwrap();
assert!(auth.requires_sasl());
assert!(auth.plain_credentials.is_some());
}
#[test]
fn test_txn_builder_sasl_scram_sha256() {
let builder = TransactionalProducer::builder()
.bootstrap_servers("broker:9093")
.transactional_id("txn-1")
.sasl_scram_sha256("user", "pass");
let auth = builder.config.auth.as_ref().unwrap();
assert!(auth.requires_sasl());
assert!(auth.scram_credentials.is_some());
}
#[test]
fn test_txn_builder_sasl_scram_sha512() {
let builder = TransactionalProducer::builder()
.bootstrap_servers("broker:9093")
.transactional_id("txn-1")
.sasl_scram_sha512("user", "pass");
let auth = builder.config.auth.as_ref().unwrap();
assert!(auth.requires_sasl());
assert!(auth.scram_credentials.is_some());
}
#[test]
fn test_txn_builder_auth_config() {
use crate::auth::AuthConfig;
let auth = AuthConfig::sasl_scram_sha256("admin", "secret");
let builder = TransactionalProducer::builder()
.bootstrap_servers("broker:9093")
.transactional_id("txn-1")
.auth(auth);
let auth = builder.config.auth.as_ref().unwrap();
assert!(auth.requires_sasl());
assert!(auth.scram_credentials.is_some());
}
#[test]
fn test_txn_builder_initializes_producer_identity() {
let builder = TransactionalProducer::builder()
.bootstrap_servers("broker:9092")
.transactional_id("txn-test");
assert_eq!(builder.config.transactional_id, "txn-test");
}
#[test]
fn test_txn_builder_requires_transactional_id() {
let builder = TransactionalProducer::builder().bootstrap_servers("broker:9092");
assert!(builder.config.transactional_id.is_empty());
}
#[tokio::test]
async fn test_txn_builder_rejects_zero_timeout() {
let result = TransactionalProducer::builder()
.bootstrap_servers("localhost:9092")
.transactional_id("txn-1")
.transaction_timeout_ms(0)
.build()
.await;
match result {
Err(e) => assert!(e.to_string().contains("transaction_timeout_ms")),
Ok(_) => panic!("expected error for transaction_timeout_ms=0"),
}
}
#[tokio::test]
async fn test_txn_builder_rejects_zero_max_request_size() {
let result = TransactionalProducer::builder()
.bootstrap_servers("localhost:9092")
.transactional_id("txn-1")
.max_request_size(0)
.build()
.await;
match result {
Err(e) => assert!(e.to_string().contains("max_request_size")),
Ok(_) => panic!("expected error for max_request_size=0"),
}
}
#[tokio::test]
async fn test_txn_builder_rejects_negative_timeout() {
let result = TransactionalProducer::builder()
.bootstrap_servers("localhost:9092")
.transactional_id("txn-1")
.transaction_timeout_ms(-1)
.build()
.await;
assert!(result.is_err());
}
#[test]
fn test_transaction_state_initializing_from_u8() {
assert_eq!(TransactionState::from_u8(6), TransactionState::Initializing);
}
#[test]
fn test_transaction_state_initializing_value() {
assert_eq!(TransactionState::Initializing as u8, 6);
}
#[test]
fn test_transaction_state_initializing_round_trip() {
let state = TransactionState::Initializing;
let val = state as u8;
assert_eq!(
TransactionState::from_u8(val),
TransactionState::Initializing
);
}
#[test]
fn test_transaction_state_unknown_maps_to_fatal() {
assert_eq!(TransactionState::from_u8(7), TransactionState::FatalError);
assert_eq!(TransactionState::from_u8(255), TransactionState::FatalError);
}
#[test]
fn test_try_transition_uninitialized_to_initializing() {
let state = AtomicU8::new(TransactionState::Uninitialized as u8);
let result = state.compare_exchange(
TransactionState::Uninitialized as u8,
TransactionState::Initializing as u8,
Ordering::SeqCst,
Ordering::SeqCst,
);
assert!(result.is_ok());
assert_eq!(
TransactionState::from_u8(state.load(Ordering::SeqCst)),
TransactionState::Initializing
);
}
#[test]
fn test_try_transition_initializing_blocks_second_init() {
let state = AtomicU8::new(TransactionState::Initializing as u8);
let result = state.compare_exchange(
TransactionState::Uninitialized as u8,
TransactionState::Initializing as u8,
Ordering::SeqCst,
Ordering::SeqCst,
);
assert!(result.is_err());
assert_eq!(
TransactionState::from_u8(state.load(Ordering::SeqCst)),
TransactionState::Initializing
);
}
#[test]
fn test_commit_fatal_error_state_machine() {
let state = AtomicU8::new(TransactionState::Committing as u8);
let error = KrafkaError::broker(ErrorCode::InvalidProducerEpoch, "epoch fenced");
assert!(!error.is_retriable());
if error.is_retriable() {
state.store(TransactionState::InTransaction as u8, Ordering::SeqCst);
} else {
state.store(TransactionState::FatalError as u8, Ordering::SeqCst);
}
assert_eq!(
TransactionState::from_u8(state.load(Ordering::SeqCst)),
TransactionState::FatalError
);
}
#[test]
fn test_commit_retriable_error_reverts_to_in_transaction() {
let state = AtomicU8::new(TransactionState::Committing as u8);
let error = KrafkaError::broker(ErrorCode::CoordinatorNotAvailable, "coordinator down");
assert!(error.is_retriable());
if error.is_retriable() {
state.store(TransactionState::InTransaction as u8, Ordering::SeqCst);
} else {
state.store(TransactionState::FatalError as u8, Ordering::SeqCst);
}
assert_eq!(
TransactionState::from_u8(state.load(Ordering::SeqCst)),
TransactionState::InTransaction
);
}
#[test]
fn test_txn_close_sets_fatal_error_state() {
let state = AtomicU8::new(TransactionState::Ready as u8);
state.store(TransactionState::FatalError as u8, Ordering::SeqCst);
assert_eq!(
TransactionState::from_u8(state.load(Ordering::SeqCst)),
TransactionState::FatalError
);
}
#[test]
fn test_out_of_order_sequence_is_retriable() {
let error = KrafkaError::broker(ErrorCode::OutOfOrderSequenceNumber, "sequence mismatch");
assert!(error.is_retriable());
}
#[test]
fn test_producer_record_with_timestamp() {
use crate::producer::ProducerRecord;
let record = ProducerRecord::new("topic", b"value".to_vec()).with_timestamp(1234567890);
assert_eq!(record.timestamp, Some(1234567890));
}
#[test]
fn test_transaction_partitions_state_machine() {
let mut tp = TransactionPartitions::default();
let result = tp.begin_add("topic", 0);
let notify = match result {
BeginAddResult::NeedAdd(n) => n,
_ => panic!("expected NeedAdd"),
};
let result2 = tp.begin_add("topic", 0);
assert!(matches!(result2, BeginAddResult::Wait(_)));
tp.confirm_add("topic", 0, ¬ify);
assert!(matches!(
tp.begin_add("topic", 0),
BeginAddResult::AlreadyAdded
));
let result3 = tp.begin_add("topic", 1);
let notify2 = match result3 {
BeginAddResult::NeedAdd(n) => n,
_ => panic!("expected NeedAdd"),
};
tp.cancel_add("topic", 1, ¬ify2);
assert!(matches!(
tp.begin_add("topic", 1),
BeginAddResult::NeedAdd(_)
));
tp.clear();
assert!(tp.is_empty());
}
#[test]
fn test_transactional_producer_is_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<TransactionalProducer>();
}
}