use crate::client::Client;
use crate::error::{Error, Result};
use rivven_protocol::MessageData;
use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tracing::{debug, info, warn};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct TopicPartition {
pub topic: Arc<str>,
pub partition: u32,
}
#[async_trait::async_trait]
pub trait RebalanceListener: Send + Sync {
async fn on_partitions_revoked(&self, partitions: &[TopicPartition]);
async fn on_partitions_assigned(&self, partitions: &[TopicPartition]);
}
#[derive(Debug, Clone)]
pub struct ConsumerConfig {
pub bootstrap_servers: Vec<String>,
pub group_id: String,
pub topics: Vec<String>,
pub partitions: HashMap<String, Vec<u32>>,
pub max_poll_records: u32,
pub max_poll_interval_ms: u64,
pub auto_commit_interval: Option<Duration>,
pub isolation_level: u8,
pub auth: Option<ConsumerAuthConfig>,
pub metadata_refresh_interval: Duration,
pub reconnect_backoff_ms: u64,
pub reconnect_backoff_max_ms: u64,
pub max_reconnect_attempts: u32,
pub session_timeout_ms: u32,
pub rebalance_timeout_ms: u32,
pub heartbeat_interval_ms: u64,
#[cfg(feature = "tls")]
pub tls_config: Option<rivven_core::tls::TlsConfig>,
#[cfg(feature = "tls")]
pub tls_server_name: Option<String>,
}
#[derive(Clone)]
pub struct ConsumerAuthConfig {
pub username: String,
pub password: String,
}
impl std::fmt::Debug for ConsumerAuthConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ConsumerAuthConfig")
.field("username", &self.username)
.field("password", &"[REDACTED]")
.finish()
}
}
pub struct ConsumerConfigBuilder {
bootstrap_servers: Vec<String>,
group_id: Option<String>,
topics: Vec<String>,
partitions: HashMap<String, Vec<u32>>,
max_poll_records: u32,
max_poll_interval_ms: u64,
auto_commit_interval: Option<Duration>,
isolation_level: u8,
auth: Option<ConsumerAuthConfig>,
metadata_refresh_interval: Duration,
reconnect_backoff_ms: u64,
reconnect_backoff_max_ms: u64,
max_reconnect_attempts: u32,
session_timeout_ms: u32,
rebalance_timeout_ms: u32,
heartbeat_interval_ms: u64,
#[cfg(feature = "tls")]
tls_config: Option<rivven_core::tls::TlsConfig>,
#[cfg(feature = "tls")]
tls_server_name: Option<String>,
}
impl ConsumerConfigBuilder {
pub fn new() -> Self {
Self {
bootstrap_servers: vec!["127.0.0.1:9092".to_string()],
group_id: None,
topics: Vec::new(),
partitions: HashMap::new(),
max_poll_records: 500,
max_poll_interval_ms: 5000,
auto_commit_interval: Some(Duration::from_secs(5)),
isolation_level: 0,
auth: None,
metadata_refresh_interval: Duration::from_secs(300),
reconnect_backoff_ms: 100,
reconnect_backoff_max_ms: 10_000,
max_reconnect_attempts: 10,
session_timeout_ms: 10_000,
rebalance_timeout_ms: 30_000,
heartbeat_interval_ms: 3_000,
#[cfg(feature = "tls")]
tls_config: None,
#[cfg(feature = "tls")]
tls_server_name: None,
}
}
pub fn bootstrap_server(mut self, server: impl Into<String>) -> Self {
self.bootstrap_servers = vec![server.into()];
self
}
pub fn bootstrap_servers(mut self, servers: Vec<String>) -> Self {
self.bootstrap_servers = servers;
self
}
pub fn group_id(mut self, group: impl Into<String>) -> Self {
self.group_id = Some(group.into());
self
}
pub fn topics(mut self, topics: Vec<String>) -> Self {
self.topics = topics;
self
}
pub fn topic(mut self, topic: impl Into<String>) -> Self {
self.topics.push(topic.into());
self
}
pub fn assign(mut self, topic: impl Into<String>, partitions: Vec<u32>) -> Self {
self.partitions.insert(topic.into(), partitions);
self
}
pub fn max_poll_records(mut self, n: u32) -> Self {
self.max_poll_records = n;
self
}
pub fn max_poll_interval_ms(mut self, ms: u64) -> Self {
self.max_poll_interval_ms = ms;
self
}
pub fn auto_commit_interval(mut self, interval: Option<Duration>) -> Self {
self.auto_commit_interval = interval;
self
}
pub fn enable_auto_commit(mut self, enabled: bool) -> Self {
if enabled {
self.auto_commit_interval = Some(Duration::from_secs(5));
} else {
self.auto_commit_interval = None;
}
self
}
pub fn isolation_level(mut self, level: u8) -> Self {
self.isolation_level = level;
self
}
pub fn read_committed(mut self) -> Self {
self.isolation_level = 1;
self
}
pub fn auth(mut self, username: impl Into<String>, password: impl Into<String>) -> Self {
self.auth = Some(ConsumerAuthConfig {
username: username.into(),
password: password.into(),
});
self
}
pub fn metadata_refresh_interval(mut self, interval: Duration) -> Self {
self.metadata_refresh_interval = interval;
self
}
pub fn reconnect_backoff_ms(mut self, ms: u64) -> Self {
self.reconnect_backoff_ms = ms;
self
}
pub fn reconnect_backoff_max_ms(mut self, ms: u64) -> Self {
self.reconnect_backoff_max_ms = ms;
self
}
pub fn max_reconnect_attempts(mut self, attempts: u32) -> Self {
self.max_reconnect_attempts = attempts;
self
}
pub fn session_timeout_ms(mut self, ms: u32) -> Self {
self.session_timeout_ms = ms;
self
}
pub fn rebalance_timeout_ms(mut self, ms: u32) -> Self {
self.rebalance_timeout_ms = ms;
self
}
pub fn heartbeat_interval_ms(mut self, ms: u64) -> Self {
self.heartbeat_interval_ms = ms;
self
}
#[cfg(feature = "tls")]
pub fn tls(
mut self,
tls_config: rivven_core::tls::TlsConfig,
server_name: impl Into<String>,
) -> Self {
self.tls_config = Some(tls_config);
self.tls_server_name = Some(server_name.into());
self
}
pub fn build(self) -> ConsumerConfig {
let max_heartbeat = (self.session_timeout_ms as u64) / 3;
let heartbeat_interval_ms = if self.heartbeat_interval_ms > max_heartbeat {
tracing::warn!(
configured = self.heartbeat_interval_ms,
clamped_to = max_heartbeat,
session_timeout_ms = self.session_timeout_ms,
"heartbeat_interval_ms exceeds 1/3 of session_timeout_ms, clamping"
);
max_heartbeat
} else {
self.heartbeat_interval_ms
};
ConsumerConfig {
bootstrap_servers: self.bootstrap_servers,
group_id: self.group_id.unwrap_or_else(|| "default-group".into()),
topics: self.topics,
partitions: self.partitions,
max_poll_records: self.max_poll_records,
max_poll_interval_ms: self.max_poll_interval_ms,
auto_commit_interval: self.auto_commit_interval,
isolation_level: self.isolation_level,
auth: self.auth,
metadata_refresh_interval: self.metadata_refresh_interval,
reconnect_backoff_ms: self.reconnect_backoff_ms,
reconnect_backoff_max_ms: self.reconnect_backoff_max_ms,
max_reconnect_attempts: self.max_reconnect_attempts,
session_timeout_ms: self.session_timeout_ms,
rebalance_timeout_ms: self.rebalance_timeout_ms,
heartbeat_interval_ms,
#[cfg(feature = "tls")]
tls_config: self.tls_config,
#[cfg(feature = "tls")]
tls_server_name: self.tls_server_name,
}
}
}
impl Default for ConsumerConfigBuilder {
fn default() -> Self {
Self::new()
}
}
impl ConsumerConfig {
pub fn builder() -> ConsumerConfigBuilder {
ConsumerConfigBuilder::new()
}
}
#[derive(Debug, Clone)]
pub struct ConsumerRecord {
pub topic: Arc<str>,
pub partition: u32,
pub offset: u64,
pub data: MessageData,
}
pub struct Consumer {
client: Client,
config: ConsumerConfig,
offsets: HashMap<(Arc<str>, u32), u64>,
assignments: HashMap<String, Vec<u32>>,
assignment_list: Vec<(Arc<str>, u32)>,
last_commit: Instant,
last_discovery: Instant,
initialized: bool,
member_id: String,
generation_id: u32,
is_leader: bool,
last_heartbeat: Instant,
uses_coordination: bool,
needs_rejoin: Arc<AtomicBool>,
rebalance_listener: Option<Arc<dyn RebalanceListener>>,
heartbeat_handle: Option<tokio::task::JoinHandle<()>>,
}
impl Consumer {
pub async fn new(config: ConsumerConfig) -> Result<Self> {
let servers = &config.bootstrap_servers;
if servers.is_empty() {
return Err(Error::ConnectionError(
"No bootstrap servers configured".to_string(),
));
}
let mut last_error = None;
let mut client = None;
for server in servers {
#[cfg(feature = "tls")]
let connect_result = if let (Some(ref tls_cfg), Some(ref sni)) =
(&config.tls_config, &config.tls_server_name)
{
Client::connect_tls(server, tls_cfg, sni).await
} else {
Client::connect(server).await
};
#[cfg(not(feature = "tls"))]
let connect_result = Client::connect(server).await;
match connect_result {
Ok(c) => {
client = Some(c);
break;
}
Err(e) => {
warn!(server = %server, error = %e, "Failed to connect to bootstrap server");
last_error = Some(e);
}
}
}
let mut client = client.ok_or_else(|| {
last_error.unwrap_or_else(|| {
Error::ConnectionError("No bootstrap servers available".to_string())
})
})?;
if let Some(ref auth) = config.auth {
client
.authenticate_scram(&auth.username, &auth.password)
.await?;
}
let uses_coordination = config.partitions.is_empty();
let mut consumer = Self {
client,
config,
offsets: HashMap::new(),
assignments: HashMap::new(),
assignment_list: Vec::new(),
last_commit: Instant::now(),
last_discovery: Instant::now(),
initialized: false,
member_id: String::new(),
generation_id: 0,
is_leader: false,
last_heartbeat: Instant::now(),
uses_coordination,
needs_rejoin: Arc::new(AtomicBool::new(false)),
rebalance_listener: None,
heartbeat_handle: None,
};
consumer.discover_assignments().await?;
info!(
group_id = %consumer.config.group_id,
topics = ?consumer.config.topics,
partitions = ?consumer.assignments,
"Consumer initialized"
);
Ok(consumer)
}
pub fn set_rebalance_listener(&mut self, listener: Arc<dyn RebalanceListener>) {
self.rebalance_listener = Some(listener);
}
async fn spawn_heartbeat_task(&mut self) {
if let Some(handle) = self.heartbeat_handle.take() {
handle.abort();
}
if self.member_id.is_empty() || !self.uses_coordination {
return;
}
let group_id = self.config.group_id.clone();
let member_id = self.member_id.clone();
let generation_id = self.generation_id;
let interval = Duration::from_millis(self.config.heartbeat_interval_ms);
let needs_rejoin = self.needs_rejoin.clone();
let servers = self.config.bootstrap_servers.clone();
let auth = self.config.auth.clone();
self.heartbeat_handle = Some(tokio::spawn(async move {
let mut hb_client = None;
for server in &servers {
match Client::connect(server).await {
Ok(mut c) => {
if let Some(ref auth) = auth {
if let Err(e) =
c.authenticate_scram(&auth.username, &auth.password).await
{
warn!(
server = %server,
error = %e,
"Heartbeat connection auth failed, trying next server"
);
continue;
}
}
hb_client = Some(c);
break;
}
Err(e) => {
warn!(
server = %server,
error = %e,
"Heartbeat connection failed, trying next server"
);
}
}
}
let Some(mut client) = hb_client else {
warn!("Could not establish heartbeat connection to any server, signaling rejoin");
needs_rejoin.store(true, Ordering::Release);
return;
};
let mut ticker = tokio::time::interval(interval);
ticker.tick().await;
loop {
ticker.tick().await;
match client.heartbeat(&group_id, generation_id, &member_id).await {
Ok(27) => {
info!(
group_id = %group_id,
"Background heartbeat: rebalance in progress, signaling rejoin"
);
needs_rejoin.store(true, Ordering::Release);
}
Ok(_) => {
}
Err(e) => {
warn!(
group_id = %group_id,
error = %e,
"Background heartbeat failed, signaling rejoin"
);
needs_rejoin.store(true, Ordering::Release);
}
}
}
}));
}
async fn discover_assignments(&mut self) -> Result<()> {
let old_tps: Vec<TopicPartition> = self
.assignments
.iter()
.flat_map(|(t, ps)| {
let arc: Arc<str> = Arc::from(t.as_str());
ps.iter().map(move |&p| TopicPartition {
topic: arc.clone(),
partition: p,
})
})
.collect();
if !old_tps.is_empty() {
if let Some(ref listener) = self.rebalance_listener {
listener.on_partitions_revoked(&old_tps).await;
}
}
if self.uses_coordination {
self.discover_via_coordination().await?;
self.spawn_heartbeat_task().await;
} else {
self.discover_via_metadata().await?;
}
let owned_keys: std::collections::HashSet<(Arc<str>, u32)> = self
.assignments
.iter()
.flat_map(|(t, ps)| {
let arc: Arc<str> = Arc::from(t.as_str());
ps.iter().map(move |&p| (arc.clone(), p))
})
.collect();
self.offsets.retain(|k, _| owned_keys.contains(k));
for (topic, partitions) in &self.assignments {
for &partition in partitions {
let key: (Arc<str>, u32) = (Arc::from(topic.as_str()), partition);
if self.offsets.contains_key(&key) {
continue;
}
match self
.client
.get_offset(&self.config.group_id, topic, partition)
.await
{
Ok(Some(offset)) => {
debug!(
topic = %topic,
partition,
offset,
"Resumed from committed offset"
);
self.offsets.insert(key, offset);
}
Ok(None) => {
self.offsets.insert(key, 0);
}
Err(e) => {
debug!(
topic = %topic,
partition,
error = %e,
"Failed to load committed offset, starting from 0"
);
self.offsets.insert(key, 0);
}
}
}
}
self.initialized = true;
self.assignment_list = self
.assignments
.iter()
.flat_map(|(t, ps)| {
let arc: Arc<str> = Arc::from(t.as_str());
ps.iter().map(move |&p| (arc.clone(), p))
})
.collect();
if let Some(ref listener) = self.rebalance_listener {
let new_tps: Vec<TopicPartition> = self
.assignment_list
.iter()
.map(|(t, p)| TopicPartition {
topic: t.clone(),
partition: *p,
})
.collect();
if !new_tps.is_empty() {
listener.on_partitions_assigned(&new_tps).await;
}
}
Ok(())
}
async fn discover_via_coordination(&mut self) -> Result<()> {
let (generation_id, _protocol_type, member_id, leader_id, members) = self
.client
.join_group(
&self.config.group_id,
&self.member_id,
self.config.session_timeout_ms,
self.config.rebalance_timeout_ms,
"consumer",
self.config.topics.clone(),
)
.await?;
self.member_id = member_id.clone();
self.generation_id = generation_id;
self.is_leader = member_id == leader_id;
info!(
group_id = %self.config.group_id,
member_id = %self.member_id,
generation_id,
is_leader = self.is_leader,
member_count = members.len(),
"Joined consumer group"
);
let group_assignments = if self.is_leader {
self.compute_range_assignments(&members).await?
} else {
Vec::new()
};
let my_assignments = self
.client
.sync_group(
&self.config.group_id,
generation_id,
&self.member_id,
group_assignments,
)
.await?;
self.assignments.clear();
for (topic, partitions) in my_assignments {
debug!(
topic = %topic,
partitions = ?partitions,
"Received partition assignment"
);
self.assignments.insert(topic, partitions);
}
self.last_heartbeat = Instant::now();
Ok(())
}
async fn compute_range_assignments(
&mut self,
members: &[(String, Vec<String>)],
) -> Result<Vec<(String, Vec<(String, Vec<u32>)>)>> {
let mut all_topics: Vec<String> = members
.iter()
.flat_map(|(_, subs)| subs.iter().cloned())
.collect();
all_topics.sort();
all_topics.dedup();
let mut result_map: HashMap<String, Vec<(String, Vec<u32>)>> = members
.iter()
.map(|(mid, _)| (mid.clone(), Vec::new()))
.collect();
for topic in &all_topics {
let mut subscribed: Vec<&str> = members
.iter()
.filter(|(_, subs)| subs.iter().any(|s| s == topic))
.map(|(mid, _)| mid.as_str())
.collect();
subscribed.sort();
let partition_count = match self.client.get_metadata(topic.as_str()).await {
Ok((_name, count)) => count,
Err(e) => {
warn!(topic = %topic, error = %e, "Failed to get metadata for assignment");
continue;
}
};
if subscribed.is_empty() || partition_count == 0 {
continue;
}
let n_members = subscribed.len() as u32;
let per_member = partition_count / n_members;
let remainder = partition_count % n_members;
let mut offset = 0u32;
for (i, mid) in subscribed.iter().enumerate() {
let extra = if (i as u32) < remainder { 1 } else { 0 };
let count = per_member + extra;
let partitions: Vec<u32> = (offset..offset + count).collect();
offset += count;
if let Some(entry) = result_map.get_mut(*mid) {
entry.push((topic.clone(), partitions));
}
}
}
Ok(result_map.into_iter().collect())
}
async fn discover_via_metadata(&mut self) -> Result<()> {
for topic in &self.config.topics {
if let Some(explicit) = self.config.partitions.get(topic) {
self.assignments.insert(topic.clone(), explicit.clone());
} else {
match self.client.get_metadata(topic.as_str()).await {
Ok((_name, partition_count)) => {
let partitions: Vec<u32> = (0..partition_count).collect();
self.assignments.insert(topic.clone(), partitions);
}
Err(e) => {
warn!(
topic = %topic,
error = %e,
"Failed to discover partitions, will retry on next poll"
);
}
}
}
}
Ok(())
}
pub async fn poll(&mut self) -> Result<Vec<ConsumerRecord>> {
match self.poll_inner().await {
Ok(records) => Ok(records),
Err(e) if Self::is_connection_error(&e) => {
warn!(error = %e, "Connection error during poll, attempting reconnect");
self.reconnect().await?;
self.poll_inner().await
}
Err(e) => Err(e),
}
}
async fn poll_inner(&mut self) -> Result<Vec<ConsumerRecord>> {
if !self.initialized {
self.discover_assignments().await?;
}
if self.needs_rejoin.load(Ordering::Acquire) && self.uses_coordination {
info!(
group_id = %self.config.group_id,
"Rejoining group due to rebalance signal"
);
self.discover_assignments().await?;
self.needs_rejoin.store(false, Ordering::Release);
}
if self.last_discovery.elapsed() >= self.config.metadata_refresh_interval {
if let Err(e) = self.discover_assignments().await {
warn!(error = %e, "Failed to re-discover assignments, continuing with existing");
}
self.last_discovery = Instant::now();
}
let mut records = Vec::new();
let isolation_level = if self.config.isolation_level > 0 {
Some(self.config.isolation_level)
} else {
None
};
if !self.assignment_list.is_empty() {
let fetches: Vec<(&str, u32, u64, u32, Option<u8>)> = self
.assignment_list
.iter()
.map(|(topic, partition)| {
let key = (topic.clone(), *partition);
let offset = self.offsets.get(&key).copied().unwrap_or(0);
(
&**topic,
*partition,
offset,
self.config.max_poll_records,
isolation_level,
)
})
.collect();
let results = self.client.consume_pipelined(&fetches).await?;
for (i, result) in results.into_iter().enumerate() {
let (topic, partition) = &self.assignment_list[i];
match result {
Ok(messages) if !messages.is_empty() => {
let key = (topic.clone(), *partition);
let max_offset = messages.iter().map(|m| m.offset).max().unwrap_or(0);
self.offsets.insert(key, max_offset + 1);
records.extend(messages.into_iter().map(|data| ConsumerRecord {
topic: topic.clone(),
partition: *partition,
offset: data.offset,
data,
}));
}
Err(e) => {
let err_str = e.to_string();
if err_str.contains("UNKNOWN_MEMBER_ID")
|| err_str.contains("ILLEGAL_GENERATION")
|| err_str.contains("REBALANCE_IN_PROGRESS")
{
warn!(
topic = %topic,
partition = partition,
error = %e,
"Rebalance signal in fetch response, will rejoin group"
);
self.needs_rejoin.store(true, Ordering::Release);
} else {
warn!(
topic = %topic,
partition = partition,
error = %e,
"Pipelined fetch error, skipping partition"
);
}
}
_ => {} }
}
}
if records.is_empty() && self.config.max_poll_interval_ms > 0 {
if !self.assignment_list.is_empty() {
self.assignment_list.rotate_left(1);
}
if let Some((topic, partition)) = self.assignment_list.first() {
let key = (topic.clone(), *partition);
let offset = self.offsets.get(&key).copied().unwrap_or(0);
let max_wait = if self.uses_coordination {
self.config.max_poll_interval_ms.min(
self.config
.heartbeat_interval_ms
.saturating_sub(500)
.max(500),
)
} else {
self.config.max_poll_interval_ms
};
let messages = self
.client
.consume_long_poll(
topic.to_string(),
*partition,
offset,
self.config.max_poll_records,
isolation_level,
max_wait,
)
.await?;
if !messages.is_empty() {
let max_offset = messages.iter().map(|m| m.offset).max().unwrap_or(offset);
self.offsets.insert(key, max_offset + 1);
records.extend(messages.into_iter().map(|data| ConsumerRecord {
topic: topic.clone(),
partition: *partition,
offset: data.offset,
data,
}));
}
}
}
if let Some(interval) = self.config.auto_commit_interval {
if self.last_commit.elapsed() >= interval {
if let Err(e) = self.commit_inner().await {
warn!(error = %e, "Auto-commit failed");
}
}
}
Ok(records)
}
pub async fn commit(&mut self) -> Result<()> {
match self.commit_inner().await {
Ok(()) => Ok(()),
Err(e) if Self::is_connection_error(&e) => {
warn!(error = %e, "Connection error during commit, attempting reconnect");
self.reconnect().await?;
self.commit_inner().await
}
Err(e) => Err(e),
}
}
async fn commit_inner(&mut self) -> Result<()> {
if self.offsets.is_empty() {
return Ok(());
}
let commits: Vec<(String, u32, u64)> = self
.offsets
.iter()
.map(|((topic, partition), offset)| (topic.to_string(), *partition, *offset))
.collect();
let mut errors = Vec::new();
if self.client.is_poisoned() {
return Err(Error::ConnectionError(
"Client stream is desynchronized — reconnect required".into(),
));
}
{
let results = self
.client
.commit_offsets_pipelined(&self.config.group_id, &commits)
.await;
match results {
Ok(per_partition) => {
for (i, result) in per_partition.into_iter().enumerate() {
if let Err(e) = result {
let (topic, partition, offset) = &commits[i];
warn!(
topic = %topic, partition, offset, error = %e,
"Failed to commit offset"
);
errors.push(e);
}
}
}
Err(e) => {
errors.push(e);
}
}
}
self.last_commit = Instant::now();
if errors.is_empty() {
debug!(
group_id = %self.config.group_id,
partitions = self.offsets.len(),
"Offsets committed"
);
Ok(())
} else {
Err(errors.into_iter().next().expect("errors is non-empty"))
}
}
pub fn seek(&mut self, topic: impl Into<String>, partition: u32, offset: u64) {
let arc: Arc<str> = Arc::from(topic.into());
self.offsets.insert((arc, partition), offset);
}
pub fn seek_to_beginning(&mut self, topic: &str) {
if let Some(partitions) = self.assignments.get(topic) {
let arc: Arc<str> = Arc::from(topic);
for &p in partitions {
self.offsets.insert((arc.clone(), p), 0);
}
}
}
pub fn position(&self, topic: &str, partition: u32) -> Option<u64> {
self.offsets
.get(&(Arc::<str>::from(topic), partition))
.copied()
}
pub fn assignments(&self) -> &HashMap<String, Vec<u32>> {
&self.assignments
}
pub fn group_id(&self) -> &str {
&self.config.group_id
}
async fn reconnect(&mut self) -> Result<()> {
if let Some(handle) = self.heartbeat_handle.take() {
handle.abort();
}
let mut backoff = Duration::from_millis(self.config.reconnect_backoff_ms);
let max_backoff = Duration::from_millis(self.config.reconnect_backoff_max_ms);
let servers = &self.config.bootstrap_servers;
if servers.is_empty() {
return Err(Error::ConnectionError(
"No bootstrap servers configured".to_string(),
));
}
for attempt in 1..=self.config.max_reconnect_attempts {
let server = &servers[(attempt as usize - 1) % servers.len()];
info!(
attempt,
server = %server,
"Attempting to reconnect"
);
match Client::connect(server).await {
Ok(mut new_client) => {
if let Some(ref auth) = self.config.auth {
if let Err(e) = new_client
.authenticate_scram(&auth.username, &auth.password)
.await
{
warn!(error = %e, attempt, "Re-authentication failed during reconnect");
tokio::time::sleep(backoff).await;
backoff = (backoff * 2).min(max_backoff);
continue;
}
}
self.client = new_client;
info!("Consumer reconnected successfully to {}", server);
if self.uses_coordination {
if let Err(e) = self.discover_assignments().await {
warn!(error = %e, "Failed to rejoin group after reconnect");
}
}
return Ok(());
}
Err(e) => {
warn!(error = %e, attempt, server = %server, "Reconnect attempt failed");
tokio::time::sleep(backoff).await;
backoff = (backoff * 2).min(max_backoff);
}
}
}
Err(Error::ConnectionError(format!(
"Failed to reconnect to any of {:?} after {} attempts",
servers, self.config.max_reconnect_attempts
)))
}
fn is_connection_error(e: &Error) -> bool {
matches!(
e,
Error::ConnectionError(_)
| Error::IoError(_, _)
| Error::Timeout
| Error::TimeoutWithMessage(_)
| Error::ProtocolError(_)
| Error::ResponseTooLarge(_, _)
)
}
pub async fn close(mut self) -> Result<()> {
if let Some(handle) = self.heartbeat_handle.take() {
handle.abort();
}
if self.config.auto_commit_interval.is_some() {
self.commit().await?;
}
if self.uses_coordination && !self.member_id.is_empty() {
if let Err(e) = self
.client
.leave_group(&self.config.group_id, &self.member_id)
.await
{
warn!(
error = %e,
group_id = %self.config.group_id,
member_id = %self.member_id,
"Failed to leave group gracefully"
);
} else {
info!(
group_id = %self.config.group_id,
member_id = %self.member_id,
"Left consumer group"
);
}
}
info!(group_id = %self.config.group_id, "Consumer closed");
Ok(())
}
}