use std::collections::HashMap;
use std::time::Duration;
use anyhow::{Context, Result as AnyResult, bail};
use kafka_protocol::error::ParseResponseErrorCode;
use kafka_protocol::messages::find_coordinator_response::FindCoordinatorResponse;
use kafka_protocol::messages::share_group_heartbeat_response::Assignment;
use kafka_protocol::messages::{
FindCoordinatorRequest, ShareGroupHeartbeatRequest, ShareGroupHeartbeatResponse,
};
use kafka_protocol::protocol::StrBytes;
use tokio::time::sleep;
use tracing::trace;
use crate::constants::{
FIND_COORDINATOR_GROUP_KEY_TYPE, FIND_COORDINATOR_VERSION_CAP,
SHARE_GROUP_HEARTBEAT_VERSION_CAP,
};
use crate::metadata::{BrokerAddress, MetadataRefresh, refresh_metadata};
use crate::network::{BrokerConnection, connect_to_any_bootstrap};
use crate::types::TopicPartitionKey;
use super::types::ShareAssignment;
use super::{KafkaShareConsumer, SHARE_COORDINATOR_RETRY_ATTEMPTS, SHARE_MEMBER_LEAVE_EPOCH};
impl KafkaShareConsumer {
pub(super) async fn heartbeat_with_retries(
&mut self,
include_subscription: bool,
) -> AnyResult<()> {
let mut last_error = None;
let attempts = self.share_coordinator_retry_attempts();
for attempt in 0..attempts {
match self.heartbeat(include_subscription).await {
Ok(()) => return Ok(()),
Err(error) => {
let is_coordinator_error = is_coordinator_error(&error);
if is_coordinator_error {
self.coordinator = None;
}
last_error = Some(error);
if !is_coordinator_error {
break;
}
if attempt + 1 < attempts {
sleep(self.config.retry_backoff).await;
}
}
}
}
Err(last_error.expect("share heartbeat attempt count is at least one"))
}
pub(super) async fn leave_group(&mut self) -> AnyResult<()> {
self.ensure_coordinator().await?;
let version = self
.coordinator
.as_mut()
.context("share group coordinator is not connected")?
.version_with_cap::<ShareGroupHeartbeatRequest>(SHARE_GROUP_HEARTBEAT_VERSION_CAP)?;
let request = ShareGroupHeartbeatRequest::default()
.with_group_id(StrBytes::from_string(self.config.group_id.clone()).into())
.with_member_id(StrBytes::from_string(self.member_id.clone()))
.with_member_epoch(SHARE_MEMBER_LEAVE_EPOCH)
.with_subscribed_topic_names(Some(Vec::new()));
let client_id = self.config.client_id.clone();
let _ = self
.coordinator
.as_mut()
.expect("coordinator checked above")
.send_request::<ShareGroupHeartbeatRequest>(&client_id, version, &request)
.await?;
Ok(())
}
pub(super) fn assignments_by_leader(&self) -> HashMap<i32, Vec<ShareAssignment>> {
let mut grouped = HashMap::<i32, Vec<ShareAssignment>>::new();
for assignment in self.assignments.values() {
grouped
.entry(assignment.leader_id)
.or_default()
.push(assignment.clone());
}
grouped
}
pub(super) async fn refresh_metadata_if_needed(&mut self) -> AnyResult<()> {
if self
.metadata
.needs_any_refresh(self.subscriptions.clone(), self.config.metadata_max_age)
{
refresh_metadata(MetadataRefresh {
bootstrap_servers: &self.config.bootstrap_servers,
client_id: &self.config.client_id,
request_timeout: self.config.request_timeout,
security_protocol: self.config.security_protocol,
tls: &self.config.tls,
sasl: &self.config.sasl,
tcp_connector: &self.config.tcp_connector,
metadata: &mut self.metadata,
topics: &self.subscriptions,
})
.await?;
}
Ok(())
}
pub(super) async fn refresh_metadata(&mut self) -> AnyResult<()> {
refresh_metadata(MetadataRefresh {
bootstrap_servers: &self.config.bootstrap_servers,
client_id: &self.config.client_id,
request_timeout: self.config.request_timeout,
security_protocol: self.config.security_protocol,
tls: &self.config.tls,
sasl: &self.config.sasl,
tcp_connector: &self.config.tcp_connector,
metadata: &mut self.metadata,
topics: &self.subscriptions,
})
.await?;
Ok(())
}
pub(super) async fn ensure_coordinator_with_retries(&mut self) -> AnyResult<()> {
let mut last_error = None;
let attempts = self.share_coordinator_retry_attempts();
for attempt in 0..attempts {
match self.ensure_coordinator().await {
Ok(()) => return Ok(()),
Err(error) => {
let is_coordinator_error = is_coordinator_error(&error);
last_error = Some(error);
if !is_coordinator_error {
break;
}
if attempt + 1 < attempts {
sleep(self.config.retry_backoff).await;
}
}
}
}
Err(last_error.expect("share coordinator lookup attempt count is at least one"))
}
pub(super) async fn leader_connection(
&mut self,
leader_id: i32,
) -> AnyResult<&mut BrokerConnection> {
if !self.leader_connections.contains_key(&leader_id) {
if self.metadata.broker(leader_id).is_none() {
self.refresh_metadata().await?;
}
let broker = self
.metadata
.broker(leader_id)
.with_context(|| format!("metadata missing broker {leader_id}"))?;
let connection = BrokerConnection::connect_with_transport(
&broker.address(),
&self.config.client_id,
self.config.request_timeout,
self.config.security_protocol,
&self.config.tls,
&self.config.sasl,
&self.config.tcp_connector,
)
.await?;
self.leader_connections.insert(leader_id, connection);
}
Ok(self
.leader_connections
.get_mut(&leader_id)
.expect("leader connection inserted above"))
}
async fn heartbeat(&mut self, include_subscription: bool) -> AnyResult<()> {
self.ensure_coordinator().await?;
let version = self
.coordinator
.as_mut()
.context("share group coordinator is not connected")?
.version_with_cap::<ShareGroupHeartbeatRequest>(SHARE_GROUP_HEARTBEAT_VERSION_CAP)?;
let subscribed_topic_names = include_subscription.then(|| {
self.subscriptions
.iter()
.cloned()
.map(StrBytes::from_string)
.map(Into::into)
.collect()
});
let request = ShareGroupHeartbeatRequest::default()
.with_group_id(StrBytes::from_string(self.config.group_id.clone()).into())
.with_member_id(StrBytes::from_string(self.member_id.clone()))
.with_member_epoch(self.member_epoch)
.with_rack_id(self.config.rack_id.clone().map(StrBytes::from_string))
.with_subscribed_topic_names(subscribed_topic_names);
let client_id = self.config.client_id.clone();
let response: ShareGroupHeartbeatResponse = self
.coordinator
.as_mut()
.expect("coordinator checked above")
.send_request::<ShareGroupHeartbeatRequest>(&client_id, version, &request)
.await?;
self.handle_heartbeat_response(response)?;
Ok(())
}
fn handle_heartbeat_response(
&mut self,
response: ShareGroupHeartbeatResponse,
) -> AnyResult<()> {
if let Some(error) = response.error_code.err() {
bail!("share group heartbeat failed: {error}");
}
if let Some(member_id) = response.member_id {
self.member_id = member_id.to_string();
}
self.member_epoch = response.member_epoch;
if response.heartbeat_interval_ms > 0 {
self.heartbeat_interval = Duration::from_millis(response.heartbeat_interval_ms as u64);
}
if let Some(assignment) = response.assignment {
self.apply_assignment(assignment)?;
}
Ok(())
}
fn apply_assignment(&mut self, assignment: Assignment) -> AnyResult<()> {
let mut assignments = HashMap::new();
for topic in assignment.topic_partitions {
let topic_name = self
.metadata
.topic_name(&topic.topic_id)
.cloned()
.with_context(|| format!("metadata missing topic id {}", topic.topic_id))?;
for partition in topic.partitions {
let partition_metadata = self
.metadata
.partition(&topic_name, partition)
.with_context(|| format!("metadata missing {topic_name}:{partition}"))?;
assignments.insert(
TopicPartitionKey::new(topic_name.clone(), partition),
ShareAssignment {
topic_id: topic.topic_id,
topic: topic_name.clone(),
partition,
leader_id: partition_metadata.leader_id,
leader_epoch: partition_metadata.leader_epoch,
},
);
}
}
self.assignments = assignments;
Ok(())
}
async fn ensure_coordinator(&mut self) -> AnyResult<()> {
if self.coordinator.is_some() {
return Ok(());
}
let mut bootstrap = connect_to_any_bootstrap(
&self.config.bootstrap_servers,
&self.config.client_id,
self.config.request_timeout,
self.config.security_protocol,
&self.config.tls,
&self.config.sasl,
&self.config.tcp_connector,
)
.await?;
let version =
bootstrap.version_with_cap::<FindCoordinatorRequest>(FIND_COORDINATOR_VERSION_CAP)?;
let request = if version >= 4 {
FindCoordinatorRequest::default()
.with_key_type(FIND_COORDINATOR_GROUP_KEY_TYPE)
.with_coordinator_keys(vec![StrBytes::from_string(self.config.group_id.clone())])
} else {
FindCoordinatorRequest::default()
.with_key(StrBytes::from_string(self.config.group_id.clone()))
.with_key_type(FIND_COORDINATOR_GROUP_KEY_TYPE)
};
let response = bootstrap
.send_request::<FindCoordinatorRequest>(&self.config.client_id, version, &request)
.await?;
let coordinator = parse_find_coordinator_response(response, version)?;
let connection = BrokerConnection::connect_with_transport(
&coordinator.address(),
&self.config.client_id,
self.config.request_timeout,
self.config.security_protocol,
&self.config.tls,
&self.config.sasl,
&self.config.tcp_connector,
)
.await?;
trace!(coordinator = %coordinator.address(), "connected to share group coordinator");
self.coordinator = Some(connection);
Ok(())
}
fn share_coordinator_retry_attempts(&self) -> usize {
self.config
.max_retries
.max(SHARE_COORDINATOR_RETRY_ATTEMPTS)
.max(1)
}
}
fn parse_find_coordinator_response(
response: FindCoordinatorResponse,
version: i16,
) -> AnyResult<BrokerAddress> {
if version >= 4 {
let coordinator = response
.coordinators
.into_iter()
.next()
.context("find coordinator returned no coordinator entries")?;
if let Some(error) = coordinator.error_code.err() {
bail!("find coordinator failed: {error}");
}
return Ok(BrokerAddress::new(
coordinator.host.to_string(),
u16::try_from(coordinator.port)
.with_context(|| format!("invalid coordinator port {}", coordinator.port))?,
));
}
if let Some(error) = response.error_code.err() {
bail!("find coordinator failed: {error}");
}
Ok(BrokerAddress::new(
response.host.to_string(),
u16::try_from(response.port)
.with_context(|| format!("invalid coordinator port {}", response.port))?,
))
}
fn is_coordinator_error(error: &anyhow::Error) -> bool {
let message = error.to_string();
message.contains("NotCoordinator")
|| message.contains("CoordinatorNotAvailable")
|| message.contains("CoordinatorLoadInProgress")
}