kafkit-client 0.1.9

Kafka 4.0+ pure Rust client.
Documentation
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};

use anyhow::{Context, Result, anyhow, bail};
use kafka_protocol::error::ParseResponseErrorCode;
use kafka_protocol::messages::MetadataRequest;
use kafka_protocol::messages::metadata_request::MetadataRequestTopic;
use kafka_protocol::messages::metadata_response::MetadataResponse;
use kafka_protocol::protocol::StrBytes;
use tracing::{debug, instrument};
use uuid::Uuid;

use crate::config::{SaslConfig, SecurityProtocol, TlsConfig};
use crate::constants::METADATA_VERSION_CAP;
use crate::network::{TcpConnector, connect_to_any_bootstrap};

#[derive(Default)]
pub struct MetadataCache {
    brokers: HashMap<i32, BrokerAddress>,
    topics_by_name: HashMap<String, TopicMetadata>,
    topics_by_id: HashMap<Uuid, String>,
    last_refresh: Option<Instant>,
}

impl MetadataCache {
    pub fn merge_response(&mut self, response: MetadataResponse) -> Result<()> {
        for broker in response.brokers {
            self.brokers.insert(
                broker.node_id.0,
                BrokerAddress {
                    host: broker.host.to_string(),
                    port: u16::try_from(broker.port)
                        .with_context(|| format!("invalid broker port {}", broker.port))?,
                },
            );
        }

        for topic in response.topics {
            let Some(name) = topic.name.as_ref().map(|name| name.0.to_string()) else {
                continue;
            };

            if let Some(error) = topic.error_code.err() {
                return Err(anyhow!("metadata error for topic '{name}': {error}"));
            }

            let partitions = topic
                .partitions
                .into_iter()
                .filter_map(|partition| {
                    partition.error_code.ok()?;
                    Some((
                        partition.partition_index,
                        PartitionMetadata {
                            leader_id: partition.leader_id.0,
                            leader_epoch: partition.leader_epoch,
                            replica_nodes: partition
                                .replica_nodes
                                .into_iter()
                                .map(|broker| broker.0)
                                .collect(),
                            isr_nodes: partition
                                .isr_nodes
                                .into_iter()
                                .map(|broker| broker.0)
                                .collect(),
                            offline_replicas: partition
                                .offline_replicas
                                .into_iter()
                                .map(|broker| broker.0)
                                .collect(),
                        },
                    ))
                })
                .collect();

            if !topic.topic_id.is_nil() {
                if let Some(previous) = self.topics_by_name.get(&name)
                    && previous.topic_id != topic.topic_id
                    && !previous.topic_id.is_nil()
                {
                    self.topics_by_id.remove(&previous.topic_id);
                }
                self.topics_by_id.insert(topic.topic_id, name.clone());
            }

            self.topics_by_name.insert(
                name,
                TopicMetadata {
                    topic_id: topic.topic_id,
                    partitions,
                },
            );
        }

        self.last_refresh = Some(Instant::now());
        Ok(())
    }

    pub fn needs_refresh(&self, topic: &str, max_age: Duration) -> bool {
        if !self.topics_by_name.contains_key(topic) {
            return true;
        }

        match self.last_refresh {
            Some(last_refresh) => last_refresh.elapsed() >= max_age,
            None => true,
        }
    }

    pub fn needs_any_refresh(&self, topics: Vec<String>, max_age: Duration) -> bool {
        topics
            .iter()
            .any(|topic| self.needs_refresh(topic, max_age))
    }

    pub fn is_stale(&self, max_age: Duration) -> bool {
        match self.last_refresh {
            Some(last_refresh) => last_refresh.elapsed() >= max_age,
            None => true,
        }
    }

    pub fn leader_for(&self, topic: &str, partition: i32) -> Option<i32> {
        self.topics_by_name
            .get(topic)?
            .partitions
            .get(&partition)
            .map(|partition| partition.leader_id)
    }

    pub fn partitions_for(&self, topic: &str) -> Option<Vec<(i32, PartitionMetadata)>> {
        let mut partitions = self
            .topics_by_name
            .get(topic)?
            .partitions
            .iter()
            .map(|(partition, metadata)| (*partition, metadata.clone()))
            .collect::<Vec<_>>();
        partitions.sort_by_key(|(partition, _)| *partition);
        Some(partitions)
    }

    pub fn topic_names(&self) -> Vec<String> {
        let mut topics = self.topics_by_name.keys().cloned().collect::<Vec<_>>();
        topics.sort();
        topics
    }

    pub fn partition(&self, topic: &str, partition: i32) -> Option<&PartitionMetadata> {
        self.topics_by_name.get(topic)?.partitions.get(&partition)
    }

    pub fn broker(&self, broker_id: i32) -> Option<&BrokerAddress> {
        self.brokers.get(&broker_id)
    }

    pub fn topic_name(&self, topic_id: &Uuid) -> Option<&String> {
        self.topics_by_id.get(topic_id)
    }

