kafkit-client 0.1.4

Kafka 4.0+ pure Rust client.
Documentation
use std::collections::BTreeMap;

use anyhow::{Context, Result, bail};
use kafka_protocol::error::{ParseResponseErrorCode, ResponseError};
use kafka_protocol::messages::find_coordinator_response::FindCoordinatorResponse;
use kafka_protocol::messages::txn_offset_commit_request::{
    TxnOffsetCommitRequestPartition, TxnOffsetCommitRequestTopic,
};
use kafka_protocol::messages::{
    FindCoordinatorRequest, ProducerId, TransactionalId, TxnOffsetCommitRequest,
};
use kafka_protocol::protocol::StrBytes;

use super::state::{ProducerIdentity, TransactionCoordinator};
use crate::constants::{FIND_COORDINATOR_GROUP_KEY_TYPE, FIND_COORDINATOR_TRANSACTION_KEY_TYPE};
use crate::network::BrokerConnection;
use crate::types::{CommitOffset, ConsumerGroupMetadata};
use crate::{ConsumerGroupMetadataError, ProducerError};

pub enum TransactionFailureDisposition {
    AbortOnly,
    Fatal,
}

pub fn classify_transactional_error(error: ResponseError) -> TransactionFailureDisposition {
    match error {
        ResponseError::ProducerFenced
        | ResponseError::TransactionalIdAuthorizationFailed
        | ResponseError::InvalidProducerIdMapping
        | ResponseError::GroupAuthorizationFailed
        | ResponseError::FencedInstanceId
        | ResponseError::UnsupportedForMessageFormat => TransactionFailureDisposition::Fatal,
        _ => TransactionFailureDisposition::AbortOnly,
    }
}

pub fn ensure_transaction_v2_feature(
    connection: &BrokerConnection,
) -> std::result::Result<(), ProducerError> {
    let Some(level) = connection.finalized_feature_level("transaction.version") else {
        return Err(ProducerError::MissingTransactionVersionFeature);
    };

    if level < 2 {
        return Err(ProducerError::UnsupportedTransactionVersion { level });
    }

    Ok(())
}

pub fn build_find_coordinator_request(
    transactional_id: &str,
    version: i16,
) -> FindCoordinatorRequest {
    if version >= 4 {
        FindCoordinatorRequest::default()
            .with_key_type(FIND_COORDINATOR_TRANSACTION_KEY_TYPE)
            .with_coordinator_keys(vec![StrBytes::from_string(transactional_id.to_owned())])
    } else {
        FindCoordinatorRequest::default()
            .with_key(StrBytes::from_string(transactional_id.to_owned()))
            .with_key_type(FIND_COORDINATOR_TRANSACTION_KEY_TYPE)
    }
}

pub fn build_group_find_coordinator_request(
    group_id: &str,
    version: i16,
) -> FindCoordinatorRequest {
    if version >= 4 {
        FindCoordinatorRequest::default()
            .with_key_type(FIND_COORDINATOR_GROUP_KEY_TYPE)
            .with_coordinator_keys(vec![StrBytes::from_string(group_id.to_owned())])
    } else {
        FindCoordinatorRequest::default()
            .with_key(StrBytes::from_string(group_id.to_owned()))
            .with_key_type(FIND_COORDINATOR_GROUP_KEY_TYPE)
    }
}

pub fn validate_group_metadata(
    group_metadata: &ConsumerGroupMetadata,
) -> std::result::Result<(), ConsumerGroupMetadataError> {
    if group_metadata.group_id.trim().is_empty() {
        return Err(ConsumerGroupMetadataError::EmptyGroupId);
    }
    if group_metadata.generation_id > 0 && group_metadata.member_id.is_empty() {
        return Err(ConsumerGroupMetadataError::MissingMemberId);
    }
    Ok(())
}

