kafkit-client 0.1.7

Kafka 4.0+ pure Rust client.
Documentation
use std::collections::{BTreeMap, HashMap};
use std::time::Duration;

use anyhow::{Context, Result as AnyResult};
use kafka_protocol::error::{ParseResponseErrorCode, ResponseError};
use kafka_protocol::messages::ShareFetchRequest;
use kafka_protocol::messages::share_fetch_request::{FetchPartition, FetchTopic};
use kafka_protocol::messages::share_fetch_response::ShareFetchResponse;
use kafka_protocol::protocol::StrBytes;
use kafka_protocol::records::RecordBatchDecoder;
use tokio::time::sleep;
use tracing::debug;
use uuid::Uuid;

use crate::constants::SHARE_FETCH_VERSION_CAP;
use crate::network::duration_to_i32_ms;
use crate::types::{ConsumerRecord, RecordHeader, TopicPartitionKey};

use super::acknowledgements::{Acknowledgements, share_acknowledgement_commits};
use super::coordinator::is_transport_error;
use super::types::{ShareAssignment, ShareRecord, TopicIdPartitionKey};
use super::{KafkaShareConsumer, SHARE_SESSION_OPEN_EPOCH};

impl KafkaShareConsumer {
    pub(super) async fn fetch_from_leader(
        &mut self,
        leader_id: i32,
        assignments: Vec<ShareAssignment>,
        timeout: Duration,
        fetch_records: bool,
    ) -> AnyResult<Vec<ShareRecord>> {
        let acknowledgements = self.take_acks_for_assignments(&assignments);
        let acknowledgement_commits = share_acknowledgement_commits(&acknowledgements, None);
        let client_id = self.config.client_id.clone();
        let mut last_error = None;

        for attempt in 0..self.config.max_retries.max(1) {
            let session_epoch = *self
                .share_session_epochs
                .get(&leader_id)
                .unwrap_or(&SHARE_SESSION_OPEN_EPOCH);
            let request = self.build_share_fetch_request(
                assignments.clone(),
                acknowledgements.clone(),
                session_epoch,
                timeout,
                fetch_records,
            )?;
            let result = async {
                let connection = self.leader_connection(leader_id).await?;
                let version =
                    connection.version_with_cap::<ShareFetchRequest>(SHARE_FETCH_VERSION_CAP)?;
                connection
                    .send_request::<ShareFetchRequest>(&client_id, version, &request)
                    .await
            }
            .await
            .and_then(|response: ShareFetchResponse| self.process_share_fetch_response(response));

            match result {
                Ok(records) => {
                    self.advance_share_session_epoch(leader_id, session_epoch);
                    self.invoke_acknowledgement_commit_callback(acknowledgement_commits.clone());
                    return Ok(records);
                }
                Err(error) => {
                    if is_transport_error(&error) || is_share_session_reset_error(&error) {
                        self.drop_leader_connection(leader_id);
                    }
                    last_error = Some(error);
                    if attempt + 1 < self.config.max_retries.max(1) {
                        sleep(self.config.retry_backoff).await;
                    }
                }
            }
        }

        let error = last_error.expect("share fetch attempt count is at least one");
        self.invoke_acknowledgement_commit_callback(share_acknowledgement_commits(
            &acknowledgements,
            Some(error.to_string()),
        ));
        self.restore_acknowledgements(acknowledgements);
        Err(error)
    }

    pub(super) fn drop_leader_connection(&mut self, leader_id: i32) {
        self.leader_connections.remove(&leader_id);
        self.share_session_epochs.remove(&leader_id);
    }

    pub(super) fn has_acks_for_assignments(&self, assignments: &[ShareAssignment]) -> bool {
        assignments.iter().any(|assignment| {
            let key = TopicIdPartitionKey {
                topic_id: assignment.topic_id,
                topic: assignment.topic.clone(),
                partition: assignment.partition,
            };
            self.pending_acks
                .get(&key)
                .is_some_and(|acks| !acks.is_empty())
        })
    }

    pub(super) fn restore_acknowledgements(
        &mut self,
        acknowledgements: HashMap<TopicIdPartitionKey, Acknowledgements>,
    ) {
        for (key, acks) in acknowledgements {
            self.pending_acks.entry(key).or_default().extend(acks);
        }
    }

    fn build_share_fetch_request(
        &self,
        assignments: Vec<ShareAssignment>,
        acknowledgements: HashMap<TopicIdPartitionKey, Acknowledgements>,
        session_epoch: i32,
        timeout: Duration,
        fetch_records: bool,
    ) -> AnyResult<ShareFetchRequest> {
        let mut topics = BTreeMap::<Uuid, Vec<FetchPartition>>::new();
        for assignment in assignments {
            let key = TopicIdPartitionKey {
                topic_id: assignment.topic_id,
                topic: assignment.topic.clone(),
                partition: assignment.partition,
            };
            let acknowledgement_batches = acknowledgements
                .get(&key)
                .cloned()
                .unwrap_or_default()
                .into_share_fetch_batches();
            topics.entry(assignment.topic_id).or_default().push(
                FetchPartition::default()
                    .with_partition_index(assignment.partition)
                    .with_partition_max_bytes(self.config.partition_max_bytes)
                    .with_acknowledgement_batches(acknowledgement_batches),
            );
        }

        Ok(ShareFetchRequest::default()
            .with_group_id(Some(
                StrBytes::from_string(self.config.group_id.clone()).into(),
            ))
            .with_member_id(Some(StrBytes::from_string(self.member_id.clone())))
            .with_share_session_epoch(session_epoch)
            .with_max_wait_ms(if fetch_records {
                duration_to_i32_ms(timeout)?
            } else {
                0
            })
            .with_min_bytes(if fetch_records {
                self.config.fetch_min_bytes
            } else {
                0
            })
            .with_max_bytes(if fetch_records {
                self.config.fetch_max_bytes
            } else {
                0
            })
            .with_max_records(if fetch_records {
                self.max_poll_records
            } else {
                0
            })
            .with_batch_size(self.max_poll_records)
            .with_topics(
                topics
                    .into_iter()
                    .map(|(topic_id, partitions)| {
                        FetchTopic::default()
                            .with_topic_id(topic_id)
                            .with_partitions(partitions)
                    })
                    .collect(),
            )
            .with_forgotten_topics_data(Vec::new()))
    }

