kafkit-client 0.1.2

Kafka 4.0+ pure Rust client.
Documentation
use std::collections::HashMap;
use std::time::Duration;

use anyhow::{Context, Result as AnyResult, bail};
use kafka_protocol::error::ParseResponseErrorCode;
use kafka_protocol::messages::find_coordinator_response::FindCoordinatorResponse;
use kafka_protocol::messages::share_group_heartbeat_response::Assignment;
use kafka_protocol::messages::{
    FindCoordinatorRequest, ShareGroupHeartbeatRequest, ShareGroupHeartbeatResponse,
};
use kafka_protocol::protocol::StrBytes;
use tokio::time::sleep;
use tracing::trace;

use crate::constants::{
    FIND_COORDINATOR_GROUP_KEY_TYPE, FIND_COORDINATOR_VERSION_CAP,
    SHARE_GROUP_HEARTBEAT_VERSION_CAP,
};
use crate::metadata::{BrokerAddress, MetadataRefresh, refresh_metadata};
use crate::network::{BrokerConnection, connect_to_any_bootstrap};
use crate::types::TopicPartitionKey;

use super::types::ShareAssignment;
use super::{KafkaShareConsumer, SHARE_COORDINATOR_RETRY_ATTEMPTS, SHARE_MEMBER_LEAVE_EPOCH};

impl KafkaShareConsumer {
    pub(super) async fn heartbeat_with_retries(
        &mut self,
        include_subscription: bool,
    ) -> AnyResult<()> {
        let mut last_error = None;
        let attempts = self.share_coordinator_retry_attempts();
        for attempt in 0..attempts {
            match self.heartbeat(include_subscription).await {
                Ok(()) => return Ok(()),
                Err(error) => {
                    let is_coordinator_error = is_coordinator_error(&error);
                    if is_coordinator_error {
                        self.coordinator = None;
                    }
                    last_error = Some(error);
                    if !is_coordinator_error {
                        break;
                    }
                    if attempt + 1 < attempts {
                        sleep(self.config.retry_backoff).await;
                    }
                }
            }
        }
        Err(last_error.expect("share heartbeat attempt count is at least one"))
    }

    pub(super) async fn leave_group(&mut self) -> AnyResult<()> {
        self.ensure_coordinator().await?;
        let version = self
            .coordinator
            .as_mut()
            .context("share group coordinator is not connected")?
            .version_with_cap::<ShareGroupHeartbeatRequest>(SHARE_GROUP_HEARTBEAT_VERSION_CAP)?;
        let request = ShareGroupHeartbeatRequest::default()
            .with_group_id(StrBytes::from_string(self.config.group_id.clone()).into())
            .with_member_id(StrBytes::from_string(self.member_id.clone()))
            .with_member_epoch(SHARE_MEMBER_LEAVE_EPOCH)
            .with_subscribed_topic_names(Some(Vec::new()));
        let client_id = self.config.client_id.clone();
        let _ = self
            .coordinator
            .as_mut()
            .expect("coordinator checked above")
            .send_request::<ShareGroupHeartbeatRequest>(&client_id, version, &request)
            .await?;
        Ok(())
    }

    pub(super) fn assignments_by_leader(&self) -> HashMap<i32, Vec<ShareAssignment>> {
        let mut grouped = HashMap::<i32, Vec<ShareAssignment>>::new();
        for assignment in self.assignments.values() {
            grouped
                .entry(assignment.leader_id)
                .or_default()
                .push(assignment.clone());
        }
        grouped
    }

    pub(super) async fn refresh_metadata_if_needed(&mut self) -> AnyResult<()> {
        if self
            .metadata
            .needs_any_refresh(self.subscriptions.clone(), self.config.metadata_max_age)
        {
            refresh_metadata(MetadataRefresh {
                bootstrap_servers: &self.config.bootstrap_servers,
                client_id: &self.config.client_id,
                request_timeout: self.config.request_timeout,
                security_protocol: self.config.security_protocol,
                tls: &self.config.tls,
                sasl: &self.config.sasl,
                metadata: &mut self.metadata,
                topics: &self.subscriptions,
            })
            .await?;
        }
        Ok(())
    }

