use std::collections::BTreeMap;
use anyhow::{Context, Result, bail};
use kafka_protocol::error::{ParseResponseErrorCode, ResponseError};
use kafka_protocol::messages::find_coordinator_response::FindCoordinatorResponse;
use kafka_protocol::messages::txn_offset_commit_request::{
TxnOffsetCommitRequestPartition, TxnOffsetCommitRequestTopic,
};
use kafka_protocol::messages::{
FindCoordinatorRequest, ProducerId, TransactionalId, TxnOffsetCommitRequest,
};
use kafka_protocol::protocol::StrBytes;
use super::state::{ProducerIdentity, TransactionCoordinator};
use crate::constants::{FIND_COORDINATOR_GROUP_KEY_TYPE, FIND_COORDINATOR_TRANSACTION_KEY_TYPE};
use crate::network::BrokerConnection;
use crate::types::{CommitOffset, ConsumerGroupMetadata};
use crate::{ConsumerGroupMetadataError, ProducerError};
pub enum TransactionFailureDisposition {
AbortOnly,
Fatal,
}
pub fn classify_transactional_error(error: ResponseError) -> TransactionFailureDisposition {
match error {
ResponseError::ProducerFenced
| ResponseError::TransactionalIdAuthorizationFailed
| ResponseError::InvalidProducerIdMapping
| ResponseError::InvalidProducerEpoch
| ResponseError::OutOfOrderSequenceNumber
| ResponseError::GroupAuthorizationFailed
| ResponseError::FencedInstanceId
| ResponseError::UnsupportedForMessageFormat => TransactionFailureDisposition::Fatal,
_ => TransactionFailureDisposition::AbortOnly,
}
}
pub fn ensure_transaction_v2_feature(
connection: &BrokerConnection,
) -> std::result::Result<(), ProducerError> {
let Some(level) = connection.finalized_feature_level("transaction.version") else {
return Err(ProducerError::MissingTransactionVersionFeature);
};
if level < 2 {
return Err(ProducerError::UnsupportedTransactionVersion { level });
}
Ok(())
}
pub fn build_find_coordinator_request(
transactional_id: &str,
version: i16,
) -> FindCoordinatorRequest {
if version >= 4 {
FindCoordinatorRequest::default()
.with_key_type(FIND_COORDINATOR_TRANSACTION_KEY_TYPE)
.with_coordinator_keys(vec![StrBytes::from_string(transactional_id.to_owned())])
} else {
FindCoordinatorRequest::default()
.with_key(StrBytes::from_string(transactional_id.to_owned()))
.with_key_type(FIND_COORDINATOR_TRANSACTION_KEY_TYPE)
}
}
pub fn build_group_find_coordinator_request(
group_id: &str,
version: i16,
) -> FindCoordinatorRequest {
if version >= 4 {
FindCoordinatorRequest::default()
.with_key_type(FIND_COORDINATOR_GROUP_KEY_TYPE)
.with_coordinator_keys(vec![StrBytes::from_string(group_id.to_owned())])
} else {
FindCoordinatorRequest::default()
.with_key(StrBytes::from_string(group_id.to_owned()))
.with_key_type(FIND_COORDINATOR_GROUP_KEY_TYPE)
}
}
pub fn validate_group_metadata(
group_metadata: &ConsumerGroupMetadata,
) -> std::result::Result<(), ConsumerGroupMetadataError> {
if group_metadata.group_id.trim().is_empty() {
return Err(ConsumerGroupMetadataError::EmptyGroupId);
}
if group_metadata.generation_id > 0 && group_metadata.member_id.is_empty() {
return Err(ConsumerGroupMetadataError::MissingMemberId);
}
Ok(())
}
pub fn build_txn_offset_commit_request(
transactional_id: &str,
producer: ProducerIdentity,
offsets: &[CommitOffset],
group_metadata: &ConsumerGroupMetadata,
) -> TxnOffsetCommitRequest {
let mut topics = BTreeMap::<String, Vec<TxnOffsetCommitRequestPartition>>::new();
for offset in offsets {
topics.entry(offset.topic.clone()).or_default().push(
TxnOffsetCommitRequestPartition::default()
.with_partition_index(offset.partition)
.with_committed_offset(offset.offset)
.with_committed_leader_epoch(-1)
.with_committed_metadata(None),
);
}
TxnOffsetCommitRequest::default()
.with_transactional_id(TransactionalId(StrBytes::from_string(
transactional_id.to_owned(),
)))
.with_group_id(StrBytes::from_string(group_metadata.group_id.clone()).into())
.with_producer_id(ProducerId(producer.id))
.with_producer_epoch(producer.epoch)
.with_generation_id(group_metadata.generation_id)
.with_member_id(StrBytes::from_string(group_metadata.member_id.clone()))
.with_group_instance_id(
group_metadata
.group_instance_id
.clone()
.map(StrBytes::from_string),
)
.with_topics(
topics
.into_iter()
.map(|(topic, partitions)| {
TxnOffsetCommitRequestTopic::default()
.with_name(StrBytes::from_string(topic).into())
.with_partitions(partitions)
})
.collect(),
)
}
pub fn parse_find_coordinator_response(
response: FindCoordinatorResponse,
version: i16,
) -> Result<TransactionCoordinator> {
if version >= 4 {
let coordinator = response
.coordinators
.into_iter()
.next()
.context("FindCoordinator returned no coordinators")?;
if let Some(error) = coordinator.error_code.err() {
bail!("FindCoordinator failed: {error}");
}
let port = u16::try_from(coordinator.port)
.with_context(|| format!("invalid coordinator port {}", coordinator.port))?;
return Ok(TransactionCoordinator {
broker_id: *coordinator.node_id,
address: format!("{}:{}", coordinator.host, port),
});
}
if let Some(error) = response.error_code.err() {
bail!("FindCoordinator failed: {error}");
}
let port = u16::try_from(response.port)
.with_context(|| format!("invalid coordinator port {}", response.port))?;
Ok(TransactionCoordinator {
broker_id: *response.node_id,
address: format!("{}:{}", response.host, port),
})
}
pub fn find_coordinator_error(
response: &FindCoordinatorResponse,
version: i16,
) -> Option<ResponseError> {
if version >= 4 {
return response
.coordinators
.first()
.and_then(|coordinator| coordinator.error_code.err());
}
response.error_code.err()
}
#[cfg(test)]
mod tests {
use super::*;
use kafka_protocol::messages::BrokerId;
use kafka_protocol::messages::find_coordinator_response::Coordinator;
#[test]
fn validate_group_metadata_requires_group_id() {
let error = validate_group_metadata(&ConsumerGroupMetadata {
group_id: " ".to_owned(),
generation_id: 0,
member_id: String::new(),
group_instance_id: None,
})
.unwrap_err();
assert!(matches!(error, ConsumerGroupMetadataError::EmptyGroupId));
}
#[test]
fn validate_group_metadata_requires_member_id_for_active_generation() {
let error = validate_group_metadata(&ConsumerGroupMetadata {
group_id: "group-a".to_owned(),
generation_id: 3,
member_id: String::new(),
group_instance_id: None,
})
.unwrap_err();
assert!(matches!(error, ConsumerGroupMetadataError::MissingMemberId));
}
#[test]
fn classify_transactional_errors_distinguishes_fatal_from_abort_only() {
assert!(matches!(
classify_transactional_error(ResponseError::ProducerFenced),
TransactionFailureDisposition::Fatal
));
assert!(matches!(
classify_transactional_error(ResponseError::InvalidProducerEpoch),
TransactionFailureDisposition::Fatal
));
assert!(matches!(
classify_transactional_error(ResponseError::OutOfOrderSequenceNumber),
TransactionFailureDisposition::Fatal
));
assert!(matches!(
classify_transactional_error(ResponseError::UnknownServerError),
TransactionFailureDisposition::AbortOnly
));
}
#[test]
fn find_coordinator_requests_follow_version_shape() {
let old = build_find_coordinator_request("tx-a", 3);
assert_eq!(old.key.to_string(), "tx-a");
assert!(old.coordinator_keys.is_empty());
assert_eq!(old.key_type, FIND_COORDINATOR_TRANSACTION_KEY_TYPE);
let modern = build_find_coordinator_request("tx-a", 4);
assert!(modern.key.is_empty());
assert_eq!(modern.coordinator_keys[0].to_string(), "tx-a");
let group = build_group_find_coordinator_request("group-a", 4);
assert_eq!(group.key_type, FIND_COORDINATOR_GROUP_KEY_TYPE);
assert_eq!(group.coordinator_keys[0].to_string(), "group-a");
}
#[test]
fn txn_offset_commit_request_groups_offsets_by_topic() {
let request = build_txn_offset_commit_request(
"tx-a",
ProducerIdentity { id: 42, epoch: 3 },
&[
CommitOffset {
topic: "topic-a".to_owned(),
partition: 0,
offset: 7,
},
CommitOffset {
topic: "topic-b".to_owned(),
partition: 1,
offset: 11,
},
CommitOffset {
topic: "topic-a".to_owned(),
partition: 2,
offset: 13,
},
],
&ConsumerGroupMetadata {
group_id: "group-a".to_owned(),
generation_id: 5,
member_id: "member-a".to_owned(),
group_instance_id: Some("instance-a".to_owned()),
},
);
assert_eq!(request.transactional_id.0.to_string(), "tx-a");
assert_eq!(request.producer_id.0, 42);
assert_eq!(request.producer_epoch, 3);
assert_eq!(request.group_id.0.to_string(), "group-a");
assert_eq!(request.group_instance_id.unwrap().to_string(), "instance-a");
assert_eq!(request.topics.len(), 2);
let topic_a = request
.topics
.iter()
.find(|topic| topic.name.to_string() == "topic-a")
.unwrap();
assert_eq!(topic_a.partitions.len(), 2);
}
#[test]
fn parse_find_coordinator_response_handles_old_and_modern_errors() {
let old = FindCoordinatorResponse::default()
.with_node_id(BrokerId(2))
.with_host(StrBytes::from_static_str("broker-a"))
.with_port(9092);
let coordinator = parse_find_coordinator_response(old.clone(), 3).unwrap();
assert_eq!(coordinator.broker_id, 2);
assert_eq!(coordinator.address, "broker-a:9092");
assert!(parse_find_coordinator_response(old.with_port(-1), 3).is_err());
let modern = FindCoordinatorResponse::default().with_coordinators(vec![
Coordinator::default()
.with_node_id(BrokerId(3))
.with_host(StrBytes::from_static_str("broker-b"))
.with_port(9093),
]);
let coordinator = parse_find_coordinator_response(modern, 4).unwrap();
assert_eq!(coordinator.broker_id, 3);
assert_eq!(coordinator.address, "broker-b:9093");
let error = FindCoordinatorResponse::default()
.with_error_code(ResponseError::CoordinatorNotAvailable.code());
assert!(find_coordinator_error(&error, 3).is_some());
assert!(parse_find_coordinator_response(error, 3).is_err());
let error = FindCoordinatorResponse::default().with_coordinators(vec![
Coordinator::default().with_error_code(ResponseError::CoordinatorNotAvailable.code()),
]);
assert!(find_coordinator_error(&error, 4).is_some());
assert!(parse_find_coordinator_response(error, 4).is_err());
assert!(parse_find_coordinator_response(FindCoordinatorResponse::default(), 4).is_err());
}
}