use std::collections::{BTreeMap, HashMap};
use std::time::Duration;
use anyhow::{Context, Result as AnyResult, bail};
use kafka_protocol::error::ParseResponseErrorCode;
use kafka_protocol::messages::ShareFetchRequest;
use kafka_protocol::messages::share_fetch_request::{FetchPartition, FetchTopic};
use kafka_protocol::messages::share_fetch_response::ShareFetchResponse;
use kafka_protocol::protocol::StrBytes;
use kafka_protocol::records::RecordBatchDecoder;
use tokio::time::sleep;
use tracing::debug;
use uuid::Uuid;
use crate::constants::SHARE_FETCH_VERSION_CAP;
use crate::network::duration_to_i32_ms;
use crate::types::{ConsumerRecord, RecordHeader, TopicPartitionKey};
use super::acknowledgements::{Acknowledgements, share_acknowledgement_commits};
use super::types::{ShareAssignment, ShareRecord, TopicIdPartitionKey};
use super::{KafkaShareConsumer, SHARE_SESSION_OPEN_EPOCH};
impl KafkaShareConsumer {
pub(super) async fn fetch_from_leader(
&mut self,
leader_id: i32,
assignments: Vec<ShareAssignment>,
timeout: Duration,
fetch_records: bool,
) -> AnyResult<Vec<ShareRecord>> {
let version = self
.leader_connection(leader_id)
.await?
.version_with_cap::<ShareFetchRequest>(SHARE_FETCH_VERSION_CAP)?;
let session_epoch = *self
.share_session_epochs
.get(&leader_id)
.unwrap_or(&SHARE_SESSION_OPEN_EPOCH);
let acknowledgements = self.take_acks_for_assignments(&assignments);
let acknowledgement_commits = share_acknowledgement_commits(&acknowledgements, None);
let request = self.build_share_fetch_request(
assignments,
acknowledgements.clone(),
session_epoch,
timeout,
fetch_records,
)?;
let client_id = self.config.client_id.clone();
let mut last_error = None;
for attempt in 0..self.config.max_retries.max(1) {
let result = self
.leader_connection(leader_id)
.await?
.send_request::<ShareFetchRequest>(&client_id, version, &request)
.await
.and_then(|response: ShareFetchResponse| {
self.process_share_fetch_response(response)
});
match result {
Ok(records) => {
self.advance_share_session_epoch(leader_id, session_epoch);
self.invoke_acknowledgement_commit_callback(acknowledgement_commits.clone());
return Ok(records);
}
Err(error) => {
last_error = Some(error);
if attempt + 1 < self.config.max_retries.max(1) {
sleep(self.config.retry_backoff).await;
}
}
}
}
let error = last_error.expect("share fetch attempt count is at least one");
self.invoke_acknowledgement_commit_callback(share_acknowledgement_commits(
&acknowledgements,
Some(error.to_string()),
));
self.restore_acknowledgements(acknowledgements);
Err(error)
}
pub(super) fn has_acks_for_assignments(&self, assignments: &[ShareAssignment]) -> bool {
assignments.iter().any(|assignment| {
let key = TopicIdPartitionKey {
topic_id: assignment.topic_id,
topic: assignment.topic.clone(),
partition: assignment.partition,
};
self.pending_acks
.get(&key)
.is_some_and(|acks| !acks.is_empty())
})
}
pub(super) fn restore_acknowledgements(
&mut self,
acknowledgements: HashMap<TopicIdPartitionKey, Acknowledgements>,
) {
for (key, acks) in acknowledgements {
self.pending_acks.entry(key).or_default().extend(acks);
}
}
fn build_share_fetch_request(
&self,
assignments: Vec<ShareAssignment>,
acknowledgements: HashMap<TopicIdPartitionKey, Acknowledgements>,
session_epoch: i32,
timeout: Duration,
fetch_records: bool,
) -> AnyResult<ShareFetchRequest> {
let mut topics = BTreeMap::<Uuid, Vec<FetchPartition>>::new();
for assignment in assignments {
let key = TopicIdPartitionKey {
topic_id: assignment.topic_id,
topic: assignment.topic.clone(),
partition: assignment.partition,
};
let acknowledgement_batches = acknowledgements
.get(&key)
.cloned()
.unwrap_or_default()
.into_share_fetch_batches();
topics.entry(assignment.topic_id).or_default().push(
FetchPartition::default()
.with_partition_index(assignment.partition)
.with_partition_max_bytes(self.config.partition_max_bytes)
.with_acknowledgement_batches(acknowledgement_batches),
);
}
Ok(ShareFetchRequest::default()
.with_group_id(Some(
StrBytes::from_string(self.config.group_id.clone()).into(),
))
.with_member_id(Some(StrBytes::from_string(self.member_id.clone())))
.with_share_session_epoch(session_epoch)
.with_max_wait_ms(if fetch_records {
duration_to_i32_ms(timeout)?
} else {
0
})
.with_min_bytes(if fetch_records {
self.config.fetch_min_bytes
} else {
0
})
.with_max_bytes(if fetch_records {
self.config.fetch_max_bytes
} else {
0
})
.with_max_records(if fetch_records {
self.max_poll_records
} else {
0
})
.with_batch_size(self.max_poll_records)
.with_topics(
topics
.into_iter()
.map(|(topic_id, partitions)| {
FetchTopic::default()
.with_topic_id(topic_id)
.with_partitions(partitions)
})
.collect(),
)
.with_forgotten_topics_data(Vec::new()))
}
fn take_acks_for_assignments(
&mut self,
assignments: &[ShareAssignment],
) -> HashMap<TopicIdPartitionKey, Acknowledgements> {
let mut result = HashMap::new();
for assignment in assignments {
let key = TopicIdPartitionKey {
topic_id: assignment.topic_id,
topic: assignment.topic.clone(),
partition: assignment.partition,
};
if let Some(acks) = self.pending_acks.remove(&key) {
result.insert(key, acks);
}
}
result
}
fn process_share_fetch_response(
&mut self,
response: ShareFetchResponse,
) -> AnyResult<Vec<ShareRecord>> {
if let Some(error) = response.error_code.err() {
bail!("share fetch failed: {error}");
}
let mut fetched = Vec::new();
for topic in response.responses {
let topic_name = self
.metadata
.topic_name(&topic.topic_id)
.cloned()
.with_context(|| format!("metadata missing topic id {}", topic.topic_id))?;
for partition in topic.partitions {
if let Some(error) = partition.error_code.err() {
bail!(
"share fetch failed for {}:{}: {}",
topic_name,
partition.partition_index,
error
);
}
if let Some(error) = partition.acknowledge_error_code.err() {
bail!(
"share acknowledgement failed for {}:{}: {}",
topic_name,
partition.partition_index,
error
);
}
let acquired = partition
.acquired_records
.iter()
.map(|range| (range.first_offset, range.last_offset, range.delivery_count))
.collect::<Vec<_>>();
let Some(mut bytes) = partition.records else {
continue;
};
if bytes.is_empty() || acquired.is_empty() {
continue;
}
let batches = RecordBatchDecoder::decode_all(&mut bytes)?;
let mut seen_offsets = BTreeMap::<i64, bool>::new();
for (first, last, _) in &acquired {
for offset in *first..=*last {
seen_offsets.insert(offset, false);
}
}
for batch in batches {
for record in batch.records {
if record.control {
continue;
}
let Some(delivery_count) =
delivery_count_for_offset(&acquired, record.offset)
else {
continue;
};
seen_offsets.insert(record.offset, true);
fetched.push(ShareRecord {
record: ConsumerRecord {
topic: topic_name.clone(),
partition: partition.partition_index,
offset: record.offset,
timestamp: record.timestamp,
headers: record
.headers
.into_iter()
.map(|(key, value)| RecordHeader {
key: key.to_string(),
value,
})
.collect(),
key: record.key,
value: record.value,
},
delivery_count,
});
}
}
for (offset, seen) in seen_offsets {
if !seen {
let key = TopicIdPartitionKey {
topic_id: topic.topic_id,
topic: topic_name.clone(),
partition: partition.partition_index,
};
self.pending_acks
.entry(key)
.or_default()
.offsets
.insert(offset, None);
}
}
let key = TopicPartitionKey::new(topic_name.clone(), partition.partition_index);
if self
.metadata
.broker(partition.current_leader.leader_id)
.is_some()
&& let Some(assigned) = self.assignments.get_mut(&key)
{
assigned.leader_epoch = partition.current_leader.leader_epoch;
assigned.leader_id = partition.current_leader.leader_id;
}
}
}
debug!(records = fetched.len(), "share fetch returned records");
Ok(fetched)
}
}
fn delivery_count_for_offset(acquired: &[(i64, i64, i16)], offset: i64) -> Option<i16> {
acquired.iter().find_map(|(first, last, delivery_count)| {
(*first <= offset && offset <= *last).then_some(*delivery_count)
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn delivery_count_matches_acquired_range() {
let acquired = vec![(5, 8, 1), (12, 12, 3)];
assert_eq!(delivery_count_for_offset(&acquired, 6), Some(1));
assert_eq!(delivery_count_for_offset(&acquired, 12), Some(3));
assert_eq!(delivery_count_for_offset(&acquired, 9), None);
}
}