    pub(super) async fn refresh_metadata(&mut self) -> AnyResult<()> {
        refresh_metadata(MetadataRefresh {
            bootstrap_servers: &self.config.bootstrap_servers,
            client_id: &self.config.client_id,
            request_timeout: self.config.request_timeout,
            security_protocol: self.config.security_protocol,
            tls: &self.config.tls,
            sasl: &self.config.sasl,
            metadata: &mut self.metadata,
            topics: &self.subscriptions,
        })
        .await?;
        Ok(())
    }

    pub(super) async fn ensure_coordinator_with_retries(&mut self) -> AnyResult<()> {
        let mut last_error = None;
        let attempts = self.share_coordinator_retry_attempts();
        for attempt in 0..attempts {
            match self.ensure_coordinator().await {
                Ok(()) => return Ok(()),
                Err(error) => {
                    let is_coordinator_error = is_coordinator_error(&error);
                    last_error = Some(error);
                    if !is_coordinator_error {
                        break;
                    }
                    if attempt + 1 < attempts {
                        sleep(self.config.retry_backoff).await;
                    }
                }
            }
        }
        Err(last_error.expect("share coordinator lookup attempt count is at least one"))
    }

    pub(super) async fn leader_connection(
        &mut self,
        leader_id: i32,
    ) -> AnyResult<&mut BrokerConnection> {
        if !self.leader_connections.contains_key(&leader_id) {
            if self.metadata.broker(leader_id).is_none() {
                self.refresh_metadata().await?;
            }
            let broker = self
                .metadata
                .broker(leader_id)
                .with_context(|| format!("metadata missing broker {leader_id}"))?;
            let connection = BrokerConnection::connect_with_transport(
                &broker.address(),
                &self.config.client_id,
                self.config.request_timeout,
                self.config.security_protocol,
                &self.config.tls,
                &self.config.sasl,
            )
            .await?;
            self.leader_connections.insert(leader_id, connection);
        }
        Ok(self
            .leader_connections
            .get_mut(&leader_id)
            .expect("leader connection inserted above"))
    }

    async fn heartbeat(&mut self, include_subscription: bool) -> AnyResult<()> {
        self.ensure_coordinator().await?;
        let version = self
            .coordinator
            .as_mut()
            .context("share group coordinator is not connected")?
            .version_with_cap::<ShareGroupHeartbeatRequest>(SHARE_GROUP_HEARTBEAT_VERSION_CAP)?;
        let subscribed_topic_names = include_subscription.then(|| {
            self.subscriptions
                .iter()
                .cloned()
                .map(StrBytes::from_string)
                .map(Into::into)
                .collect()
        });
        let request = ShareGroupHeartbeatRequest::default()
            .with_group_id(StrBytes::from_string(self.config.group_id.clone()).into())
            .with_member_id(StrBytes::from_string(self.member_id.clone()))
            .with_member_epoch(self.member_epoch)
            .with_rack_id(self.config.rack_id.clone().map(StrBytes::from_string))
            .with_subscribed_topic_names(subscribed_topic_names);
        let client_id = self.config.client_id.clone();
        let response: ShareGroupHeartbeatResponse = self
            .coordinator
            .as_mut()
            .expect("coordinator checked above")
            .send_request::<ShareGroupHeartbeatRequest>(&client_id, version, &request)
            .await?;
        self.handle_heartbeat_response(response)?;
        Ok(())
    }

    fn handle_heartbeat_response(
        &mut self,
        response: ShareGroupHeartbeatResponse,
    ) -> AnyResult<()> {
        if let Some(error) = response.error_code.err() {
            bail!("share group heartbeat failed: {error}");
        }
        if let Some(member_id) = response.member_id {
            self.member_id = member_id.to_string();
        }
        self.member_epoch = response.member_epoch;
        if response.heartbeat_interval_ms > 0 {
            self.heartbeat_interval = Duration::from_millis(response.heartbeat_interval_ms as u64);
        }
        if let Some(assignment) = response.assignment {
            self.apply_assignment(assignment)?;
        }
        Ok(())
    }

