kafkit-client 0.1.2

Kafka 4.0+ pure Rust client.
Documentation
mod assignment;
mod commits;
mod fetch;
mod group;
mod handlers;
mod metadata;
mod poll;

use std::collections::{BTreeSet, HashMap};
use std::time::Duration;

use anyhow::{Context, Result as AnyResult, anyhow, bail};
use bytes::Buf;
use kafka_protocol::error::ParseResponseErrorCode;
use kafka_protocol::messages::consumer_group_heartbeat_response::Assignment as HeartbeatAssignment;
use kafka_protocol::messages::fetch_response::FetchResponse;
use kafka_protocol::messages::list_offsets_response::ListOffsetsResponse;
use kafka_protocol::messages::offset_commit_response::OffsetCommitResponse;
use kafka_protocol::messages::offset_fetch_response::OffsetFetchResponse;
use kafka_protocol::messages::{
    ConsumerGroupHeartbeatRequest, ConsumerGroupHeartbeatResponse, FindCoordinatorRequest,
    ListOffsetsRequest, OffsetCommitRequest, OffsetFetchRequest,
};
use kafka_protocol::protocol::StrBytes;
use kafka_protocol::records::RecordBatchDecoder;
use tokio::sync::mpsc;
use tokio::time::{Instant, sleep};
use tracing::{Instrument, debug, trace, warn};

use super::ConsumerRuntimeEvent;
use super::protocol::{
    CoordinatorLookupResult, assignment_snapshot_by_topic_id, build_fetch_request,
    build_list_offsets_request, build_list_offsets_request_with_timestamps,
    build_offset_commit_request, group_topic_partitions, is_retriable_error,
    parse_find_coordinator_response,
};
use super::scheduler::{ConsumerNetworkAction, ConsumerRequestManagers};
use super::state::{
    CommitKind, ConsumerAssignmentState, ConsumerConnectionState, ConsumerLifecycleState,
    ConsumerPollState, HeartbeatState, PendingCommit, PendingPoll,
};
use crate::config::{AutoOffsetReset, ConsumerConfig};
use crate::constants::{
    FETCH_VERSION_CAP, FIND_COORDINATOR_GROUP_KEY_TYPE, FIND_COORDINATOR_VERSION_CAP,
    HEARTBEAT_VERSION_CAP, LIST_OFFSETS_EARLIEST, LIST_OFFSETS_LATEST, LIST_OFFSETS_VERSION_CAP,
    OFFSET_COMMIT_VERSION_CAP, OFFSET_FETCH_VERSION_CAP,
};
use crate::metadata::{MetadataRefresh, refresh_metadata};
use crate::network::{BrokerConnection, connect_to_any_bootstrap, duration_to_i32_ms};
use crate::types::{
    AssignedPartition, CommitOffset, ConsumerGroupMetadata, ConsumerRecord, ConsumerRecords,
    RecordHeader, TopicPartition, TopicPartitionInfo, TopicPartitionKey, TopicPartitionOffset,
    TopicPartitionOffsetAndTimestamp, TopicPartitionTimestamp,
};
use crate::{ConsumerError, Error, Result as ClientResult};

#[derive(Debug, Clone, Default)]
pub enum ConsumerSubscription {
    #[default]
    None,
    Topics(BTreeSet<String>),
    Regex {
        regex: regex::Regex,
    },
    Pattern(String),
}

pub struct ConsumerRuntime {
    pub config: ConsumerConfig,
    pub connections: ConsumerConnectionState,
    pub assignment_state: ConsumerAssignmentState,
    pub heartbeat_state: HeartbeatState,
    pub poll_state: ConsumerPollState,
    pub lifecycle: ConsumerLifecycleState,
}

impl ConsumerRuntime {
    pub fn new(config: ConsumerConfig) -> Self {
        Self {
            config,
            connections: ConsumerConnectionState::default(),
            assignment_state: ConsumerAssignmentState::default(),
            heartbeat_state: HeartbeatState::new(),
            poll_state: ConsumerPollState::default(),
            lifecycle: ConsumerLifecycleState::new(),
        }
    }

