use std::collections::{BTreeMap, HashMap};
use std::time::{Duration, Instant};
use anyhow::{Context, Result as AnyResult};
use kafka_protocol::error::{ParseResponseErrorCode, ResponseError};
use kafka_protocol::messages::ShareFetchRequest;
use kafka_protocol::messages::share_fetch_request::{FetchPartition, FetchTopic};
use kafka_protocol::messages::share_fetch_response::ShareFetchResponse;
use kafka_protocol::protocol::StrBytes;
use kafka_protocol::records::RecordBatchDecoder;
use tokio::time::sleep;
use tracing::debug;
use uuid::Uuid;
use crate::constants::SHARE_FETCH_VERSION_CAP;
use crate::network::duration_to_i32_ms;
use crate::telemetry;
use crate::types::{ConsumerRecord, RecordHeader, TopicPartitionKey};
use super::acknowledgements::{
Acknowledgements, share_acknowledgement_commit, share_acknowledgement_commits,
};
use super::coordinator::is_transport_error;
use super::types::{ShareAcknowledgementCommit, ShareAssignment, ShareRecord, TopicIdPartitionKey};
use super::{KafkaShareConsumer, SHARE_SESSION_OPEN_EPOCH};
struct ShareFetchProcessingResult {
records: Vec<ShareRecord>,
acknowledgement_commits: Vec<ShareAcknowledgementCommit>,
}
impl KafkaShareConsumer {
pub(super) async fn fetch_from_leader(
&mut self,
leader_id: i32,
assignments: Vec<ShareAssignment>,
timeout: Duration,
fetch_records: bool,
) -> AnyResult<Vec<ShareRecord>> {
let acknowledgements = self.take_acks_for_assignments(&assignments);
let client_id = self.config.client_id.clone();
let started_at = Instant::now();
let mut last_error = None;
for attempt in 0..self.config.max_retries.max(1) {
let session_epoch = *self
.share_session_epochs
.get(&leader_id)
.unwrap_or(&SHARE_SESSION_OPEN_EPOCH);
let request = self.build_share_fetch_request(
assignments.clone(),
acknowledgements.clone(),
session_epoch,
timeout,
fetch_records,
)?;
let result = async {
let connection = self.leader_connection(leader_id).await?;
let version =
connection.version_with_cap::<ShareFetchRequest>(SHARE_FETCH_VERSION_CAP)?;
connection
.send_request::<ShareFetchRequest>(&client_id, version, &request)
.await
}
.await
.and_then(|response: ShareFetchResponse| {
self.process_share_fetch_response(response, &acknowledgements)
});
match result {
Ok(processed) => {
let record_count = processed.records.len();
telemetry::record_share_fetch_completed(
&self.config.client_id,
&self.config.group_id,
leader_id,
record_count,
started_at.elapsed(),
true,
);
self.advance_share_session_epoch(leader_id, session_epoch);
self.invoke_acknowledgement_commit_callback(processed.acknowledgement_commits);
return Ok(processed.records);
}
Err(error) => {
if is_transport_error(&error) || is_share_session_reset_error(&error) {
self.drop_leader_connection(leader_id);
}
last_error = Some(error);
if attempt + 1 < self.config.max_retries.max(1) {
sleep(self.config.retry_backoff).await;
}
}
}
}
let error = last_error.expect("share fetch attempt count is at least one");
self.invoke_acknowledgement_commit_callback(share_acknowledgement_commits(
&acknowledgements,
Some(error.to_string()),
));
telemetry::record_share_fetch_completed(
&self.config.client_id,
&self.config.group_id,
leader_id,
0,
started_at.elapsed(),
false,
);
self.restore_acknowledgements(acknowledgements);
Err(error)
}
pub(super) fn drop_leader_connection(&mut self, leader_id: i32) {
self.leader_connections.remove(&leader_id);
self.share_session_epochs.remove(&leader_id);
}
pub(super) fn has_acks_for_assignments(&self, assignments: &[ShareAssignment]) -> bool {
assignments.iter().any(|assignment| {
let key = TopicIdPartitionKey {
topic_id: assignment.topic_id,
topic: assignment.topic.clone(),
partition: assignment.partition,
};
self.pending_acks
.get(&key)
.is_some_and(|acks| !acks.is_empty())
})
}
pub(super) fn restore_acknowledgements(
&mut self,
acknowledgements: HashMap<TopicIdPartitionKey, Acknowledgements>,
) {
for (key, acks) in acknowledgements {
self.pending_acks.entry(key).or_default().extend(acks);
}
}
fn build_share_fetch_request(
&self,
assignments: Vec<ShareAssignment>,
acknowledgements: HashMap<TopicIdPartitionKey, Acknowledgements>,
session_epoch: i32,
timeout: Duration,
fetch_records: bool,
) -> AnyResult<ShareFetchRequest> {
let mut topics = BTreeMap::<Uuid, Vec<FetchPartition>>::new();
for assignment in assignments {
let key = TopicIdPartitionKey {
topic_id: assignment.topic_id,
topic: assignment.topic.clone(),
partition: assignment.partition,
};
let acknowledgement_batches = acknowledgements
.get(&key)
.cloned()
.unwrap_or_default()
.into_share_fetch_batches();
topics.entry(assignment.topic_id).or_default().push(
FetchPartition::default()
.with_partition_index(assignment.partition)
.with_partition_max_bytes(self.config.partition_max_bytes)
.with_acknowledgement_batches(acknowledgement_batches),
);
}
Ok(ShareFetchRequest::default()
.with_group_id(Some(
StrBytes::from_string(self.config.group_id.clone()).into(),
))
.with_member_id(Some(StrBytes::from_string(self.member_id.clone())))
.with_share_session_epoch(session_epoch)
.with_max_wait_ms(if fetch_records {
duration_to_i32_ms(timeout)?
} else {
0
})
.with_min_bytes(if fetch_records {
self.config.fetch_min_bytes
} else {
0
})
.with_max_bytes(if fetch_records {
self.config.fetch_max_bytes
} else {
0
})
.with_max_records(if fetch_records {
self.max_poll_records
} else {
0
})
.with_batch_size(self.max_poll_records)
.with_topics(
topics
.into_iter()
.map(|(topic_id, partitions)| {
FetchTopic::default()
.with_topic_id(topic_id)
.with_partitions(partitions)
})
.collect(),
)
.with_forgotten_topics_data(Vec::new()))
}
fn take_acks_for_assignments(
&mut self,
assignments: &[ShareAssignment],
) -> HashMap<TopicIdPartitionKey, Acknowledgements> {
let mut result = HashMap::new();
for assignment in assignments {
let key = TopicIdPartitionKey {
topic_id: assignment.topic_id,
topic: assignment.topic.clone(),
partition: assignment.partition,
};
if let Some(acks) = self.pending_acks.remove(&key) {
result.insert(key, acks);
}
}
result
}
fn process_share_fetch_response(
&mut self,
response: ShareFetchResponse,
acknowledgements: &HashMap<TopicIdPartitionKey, Acknowledgements>,
) -> AnyResult<ShareFetchProcessingResult> {
if let Some(error) = response.error_code.err() {
return Err(error.into());
}
let mut fetched = Vec::new();
let mut remaining_acknowledgements = acknowledgements.clone();
let mut acknowledgement_commits = Vec::new();
for topic in response.responses {
let topic_name = self
.metadata
.topic_name(&topic.topic_id)
.cloned()
.with_context(|| format!("metadata missing topic id {}", topic.topic_id))?;
for partition in topic.partitions {
let key = TopicIdPartitionKey {
topic_id: topic.topic_id,
topic: topic_name.clone(),
partition: partition.partition_index,
};
if let Some(acks) = remaining_acknowledgements.remove(&key) {
let ack_error = partition.acknowledge_error_code.err().map(|error| {
acknowledgement_error_message(error, &partition.acknowledge_error_message)
});
if let Some(commit) = share_acknowledgement_commit(&key, &acks, ack_error) {
acknowledgement_commits.push(commit);
}
}
if let Some(error) = partition.error_code.err() {
debug!(
topic = %topic_name,
partition = partition.partition_index,
error = %error,
"share fetch partition returned error"
);
self.update_share_fetch_leader(&topic_name, &partition);
continue;
}
let acquired = partition
.acquired_records
.iter()
.map(|range| (range.first_offset, range.last_offset, range.delivery_count))
.collect::<Vec<_>>();
self.update_share_fetch_leader(&topic_name, &partition);
let Some(mut bytes) = partition.records else {
continue;
};
if bytes.is_empty() || acquired.is_empty() {
continue;
}
let batches = RecordBatchDecoder::decode_all(&mut bytes)?;
let mut seen_offsets = BTreeMap::<i64, bool>::new();
for (first, last, _) in &acquired {
for offset in *first..=*last {
seen_offsets.insert(offset, false);
}
}
for batch in batches {
for record in batch.records {
if record.control {
continue;
}
let Some(delivery_count) =
delivery_count_for_offset(&acquired, record.offset)
else {
continue;
};
seen_offsets.insert(record.offset, true);
fetched.push(ShareRecord {
record: ConsumerRecord {
topic: topic_name.clone(),
partition: partition.partition_index,
offset: record.offset,
timestamp: record.timestamp,
headers: record
.headers
.into_iter()
.map(|(key, value)| RecordHeader {
key: key.to_string(),
value,
})
.collect(),
key: record.key,
value: record.value,
},
delivery_count,
});
}
}
for (offset, seen) in seen_offsets {
if !seen {
let key = TopicIdPartitionKey {
topic_id: topic.topic_id,
topic: topic_name.clone(),
partition: partition.partition_index,
};
self.pending_acks
.entry(key)
.or_default()
.offsets
.insert(offset, None);
}
}
}
}
for (key, acks) in remaining_acknowledgements {
if let Some(commit) = share_acknowledgement_commit(
&key,
&acks,
Some("share fetch acknowledgement missing from broker response".to_owned()),
) {
acknowledgement_commits.push(commit);
}
}
debug!(records = fetched.len(), "share fetch returned records");
Ok(ShareFetchProcessingResult {
records: fetched,
acknowledgement_commits,
})
}
fn update_share_fetch_leader(
&mut self,
topic_name: &str,
partition: &kafka_protocol::messages::share_fetch_response::PartitionData,
) {
let key = TopicPartitionKey::new(topic_name.to_owned(), partition.partition_index);
if self
.metadata
.broker(partition.current_leader.leader_id)
.is_some()
&& let Some(assigned) = self.assignments.get_mut(&key)
{
assigned.leader_epoch = partition.current_leader.leader_epoch;
assigned.leader_id = partition.current_leader.leader_id;
}
}
}
pub(super) fn is_share_session_reset_error(error: &anyhow::Error) -> bool {
error.chain().any(|cause| {
cause.downcast_ref::<ResponseError>().is_some_and(|error| {
matches!(
error,
ResponseError::ShareSessionNotFound
| ResponseError::InvalidShareSessionEpoch
| ResponseError::ShareSessionLimitReached
)
})
})
}
fn delivery_count_for_offset(acquired: &[(i64, i64, i16)], offset: i64) -> Option<i16> {
acquired.iter().find_map(|(first, last, delivery_count)| {
(*first <= offset && offset <= *last).then_some(*delivery_count)
})
}
fn acknowledgement_error_message(error: ResponseError, message: &Option<StrBytes>) -> String {
message
.as_ref()
.map(ToString::to_string)
.filter(|message| !message.is_empty())
.unwrap_or_else(|| error.to_string())
}
#[cfg(test)]
mod tests {
use kafka_protocol::messages::metadata_response::MetadataResponseTopic;
use kafka_protocol::messages::share_fetch_response::{
PartitionData, ShareFetchableTopicResponse,
};
use kafka_protocol::messages::{MetadataResponse, TopicName};
use super::*;
use crate::config::ConsumerConfig;
use crate::consumer::share::types::{AcknowledgeType, ShareAcquireMode};
use crate::metadata::MetadataCache;
#[test]
fn delivery_count_matches_acquired_range() {
let acquired = vec![(5, 8, 1), (12, 12, 3)];
assert_eq!(delivery_count_for_offset(&acquired, 6), Some(1));
assert_eq!(delivery_count_for_offset(&acquired, 12), Some(3));
assert_eq!(delivery_count_for_offset(&acquired, 9), None);
}
#[test]
fn share_fetch_acknowledgements_complete_per_partition() {
let topic_id = Uuid::from_u128(1);
let mut consumer = consumer_with_topic(topic_id);
let acknowledgements = acknowledgements_for(topic_id, &[(0, 10), (1, 20)]);
let response = ShareFetchResponse::default().with_responses(vec![
ShareFetchableTopicResponse::default()
.with_topic_id(topic_id)
.with_partitions(vec![
PartitionData::default().with_partition_index(0),
PartitionData::default()
.with_partition_index(1)
.with_acknowledge_error_code(ResponseError::InvalidRecordState.code())
.with_acknowledge_error_message(Some(StrBytes::from_string(
"bad state".to_owned(),
))),
]),
]);
let mut commits = consumer
.process_share_fetch_response(response, &acknowledgements)
.unwrap()
.acknowledgement_commits;
commits.sort_by_key(|commit| commit.partition);
assert_eq!(commits.len(), 2);
assert_eq!(commits[0].partition, 0);
assert_eq!(commits[0].error, None);
assert_eq!(commits[1].partition, 1);
assert_eq!(commits[1].error.as_deref(), Some("bad state"));
}
#[test]
fn share_fetch_acknowledgements_missing_from_response_fail_only_that_partition() {
let topic_id = Uuid::from_u128(1);
let mut consumer = consumer_with_topic(topic_id);
let acknowledgements = acknowledgements_for(topic_id, &[(0, 10), (1, 20)]);
let response = ShareFetchResponse::default().with_responses(vec![
ShareFetchableTopicResponse::default()
.with_topic_id(topic_id)
.with_partitions(vec![PartitionData::default().with_partition_index(0)]),
]);
let mut commits = consumer
.process_share_fetch_response(response, &acknowledgements)
.unwrap()
.acknowledgement_commits;
commits.sort_by_key(|commit| commit.partition);
assert_eq!(commits.len(), 2);
assert_eq!(commits[0].partition, 0);
assert_eq!(commits[0].error, None);
assert_eq!(commits[1].partition, 1);
assert!(
commits[1]
.error
.as_deref()
.is_some_and(|error| error.contains("missing from broker response"))
);
}
fn consumer_with_topic(topic_id: Uuid) -> KafkaShareConsumer {
let mut metadata = MetadataCache::default();
metadata
.merge_response(MetadataResponse::default().with_topics(vec![
MetadataResponseTopic::default()
.with_name(Some(TopicName(StrBytes::from_string(
"topic-a".to_owned(),
))))
.with_topic_id(topic_id),
]))
.unwrap();
KafkaShareConsumer {
config: ConsumerConfig::new("localhost:9092", "group-a"),
metadata,
coordinator: None,
leader_connections: HashMap::new(),
subscriptions: Vec::new(),
assignments: HashMap::new(),
member_id: "member-a".to_owned(),
member_epoch: 1,
heartbeat_interval: Duration::from_secs(5),
share_session_epochs: HashMap::new(),
pending_acks: HashMap::new(),
share_acquire_mode: ShareAcquireMode::BatchOptimized,
max_poll_records: 10,
next_record_limit_leader_index: 0,
acknowledgement_commit_callback: None,
}
}
fn acknowledgements_for(
topic_id: Uuid,
partitions: &[(i32, i64)],
) -> HashMap<TopicIdPartitionKey, Acknowledgements> {
let mut acknowledgements = HashMap::new();
for (partition, offset) in partitions {
let mut acks = Acknowledgements::default();
acks.add(*offset, AcknowledgeType::Accept);
acknowledgements.insert(
TopicIdPartitionKey {
topic_id,
topic: "topic-a".to_owned(),
partition: *partition,
},
acks,
);
}
acknowledgements
}
}