pub fn build_txn_offset_commit_request(
    transactional_id: &str,
    producer: ProducerIdentity,
    offsets: &[CommitOffset],
    group_metadata: &ConsumerGroupMetadata,
) -> TxnOffsetCommitRequest {
    let mut topics = BTreeMap::<String, Vec<TxnOffsetCommitRequestPartition>>::new();
    for offset in offsets {
        topics.entry(offset.topic.clone()).or_default().push(
            TxnOffsetCommitRequestPartition::default()
                .with_partition_index(offset.partition)
                .with_committed_offset(offset.offset)
                .with_committed_leader_epoch(-1)
                .with_committed_metadata(None),
        );
    }

    TxnOffsetCommitRequest::default()
        .with_transactional_id(TransactionalId(StrBytes::from_string(
            transactional_id.to_owned(),
        )))
        .with_group_id(StrBytes::from_string(group_metadata.group_id.clone()).into())
        .with_producer_id(ProducerId(producer.id))
        .with_producer_epoch(producer.epoch)
        .with_generation_id(group_metadata.generation_id)
        .with_member_id(StrBytes::from_string(group_metadata.member_id.clone()))
        .with_group_instance_id(
            group_metadata
                .group_instance_id
                .clone()
                .map(StrBytes::from_string),
        )
        .with_topics(
            topics
                .into_iter()
                .map(|(topic, partitions)| {
                    TxnOffsetCommitRequestTopic::default()
                        .with_name(StrBytes::from_string(topic).into())
                        .with_partitions(partitions)
                })
                .collect(),
        )
}

pub fn parse_find_coordinator_response(
    response: FindCoordinatorResponse,
    version: i16,
) -> Result<TransactionCoordinator> {
    if version >= 4 {
        let coordinator = response
            .coordinators
            .into_iter()
            .next()
            .context("FindCoordinator returned no coordinators")?;
        if let Some(error) = coordinator.error_code.err() {
            bail!("FindCoordinator failed: {error}");
        }
        let port = u16::try_from(coordinator.port)
            .with_context(|| format!("invalid coordinator port {}", coordinator.port))?;
        return Ok(TransactionCoordinator {
            broker_id: *coordinator.node_id,
            address: format!("{}:{}", coordinator.host, port),
        });
    }

    if let Some(error) = response.error_code.err() {
        bail!("FindCoordinator failed: {error}");
    }
    let port = u16::try_from(response.port)
        .with_context(|| format!("invalid coordinator port {}", response.port))?;
    Ok(TransactionCoordinator {
        broker_id: *response.node_id,
        address: format!("{}:{}", response.host, port),
    })
}

pub fn find_coordinator_error(
    response: &FindCoordinatorResponse,
    version: i16,
) -> Option<ResponseError> {
    if version >= 4 {
        return response
            .coordinators
            .first()
            .and_then(|coordinator| coordinator.error_code.err());
    }

    response.error_code.err()
}

#[cfg(test)]
mod tests {
    use super::*;
    use kafka_protocol::messages::BrokerId;
    use kafka_protocol::messages::find_coordinator_response::Coordinator;

    #[test]
    fn validate_group_metadata_requires_group_id() {
        let error = validate_group_metadata(&ConsumerGroupMetadata {
            group_id: "   ".to_owned(),
            generation_id: 0,
            member_id: String::new(),
            group_instance_id: None,
        })
        .unwrap_err();

        assert!(matches!(error, ConsumerGroupMetadataError::EmptyGroupId));
    }

    #[test]
    fn validate_group_metadata_requires_member_id_for_active_generation() {
        let error = validate_group_metadata(&ConsumerGroupMetadata {
            group_id: "group-a".to_owned(),
            generation_id: 3,
            member_id: String::new(),
            group_instance_id: None,
        })
        .unwrap_err();

        assert!(matches!(error, ConsumerGroupMetadataError::MissingMemberId));
    }

    #[test]
    fn classify_transactional_errors_distinguishes_fatal_from_abort_only() {
        assert!(matches!(
            classify_transactional_error(ResponseError::ProducerFenced),
            TransactionFailureDisposition::Fatal
        ));
        assert!(matches!(
            classify_transactional_error(ResponseError::UnknownServerError),
            TransactionFailureDisposition::AbortOnly
        ));
    }