    fn take_acks_for_assignments(
        &mut self,
        assignments: &[ShareAssignment],
    ) -> HashMap<TopicIdPartitionKey, Acknowledgements> {
        let mut result = HashMap::new();
        for assignment in assignments {
            let key = TopicIdPartitionKey {
                topic_id: assignment.topic_id,
                topic: assignment.topic.clone(),
                partition: assignment.partition,
            };
            if let Some(acks) = self.pending_acks.remove(&key) {
                result.insert(key, acks);
            }
        }
        result
    }

    fn process_share_fetch_response(
        &mut self,
        response: ShareFetchResponse,
    ) -> AnyResult<Vec<ShareRecord>> {
        if let Some(error) = response.error_code.err() {
            return Err(error.into());
        }

        let mut fetched = Vec::new();
        for topic in response.responses {
            let topic_name = self
                .metadata
                .topic_name(&topic.topic_id)
                .cloned()
                .with_context(|| format!("metadata missing topic id {}", topic.topic_id))?;
            for partition in topic.partitions {
                if let Some(error) = partition.error_code.err() {
                    return Err(error.into());
                }
                if let Some(error) = partition.acknowledge_error_code.err() {
                    return Err(error.into());
                }
                let acquired = partition
                    .acquired_records
                    .iter()
                    .map(|range| (range.first_offset, range.last_offset, range.delivery_count))
                    .collect::<Vec<_>>();
                let Some(mut bytes) = partition.records else {
                    continue;
                };
                if bytes.is_empty() || acquired.is_empty() {
                    continue;
                }

                let batches = RecordBatchDecoder::decode_all(&mut bytes)?;
                let mut seen_offsets = BTreeMap::<i64, bool>::new();
                for (first, last, _) in &acquired {
                    for offset in *first..=*last {
                        seen_offsets.insert(offset, false);
                    }
                }

                for batch in batches {
                    for record in batch.records {
                        if record.control {
                            continue;
                        }
                        let Some(delivery_count) =
                            delivery_count_for_offset(&acquired, record.offset)
                        else {
                            continue;
                        };
                        seen_offsets.insert(record.offset, true);
                        fetched.push(ShareRecord {
                            record: ConsumerRecord {
                                topic: topic_name.clone(),
                                partition: partition.partition_index,
                                offset: record.offset,
                                timestamp: record.timestamp,
                                headers: record
                                    .headers
                                    .into_iter()
                                    .map(|(key, value)| RecordHeader {
                                        key: key.to_string(),
                                        value,
                                    })
                                    .collect(),
                                key: record.key,
                                value: record.value,
                            },
                            delivery_count,
                        });
                    }
                }

                for (offset, seen) in seen_offsets {
                    if !seen {
                        let key = TopicIdPartitionKey {
                            topic_id: topic.topic_id,
                            topic: topic_name.clone(),
                            partition: partition.partition_index,
                        };
                        self.pending_acks
                            .entry(key)
                            .or_default()
                            .offsets
                            .insert(offset, None);
                    }
                }

                let key = TopicPartitionKey::new(topic_name.clone(), partition.partition_index);
                if self
                    .metadata
                    .broker(partition.current_leader.leader_id)
                    .is_some()
                    && let Some(assigned) = self.assignments.get_mut(&key)
                {
                    assigned.leader_epoch = partition.current_leader.leader_epoch;
                    assigned.leader_id = partition.current_leader.leader_id;
                }
            }
        }
        debug!(records = fetched.len(), "share fetch returned records");
        Ok(fetched)
    }
}

pub(super) fn is_share_session_reset_error(error: &anyhow::Error) -> bool {
    error.chain().any(|cause| {
        cause.downcast_ref::<ResponseError>().is_some_and(|error| {
            matches!(
                error,
                ResponseError::ShareSessionNotFound
                    | ResponseError::InvalidShareSessionEpoch
                    | ResponseError::ShareSessionLimitReached
            )
        })
    })
}

fn delivery_count_for_offset(acquired: &[(i64, i64, i16)], offset: i64) -> Option<i16> {
    acquired.iter().find_map(|(first, last, delivery_count)| {
        (*first <= offset && offset <= *last).then_some(*delivery_count)
    })
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn delivery_count_matches_acquired_range() {
        let acquired = vec![(5, 8, 1), (12, 12, 3)];

        assert_eq!(delivery_count_for_offset(&acquired, 6), Some(1));
        assert_eq!(delivery_count_for_offset(&acquired, 12), Some(3));
        assert_eq!(delivery_count_for_offset(&acquired, 9), None);
    }
}