    pub fn topic_id(&self, topic: &str) -> Option<Uuid> {
        self.topics_by_name
            .get(topic)
            .map(|metadata| metadata.topic_id)
    }

    pub fn contains_broker(&self, broker_id: i32) -> bool {
        self.brokers.contains_key(&broker_id)
    }

    pub fn contains_topic(&self, topic: &str) -> bool {
        self.topics_by_name.contains_key(topic)
    }

    pub fn last_refresh(&self) -> Option<Instant> {
        self.last_refresh
    }

    pub fn invalidate_topic(&mut self, topic: &str) {
        if let Some(metadata) = self.topics_by_name.remove(topic)
            && !metadata.topic_id.is_nil()
        {
            self.topics_by_id.remove(&metadata.topic_id);
        }
        self.last_refresh = None;
    }

    pub fn invalidate_all(&mut self) {
        self.topics_by_name.clear();
        self.topics_by_id.clear();
        self.last_refresh = None;
    }
}

struct TopicMetadata {
    topic_id: Uuid,
    partitions: HashMap<i32, PartitionMetadata>,
}

#[derive(Debug, Clone)]
pub struct PartitionMetadata {
    pub leader_id: i32,
    pub leader_epoch: i32,
    pub replica_nodes: Vec<i32>,
    pub isr_nodes: Vec<i32>,
    pub offline_replicas: Vec<i32>,
}

#[derive(Debug, Clone)]
pub struct BrokerAddress {
    host: String,
    port: u16,
}

impl BrokerAddress {
    pub fn new(host: String, port: u16) -> Self {
        Self { host, port }
    }

    pub fn address(&self) -> String {
        format!("{}:{}", self.host, self.port)
    }
}

pub struct MetadataRefresh<'a> {
    pub bootstrap_servers: &'a [String],
    pub client_id: &'a str,
    pub request_timeout: Duration,
    pub security_protocol: SecurityProtocol,
    pub tls: &'a TlsConfig,
    pub sasl: &'a SaslConfig,
    pub tcp_connector: &'a Arc<dyn TcpConnector>,
    pub metadata: &'a mut MetadataCache,
    pub topics: &'a [String],
}

#[instrument(
    name = "metadata.refresh",
    level = "debug",
    skip(refresh),
    fields(
        bootstrap_server_count = refresh.bootstrap_servers.len(),
        client_id = %refresh.client_id,
        topic_count = refresh.topics.len()
    )
)]
pub async fn refresh_metadata(refresh: MetadataRefresh<'_>) -> Result<()> {
    let mut bootstrap = connect_to_any_bootstrap(
        refresh.bootstrap_servers,
        refresh.client_id,
        refresh.request_timeout,
        refresh.security_protocol,
        refresh.tls,
        refresh.sasl,
        refresh.tcp_connector,
    )
    .await?;
    let version = bootstrap.version_with_cap::<MetadataRequest>(METADATA_VERSION_CAP)?;
    let request = MetadataRequest::default()
        .with_topics((!refresh.topics.is_empty()).then(|| {
            refresh
                .topics
                .iter()
                .cloned()
                .map(StrBytes::from_string)
                .map(|name| MetadataRequestTopic::default().with_name(Some(name.into())))
                .collect()
        }))
        .with_allow_auto_topic_creation(false)
        .with_include_cluster_authorized_operations(false)
        .with_include_topic_authorized_operations(false);

    let response: MetadataResponse = bootstrap
        .send_request::<MetadataRequest>(refresh.client_id, version, &request)
        .await?;
    refresh.metadata.merge_response(response)?;

    for topic in refresh.topics {
        if !refresh.metadata.contains_topic(topic) {
            bail!("topic '{topic}' was not present in metadata response");
        }
    }

    debug!("metadata refresh completed");
    Ok(())
}

#[cfg(test)]
mod tests {
    use kafka_protocol::messages::metadata_response::MetadataResponseTopic;
    use kafka_protocol::messages::{MetadataResponse, TopicName};

    use super::*;

    #[test]
    fn merge_response_replaces_stale_topic_id_for_recreated_topic() {
        let mut cache = MetadataCache::default();
        let old_id = Uuid::from_u128(1);
        let new_id = Uuid::from_u128(2);

        cache
            .merge_response(metadata_response_with_topic_id("orders", old_id))
            .unwrap();
        assert_eq!(
            cache.topic_name(&old_id).map(String::as_str),
            Some("orders")
        );

        cache
            .merge_response(metadata_response_with_topic_id("orders", new_id))
            .unwrap();

        assert_eq!(cache.topic_name(&old_id), None);
        assert_eq!(
            cache.topic_name(&new_id).map(String::as_str),
            Some("orders")
        );
        assert_eq!(cache.topic_id("orders"), Some(new_id));
    }

    fn metadata_response_with_topic_id(topic: &str, topic_id: Uuid) -> MetadataResponse {
        MetadataResponse::default().with_topics(vec![
            MetadataResponseTopic::default()
                .with_name(Some(TopicName(StrBytes::from_string(topic.to_owned()))))
                .with_topic_id(topic_id),
        ])
    }
}