    #[test]
    fn find_coordinator_requests_follow_version_shape() {
        let old = build_find_coordinator_request("tx-a", 3);
        assert_eq!(old.key.to_string(), "tx-a");
        assert!(old.coordinator_keys.is_empty());
        assert_eq!(old.key_type, FIND_COORDINATOR_TRANSACTION_KEY_TYPE);

        let modern = build_find_coordinator_request("tx-a", 4);
        assert!(modern.key.is_empty());
        assert_eq!(modern.coordinator_keys[0].to_string(), "tx-a");

        let group = build_group_find_coordinator_request("group-a", 4);
        assert_eq!(group.key_type, FIND_COORDINATOR_GROUP_KEY_TYPE);
        assert_eq!(group.coordinator_keys[0].to_string(), "group-a");
    }

    #[test]
    fn txn_offset_commit_request_groups_offsets_by_topic() {
        let request = build_txn_offset_commit_request(
            "tx-a",
            ProducerIdentity { id: 42, epoch: 3 },
            &[
                CommitOffset {
                    topic: "topic-a".to_owned(),
                    partition: 0,
                    offset: 7,
                },
                CommitOffset {
                    topic: "topic-b".to_owned(),
                    partition: 1,
                    offset: 11,
                },
                CommitOffset {
                    topic: "topic-a".to_owned(),
                    partition: 2,
                    offset: 13,
                },
            ],
            &ConsumerGroupMetadata {
                group_id: "group-a".to_owned(),
                generation_id: 5,
                member_id: "member-a".to_owned(),
                group_instance_id: Some("instance-a".to_owned()),
            },
        );

        assert_eq!(request.transactional_id.0.to_string(), "tx-a");
        assert_eq!(request.producer_id.0, 42);
        assert_eq!(request.producer_epoch, 3);
        assert_eq!(request.group_id.0.to_string(), "group-a");
        assert_eq!(request.group_instance_id.unwrap().to_string(), "instance-a");
        assert_eq!(request.topics.len(), 2);
        let topic_a = request
            .topics
            .iter()
            .find(|topic| topic.name.to_string() == "topic-a")
            .unwrap();
        assert_eq!(topic_a.partitions.len(), 2);
    }

    #[test]
    fn parse_find_coordinator_response_handles_old_and_modern_errors() {
        let old = FindCoordinatorResponse::default()
            .with_node_id(BrokerId(2))
            .with_host(StrBytes::from_static_str("broker-a"))
            .with_port(9092);
        let coordinator = parse_find_coordinator_response(old.clone(), 3).unwrap();
        assert_eq!(coordinator.broker_id, 2);
        assert_eq!(coordinator.address, "broker-a:9092");
        assert!(parse_find_coordinator_response(old.with_port(-1), 3).is_err());

        let modern = FindCoordinatorResponse::default().with_coordinators(vec![
            Coordinator::default()
                .with_node_id(BrokerId(3))
                .with_host(StrBytes::from_static_str("broker-b"))
                .with_port(9093),
        ]);
        let coordinator = parse_find_coordinator_response(modern, 4).unwrap();
        assert_eq!(coordinator.broker_id, 3);
        assert_eq!(coordinator.address, "broker-b:9093");

        let error = FindCoordinatorResponse::default()
            .with_error_code(ResponseError::CoordinatorNotAvailable.code());
        assert!(find_coordinator_error(&error, 3).is_some());
        assert!(parse_find_coordinator_response(error, 3).is_err());

        let error = FindCoordinatorResponse::default().with_coordinators(vec![
            Coordinator::default().with_error_code(ResponseError::CoordinatorNotAvailable.code()),
        ]);
        assert!(find_coordinator_error(&error, 4).is_some());
        assert!(parse_find_coordinator_response(error, 4).is_err());
        assert!(parse_find_coordinator_response(FindCoordinatorResponse::default(), 4).is_err());
    }
}