    fn apply_assignment(&mut self, assignment: Assignment) -> AnyResult<()> {
        let mut assignments = HashMap::new();
        for topic in assignment.topic_partitions {
            let topic_name = self
                .metadata
                .topic_name(&topic.topic_id)
                .cloned()
                .with_context(|| format!("metadata missing topic id {}", topic.topic_id))?;
            for partition in topic.partitions {
                let partition_metadata = self
                    .metadata
                    .partition(&topic_name, partition)
                    .with_context(|| format!("metadata missing {topic_name}:{partition}"))?;
                assignments.insert(
                    TopicPartitionKey::new(topic_name.clone(), partition),
                    ShareAssignment {
                        topic_id: topic.topic_id,
                        topic: topic_name.clone(),
                        partition,
                        leader_id: partition_metadata.leader_id,
                        leader_epoch: partition_metadata.leader_epoch,
                    },
                );
            }
        }
        self.assignments = assignments;
        Ok(())
    }

    async fn ensure_coordinator(&mut self) -> AnyResult<()> {
        if self.coordinator.is_some() {
            return Ok(());
        }
        let mut bootstrap = connect_to_any_bootstrap(
            &self.config.bootstrap_servers,
            &self.config.client_id,
            self.config.request_timeout,
            self.config.security_protocol,
            &self.config.tls,
            &self.config.sasl,
        )
        .await?;
        let version =
            bootstrap.version_with_cap::<FindCoordinatorRequest>(FIND_COORDINATOR_VERSION_CAP)?;
        let request = if version >= 4 {
            FindCoordinatorRequest::default()
                .with_key_type(FIND_COORDINATOR_GROUP_KEY_TYPE)
                .with_coordinator_keys(vec![StrBytes::from_string(self.config.group_id.clone())])
        } else {
            FindCoordinatorRequest::default()
                .with_key(StrBytes::from_string(self.config.group_id.clone()))
                .with_key_type(FIND_COORDINATOR_GROUP_KEY_TYPE)
        };
        let response = bootstrap
            .send_request::<FindCoordinatorRequest>(&self.config.client_id, version, &request)
            .await?;
        let coordinator = parse_find_coordinator_response(response, version)?;
        let connection = BrokerConnection::connect_with_transport(
            &coordinator.address(),
            &self.config.client_id,
            self.config.request_timeout,
            self.config.security_protocol,
            &self.config.tls,
            &self.config.sasl,
        )
        .await?;
        trace!(coordinator = %coordinator.address(), "connected to share group coordinator");
        self.coordinator = Some(connection);
        Ok(())
    }

    fn share_coordinator_retry_attempts(&self) -> usize {
        self.config
            .max_retries
            .max(SHARE_COORDINATOR_RETRY_ATTEMPTS)
            .max(1)
    }
}

fn parse_find_coordinator_response(
    response: FindCoordinatorResponse,
    version: i16,
) -> AnyResult<BrokerAddress> {
    if version >= 4 {
        let coordinator = response
            .coordinators
            .into_iter()
            .next()
            .context("find coordinator returned no coordinator entries")?;
        if let Some(error) = coordinator.error_code.err() {
            bail!("find coordinator failed: {error}");
        }
        return Ok(BrokerAddress::new(
            coordinator.host.to_string(),
            u16::try_from(coordinator.port)
                .with_context(|| format!("invalid coordinator port {}", coordinator.port))?,
        ));
    }

    if let Some(error) = response.error_code.err() {
        bail!("find coordinator failed: {error}");
    }
    Ok(BrokerAddress::new(
        response.host.to_string(),
        u16::try_from(response.port)
            .with_context(|| format!("invalid coordinator port {}", response.port))?,
    ))
}

fn is_coordinator_error(error: &anyhow::Error) -> bool {
    let message = error.to_string();
    message.contains("NotCoordinator")
        || message.contains("CoordinatorNotAvailable")
        || message.contains("CoordinatorLoadInProgress")
}