    pub async fn run(mut self, mut rx: mpsc::Receiver<ConsumerRuntimeEvent>) {
        let client_id = self.config.client_id.clone();
        let group_id = self.config.group_id.clone();
        async move {
            let mut running = true;

            while running {
                let wake = ConsumerRequestManagers::next_wakeup(&self);
                tokio::select! {
                    biased;
                    maybe_event = rx.recv() => {
                        match maybe_event {
                            Some(event) => self.lifecycle.runtime_events.push_back(event),
                            None => {
                                debug!("consumer application event channel closed");
                                self.lifecycle.shutting_down = true;
                            }
                        }
                    }
                    _ = sleep(wake) => {}
                }

                match self.run_once(&mut rx).await {
                    Ok(should_continue) => running = should_continue,
                    Err(error) => {
                        if is_retriable_error(&error) {
                            debug!(error = %error, "consumer network action failed with retriable error");
                            continue;
                        }
                        let client_error = into_client_error(error);
                        warn!(
                            error = %client_error,
                            shutting_down = self.lifecycle.shutting_down,
                            "consumer runtime encountered non-retriable error"
                        );
                        let error_text = format!("{client_error:#}");
                        self.fail_pending_poll(client_error);
                        self.fail_pending_commits(&error_text);
                        if self.lifecycle.shutting_down {
                            self.finish_shutdown_with_error(&error_text);
                            running = false;
                        }
                    }
                }
            }

            self.fail_pending_commits("consumer runtime stopped");
            if self.lifecycle.close_reply.is_some() {
                self.finish_shutdown_with_error("consumer runtime stopped");
            }
        }
        .instrument(tracing::debug_span!(
            "consumer_runtime",
            %client_id,
            %group_id
        ))
        .await;
    }

    async fn run_once(&mut self, rx: &mut mpsc::Receiver<ConsumerRuntimeEvent>) -> AnyResult<bool> {
        self.drain_runtime_events(rx);
        self.process_runtime_events().await?;
        self.maybe_enqueue_auto_commit();

        let actions = ConsumerRequestManagers::poll(self);
        for action in actions {
            self.execute_network_action(action).await?;
        }

        self.maybe_reconcile_pending_assignment().await?;
        self.maybe_complete_poll();

        Ok(self.finish_shutdown_if_ready())
    }

    pub(crate) fn has_group_subscription(&self) -> bool {
        !matches!(
            self.assignment_state.group_subscription,
            ConsumerSubscription::None
        )
    }

    pub(crate) fn is_pattern_subscription(&self) -> bool {
        matches!(
            self.assignment_state.group_subscription,
            ConsumerSubscription::Regex { .. } | ConsumerSubscription::Pattern(_)
        )
    }

    pub(crate) fn subscribed_topic_names(&self) -> BTreeSet<String> {
        match &self.assignment_state.group_subscription {
            ConsumerSubscription::Topics(topics) => topics.clone(),
            ConsumerSubscription::Regex { regex, .. } => self
                .connections
                .metadata
                .topic_names()
                .into_iter()
                .filter(|topic| regex.is_match(topic))
                .collect(),
            ConsumerSubscription::Pattern(_) | ConsumerSubscription::None => BTreeSet::new(),
        }
    }

    pub(crate) fn subscription_pattern(&self) -> Option<&str> {
        match &self.assignment_state.group_subscription {
            ConsumerSubscription::Pattern(pattern) => Some(pattern.as_str()),
            ConsumerSubscription::Topics(_)
            | ConsumerSubscription::Regex { .. }
            | ConsumerSubscription::None => None,
        }
    }

    pub(crate) fn current_assignment(&self) -> BTreeSet<TopicPartition> {
        self.assignment_state
            .assignment
            .keys()
            .map(|key| TopicPartition::new(key.topic.clone(), key.partition))
            .collect()
    }
}

fn consumer_event_name(event: &ConsumerRuntimeEvent) -> &'static str {
    match event {
        ConsumerRuntimeEvent::WarmUp { .. } => "warm_up",
        ConsumerRuntimeEvent::Subscribe { .. } => "subscribe",
        ConsumerRuntimeEvent::SubscribePattern { .. } => "subscribe_pattern",
        ConsumerRuntimeEvent::SubscribeRegex { .. } => "subscribe_regex",
        ConsumerRuntimeEvent::Unsubscribe { .. } => "unsubscribe",
        ConsumerRuntimeEvent::Assign { .. } => "assign",
        ConsumerRuntimeEvent::Poll { .. } => "poll",
        ConsumerRuntimeEvent::Seek { .. } => "seek",
        ConsumerRuntimeEvent::SeekToBeginning { .. } => "seek_to_beginning",
        ConsumerRuntimeEvent::SeekToEnd { .. } => "seek_to_end",
        ConsumerRuntimeEvent::SeekToTimestamp { .. } => "seek_to_timestamp",
        ConsumerRuntimeEvent::Position { .. } => "position",
        ConsumerRuntimeEvent::Pause { .. } => "pause",
        ConsumerRuntimeEvent::Resume { .. } => "resume",
        ConsumerRuntimeEvent::GroupMetadata { .. } => "group_metadata",
        ConsumerRuntimeEvent::Assignment { .. } => "assignment",
        ConsumerRuntimeEvent::Committed { .. } => "committed",
        ConsumerRuntimeEvent::BeginningOffsets { .. } => "beginning_offsets",
        ConsumerRuntimeEvent::EndOffsets { .. } => "end_offsets",
        ConsumerRuntimeEvent::OffsetsForTimes { .. } => "offsets_for_times",
        ConsumerRuntimeEvent::PartitionsFor { .. } => "partitions_for",
        ConsumerRuntimeEvent::ListTopics { .. } => "list_topics",
        ConsumerRuntimeEvent::Commit { .. } => "commit",
        ConsumerRuntimeEvent::Wakeup => "wakeup",
        ConsumerRuntimeEvent::Shutdown { .. } => "shutdown",
    }
}

fn consumer_network_action_name(action: &ConsumerNetworkAction) -> &'static str {
    match action {
        ConsumerNetworkAction::RefreshMetadata => "refresh_metadata",
        ConsumerNetworkAction::EnsureCoordinator => "ensure_coordinator",
        ConsumerNetworkAction::Heartbeat => "heartbeat",
        ConsumerNetworkAction::Commit => "commit",
        ConsumerNetworkAction::Fetch { .. } => "fetch",
        ConsumerNetworkAction::LeaveGroup => "leave_group",
    }
}

pub(crate) fn into_client_error(error: anyhow::Error) -> Error {
    match error.downcast::<Error>() {
        Ok(error) => error,
        Err(error) => Error::Internal(error),
    }
}

#[cfg(test)]
mod tests {
    use std::collections::BTreeSet;

    use super::*;
    use crate::consumer::scheduler::{ConsumerNetworkAction, ConsumerRequestManagers};

    #[test]
    fn heartbeat_request_sends_full_state_on_join_then_delta() {
        let config = ConsumerConfig::new("localhost:9092", "group-a")
            .with_server_assignor("uniform")
            .with_rack_id("rack-a");
        let mut runtime = ConsumerRuntime::new(config);
        runtime.assignment_state.group_subscription =
            ConsumerSubscription::Topics(BTreeSet::from(["topic-a".to_owned()]));

        let first = runtime.build_heartbeat_request(1).unwrap();
        assert_eq!(first.member_epoch, 0);
        assert!(first.subscribed_topic_names.is_some());
        assert!(first.server_assignor.is_some());
        assert!(first.rebalance_timeout_ms > 0);

        runtime.heartbeat_state.member_epoch = 3;
        let second = runtime.build_heartbeat_request(1).unwrap();
        assert!(second.subscribed_topic_names.is_none());
        assert!(second.server_assignor.is_none());
        assert_eq!(second.rebalance_timeout_ms, -1);
    }

    #[test]
    fn heartbeat_request_sends_regex_subscription() {
        let mut runtime = ConsumerRuntime::new(ConsumerConfig::new("localhost:9092", "group-a"));
        runtime.assignment_state.group_subscription =
            ConsumerSubscription::Pattern("topic-.*".to_owned());

        let request = runtime.build_heartbeat_request(1).unwrap();
        assert_eq!(request.subscribed_topic_names, Some(Vec::new()));
        assert_eq!(
            request
                .subscribed_topic_regex
                .as_ref()
                .map(ToString::to_string),
            Some("topic-.*".to_owned())
        );
    }

    #[test]
    fn subscription_changes_keep_live_member_epoch() {
        let mut runtime = ConsumerRuntime::new(ConsumerConfig::new("localhost:9092", "group-a"));
        runtime.heartbeat_state.member_epoch = 3;

        runtime
            .subscribe_pattern(crate::types::SubscriptionPattern::new("topic-.*"))
            .unwrap();

        assert_eq!(runtime.heartbeat_state.member_epoch, 3);
        let request = runtime.build_heartbeat_request(1).unwrap();
        assert_eq!(request.member_epoch, 3);
        assert_eq!(
            request
                .subscribed_topic_regex
                .as_ref()
                .map(ToString::to_string),
            Some("topic-.*".to_owned())
        );
    }

    #[test]
    fn static_members_use_java_leave_epoch() {
        let static_runtime = ConsumerRuntime::new(
            ConsumerConfig::new("localhost:9092", "group-a").with_instance_id("instance-a"),
        );
        assert_eq!(static_runtime.leave_group_epoch(), -2);

        let dynamic_runtime =
            ConsumerRuntime::new(ConsumerConfig::new("localhost:9092", "group-a"));
        assert_eq!(dynamic_runtime.leave_group_epoch(), -1);
    }

    #[test]
    fn shutdown_drains_commits_before_leave() {
        let mut runtime = ConsumerRuntime::new(ConsumerConfig::new("localhost:9092", "group-a"));
        runtime.lifecycle.shutting_down = true;
        runtime.heartbeat_state.member_epoch = 4;
        runtime.lifecycle.pending_commits.push_back(PendingCommit {
            offsets: vec![CommitOffset {
                topic: "topic-a".to_owned(),
                partition: 0,
                offset: 12,
            }],
            reply: None,
            kind: CommitKind::Manual,
        });

        let actions = ConsumerRequestManagers::poll(&runtime);
        assert_eq!(actions.len(), 1);
        assert!(matches!(
            actions[0],
            ConsumerNetworkAction::EnsureCoordinator
        ));
        assert!(
            actions
                .iter()
                .all(|action| !matches!(action, ConsumerNetworkAction::LeaveGroup))
        );
    }
}