use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Mutex;
use tokio_util::sync::CancellationToken;
use crabka_client_core::Client;
use crabka_protocol::owned::heartbeat_request::HeartbeatRequest;
use crabka_protocol::owned::join_group_request::{JoinGroupRequest, JoinGroupRequestProtocol};
use crabka_protocol::owned::join_group_response::JoinGroupResponse;
use crabka_protocol::owned::leave_group_request::{LeaveGroupRequest, MemberIdentity};
use crabka_protocol::owned::metadata_request::MetadataRequest;
use crabka_protocol::owned::offset_commit_request::OffsetCommitRequest;
use crabka_protocol::owned::sync_group_request::{SyncGroupRequest, SyncGroupRequestAssignment};
use crabka_protocol::owned::sync_group_response::SyncGroupResponse;
use crabka_protocol::primitives::uuid::Uuid as WireUuid;
use crate::assignor::{Assignor, RebalanceProtocol};
use crate::builder::{
AutoOffsetReset, decode_assignment, decode_subscription, encode_assignment, encode_subscription,
};
use crate::error::ConsumerError;
use crate::offset_wire::{build_commit_topics, build_offset_fetch, id_to_name, parse_offset_fetch};
pub(crate) const COORDINATOR_LOAD_IN_PROGRESS: i16 = 14;
pub(crate) const COORDINATOR_NOT_AVAILABLE: i16 = 15;
pub(crate) const NOT_COORDINATOR: i16 = 16;
pub(crate) const COORDINATOR_RETRY_TIMEOUT: Duration = Duration::from_secs(30);
fn is_retriable_coordinator_code(code: i16) -> bool {
matches!(
code,
COORDINATOR_LOAD_IN_PROGRESS | COORDINATOR_NOT_AVAILABLE | NOT_COORDINATOR
)
}
pub(crate) async fn with_coordinator_retry<R, F, Fut>(
timeout: Duration,
code: impl Fn(&R) -> i16,
make: F,
) -> Result<R, ConsumerError>
where
F: Fn() -> Fut,
Fut: std::future::Future<Output = Result<R, ConsumerError>>,
{
const MAX_BACKOFF: Duration = Duration::from_secs(1);
let start = tokio::time::Instant::now();
let mut backoff = Duration::from_millis(100);
loop {
match make().await {
Ok(r) if !is_retriable_coordinator_code(code(&r)) => return Ok(r),
Ok(r) => {
if start.elapsed() >= timeout {
return Ok(r);
}
}
Err(ConsumerError::Client(crabka_client_core::ClientError::Disconnected)) => {
if start.elapsed() >= timeout {
return Err(ConsumerError::CoordinatorUnavailable);
}
}
Err(e) => return Err(e),
}
tokio::time::sleep(backoff).await;
backoff = (backoff * 2).min(MAX_BACKOFF);
}
}
pub(crate) struct CoordinatorState {
pub client: Client,
pub group_id: String,
pub member_id: String,
pub generation_id: i32,
pub assignor: Assignor,
pub subscribed_topics: Vec<String>,
pub assigned: Arc<Mutex<Vec<(String, i32)>>>,
pub next_offsets: Arc<Mutex<HashMap<(String, i32), i64>>>,
pub positions: Arc<Mutex<HashMap<(String, i32), crate::position::PartitionPosition>>>,
pub topic_ids: Arc<Mutex<HashMap<String, WireUuid>>>,
pub session_timeout: Duration,
pub rebalance_timeout: Duration,
pub heartbeat_interval: Duration,
pub auto_offset_reset: AutoOffsetReset,
pub client_rack: Option<String>,
}
enum HeartbeatOutcome {
Ok,
NeedRejoin,
RejoinFromScratch,
Transient,
}
pub(crate) async fn run(mut state: CoordinatorState, shutdown: CancellationToken) {
let mut ticker = tokio::time::interval(state.heartbeat_interval);
ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
let mut needs_rejoin = false;
loop {
tokio::select! {
() = shutdown.cancelled() => break,
_ = ticker.tick() => {}
}
if needs_rejoin {
tokio::select! {
() = shutdown.cancelled() => break,
result = rejoin(&mut state) => match result {
Ok(()) => needs_rejoin = false,
Err(e) => {
tracing::warn!(error = %e, "rejoin failed; will retry on next tick");
}
},
}
} else {
tokio::select! {
() = shutdown.cancelled() => break,
outcome = heartbeat_once(&state) => match outcome {
HeartbeatOutcome::Ok | HeartbeatOutcome::Transient => {}
HeartbeatOutcome::NeedRejoin => needs_rejoin = true,
HeartbeatOutcome::RejoinFromScratch => {
state.member_id.clear();
state.generation_id = -1;
needs_rejoin = true;
}
},
}
}
}
leave_group(&state).await;
}
async fn leave_group(state: &CoordinatorState) {
if state.member_id.is_empty() {
return;
}
let send = state.client.send(LeaveGroupRequest {
group_id: state.group_id.clone(),
member_id: state.member_id.clone(),
members: vec![MemberIdentity {
member_id: state.member_id.clone(),
..Default::default()
}],
..Default::default()
});
let _ = tokio::time::timeout(Duration::from_secs(5), send).await;
}
async fn heartbeat_once(state: &CoordinatorState) -> HeartbeatOutcome {
let result = state
.client
.send(HeartbeatRequest {
group_id: state.group_id.clone(),
generation_id: state.generation_id,
member_id: state.member_id.clone(),
..Default::default()
})
.await;
match result {
Ok(r) if r.error_code == 0 => HeartbeatOutcome::Ok,
Ok(r) if r.error_code == 27 || r.error_code == 22 => HeartbeatOutcome::NeedRejoin,
Ok(r) if r.error_code == 25 => HeartbeatOutcome::RejoinFromScratch,
Ok(r) => {
tracing::warn!(error_code = r.error_code, "unexpected heartbeat error");
HeartbeatOutcome::Transient
}
Err(e) => {
tracing::warn!(error = %e, "heartbeat send failed");
HeartbeatOutcome::Transient
}
}
}
async fn rejoin(state: &mut CoordinatorState) -> Result<(), ConsumerError> {
let owned: Vec<(String, i32)> = state.assigned.lock().await.clone();
let (new_assignment, new_generation, _protocol_name) = join_and_sync(state, &owned).await?;
let old_set: HashSet<(String, i32)> = owned.iter().cloned().collect();
let new_set: HashSet<(String, i32)> = new_assignment.iter().cloned().collect();
let revoked: Vec<(String, i32)> = old_set.difference(&new_set).cloned().collect();
let added: Vec<(String, i32)> = new_set.difference(&old_set).cloned().collect();
match state.assignor.rebalance_protocol() {
RebalanceProtocol::Eager => {
prime_offsets(state, &added).await?;
{
let mut a = state.assigned.lock().await;
a.clone_from(&new_assignment);
}
{
let mut off = state.next_offsets.lock().await;
off.retain(|k, _| new_set.contains(k));
let mut pos = state.positions.lock().await;
pos.retain(|k, _| new_set.contains(k));
}
state.generation_id = new_generation;
}
RebalanceProtocol::Cooperative => {
if revoked.is_empty() {
prime_offsets(state, &added).await?;
{
let mut a = state.assigned.lock().await;
for p in &added {
if !a.contains(p) {
a.push(p.clone());
}
}
}
state.generation_id = new_generation;
} else {
{
let mut a = state.assigned.lock().await;
a.retain(|p| !revoked.contains(p));
}
state.generation_id = new_generation;
commit_revoked(state, &revoked).await;
{
let mut off = state.next_offsets.lock().await;
let mut pos = state.positions.lock().await;
for p in &revoked {
off.remove(p);
pos.remove(p);
}
}
let owned_after_revoke: Vec<(String, i32)> = state.assigned.lock().await.clone();
let (assignment2, gen2, _) = join_and_sync(state, &owned_after_revoke).await?;
let owned_after_revoke_set: HashSet<(String, i32)> =
owned_after_revoke.iter().cloned().collect();
let added2: Vec<(String, i32)> = assignment2
.iter()
.filter(|p| !owned_after_revoke_set.contains(*p))
.cloned()
.collect();
prime_offsets(state, &added2).await?;
{
let mut a = state.assigned.lock().await;
*a = assignment2;
}
state.generation_id = gen2;
}
}
}
Ok(())
}
async fn commit_revoked(state: &CoordinatorState, revoked: &[(String, i32)]) {
let revoked_set: HashSet<&(String, i32)> = revoked.iter().collect();
let offsets: HashMap<(String, i32), (i64, i32)> = {
let off = state.next_offsets.lock().await;
let pos = state.positions.lock().await;
off.iter()
.filter(|(k, v)| revoked_set.contains(k) && **v > 0 && **v != i64::MAX)
.map(|(k, v)| {
let epoch = pos.get(k).map_or(-1, |p| p.offset_epoch);
(k.clone(), (*v, epoch))
})
.collect()
};
if offsets.is_empty() {
return;
}
let topic_ids = state.topic_ids.lock().await.clone();
let topics = build_commit_topics(offsets, &topic_ids);
let res = state
.client
.send(OffsetCommitRequest {
group_id: state.group_id.clone(),
generation_id_or_member_epoch: state.generation_id,
member_id: state.member_id.clone(),
topics,
..Default::default()
})
.await;
match res {
Ok(_) => {}
Err(e) => {
tracing::warn!(error = %e, "revoke-time offset commit failed; partitions may re-deliver");
}
}
}
#[allow(clippy::too_many_lines)]
async fn join_and_sync(
state: &mut CoordinatorState,
owned: &[(String, i32)],
) -> Result<(Vec<(String, i32)>, i32, String), ConsumerError> {
let session_timeout_ms = i32::try_from(state.session_timeout.as_millis()).unwrap_or(i32::MAX);
let rebalance_timeout_ms =
i32::try_from(state.rebalance_timeout.as_millis()).unwrap_or(i32::MAX);
let subscription_bytes = encode_subscription(
&state.subscribed_topics,
owned,
state.generation_id,
state.client_rack.as_deref(),
);
let protocol_name = state.assignor.protocol_name().to_string();
let r1 = with_coordinator_retry(
COORDINATOR_RETRY_TIMEOUT,
|r: &JoinGroupResponse| r.error_code,
|| {
let group_id = state.group_id.clone();
let member_id = state.member_id.clone();
let protocol_name = protocol_name.clone();
let subscription_bytes = subscription_bytes.clone();
let client = &state.client;
async move {
client
.send(JoinGroupRequest {
group_id,
protocol_type: "consumer".into(),
member_id,
session_timeout_ms,
rebalance_timeout_ms,
protocols: vec![JoinGroupRequestProtocol {
name: protocol_name,
metadata: subscription_bytes,
..Default::default()
}],
..Default::default()
})
.await
.map_err(ConsumerError::from)
}
},
)
.await?;
let join_resp = if r1.error_code == 0 {
r1
} else if r1.error_code == 79 {
let assigned_id = r1.member_id.clone();
if assigned_id.is_empty() {
return Err(ConsumerError::RebalanceFailed(
"broker did not assign a member_id".into(),
));
}
state.member_id.clone_from(&assigned_id);
let r2 = with_coordinator_retry(
COORDINATOR_RETRY_TIMEOUT,
|r: &JoinGroupResponse| r.error_code,
|| {
let group_id = state.group_id.clone();
let assigned_id = assigned_id.clone();
let protocol_name = protocol_name.clone();
let subscription_bytes = subscription_bytes.clone();
let client = &state.client;
async move {
client
.send(JoinGroupRequest {
group_id,
protocol_type: "consumer".into(),
member_id: assigned_id,
session_timeout_ms,
rebalance_timeout_ms,
protocols: vec![JoinGroupRequestProtocol {
name: protocol_name,
metadata: subscription_bytes,
..Default::default()
}],
..Default::default()
})
.await
.map_err(ConsumerError::from)
}
},
)
.await?;
if r2.error_code != 0 {
return Err(ConsumerError::Server(r2.error_code));
}
r2
} else {
return Err(ConsumerError::Server(r1.error_code));
};
if !join_resp.member_id.is_empty() {
state.member_id.clone_from(&join_resp.member_id);
}
let chosen_protocol = join_resp
.protocol_name
.clone()
.unwrap_or_else(|| protocol_name.clone());
let generation_id = join_resp.generation_id;
let is_leader = join_resp.leader == state.member_id;
let assignments_for_sync: Vec<SyncGroupRequestAssignment> = if is_leader {
let md = state.client.send(MetadataRequest::default()).await?;
let mut topic_partitions: HashMap<String, i32> = HashMap::new();
let mut resolved_ids: HashMap<String, WireUuid> = HashMap::new();
for t in &md.topics {
let Some(name) = &t.name else { continue };
if state.subscribed_topics.iter().any(|s| s == name) {
let count = i32::try_from(t.partitions.len()).unwrap_or(i32::MAX);
topic_partitions.insert(name.clone(), count);
resolved_ids.insert(name.clone(), t.topic_id);
}
}
{
let mut ids = state.topic_ids.lock().await;
for (k, v) in resolved_ids {
ids.insert(k, v);
}
}
let decoded: Vec<(String, crate::builder::DecodedSubscription)> = join_resp
.members
.iter()
.map(|m| (m.member_id.clone(), decode_subscription(&m.metadata)))
.collect();
let assignments = match state.assignor {
Assignor::Range => {
let inputs: Vec<(String, Vec<String>)> = decoded
.into_iter()
.map(|(id, sub)| (id, sub.topics))
.collect();
crate::assignor::range::assign(inputs, &topic_partitions)
}
Assignor::CooperativeSticky => {
let inputs: Vec<crate::assignor::cooperative_sticky::MemberInput> = decoded
.into_iter()
.map(|(id, sub)| (id, sub.topics, sub.owned, sub.generation_id))
.collect();
crate::assignor::cooperative_sticky::assign(&inputs, &topic_partitions)
}
};
assignments
.into_iter()
.map(|(m, partitions)| SyncGroupRequestAssignment {
member_id: m,
assignment: encode_assignment(&partitions),
..Default::default()
})
.collect()
} else {
Vec::new()
};
let sync_resp = with_coordinator_retry(
COORDINATOR_RETRY_TIMEOUT,
|r: &SyncGroupResponse| r.error_code,
|| {
let group_id = state.group_id.clone();
let member_id = state.member_id.clone();
let chosen_protocol = chosen_protocol.clone();
let assignments_for_sync = assignments_for_sync.clone();
let client = &state.client;
async move {
client
.send(SyncGroupRequest {
group_id,
generation_id,
member_id,
protocol_type: Some("consumer".into()),
protocol_name: Some(chosen_protocol),
assignments: assignments_for_sync,
..Default::default()
})
.await
.map_err(ConsumerError::from)
}
},
)
.await?;
if sync_resp.error_code != 0 {
return Err(ConsumerError::Server(sync_resp.error_code));
}
let my_assignment = decode_assignment(&sync_resp.assignment);
Ok((my_assignment, generation_id, chosen_protocol))
}
async fn prime_offsets(
state: &CoordinatorState,
partitions: &[(String, i32)],
) -> Result<(), ConsumerError> {
if partitions.is_empty() {
return Ok(());
}
let mut by_topic: HashMap<String, Vec<i32>> = HashMap::new();
for (t, p) in partitions {
by_topic.entry(t.clone()).or_default().push(*p);
}
let topic_ids = state.topic_ids.lock().await.clone();
let of = state
.client
.send(build_offset_fetch(&state.group_id, &by_topic, &topic_ids))
.await?;
let id_to_name = id_to_name(&topic_ids);
let mut offsets = state.next_offsets.lock().await;
let mut positions = state.positions.lock().await;
let mut seen: HashSet<(String, i32)> = HashSet::new();
for (name, partition_index, committed, committed_epoch) in parse_offset_fetch(&of, &id_to_name)
{
let starting = if committed >= 0 {
committed
} else {
match state.auto_offset_reset {
AutoOffsetReset::Earliest => 0,
AutoOffsetReset::Latest | AutoOffsetReset::None => i64::MAX,
}
};
let key = (name, partition_index);
seen.insert(key.clone());
offsets.insert(key.clone(), starting);
positions.entry(key).or_default().offset_epoch = committed_epoch;
}
for tp in partitions {
if !seen.contains(tp) {
let starting = match state.auto_offset_reset {
AutoOffsetReset::Earliest => 0,
AutoOffsetReset::Latest | AutoOffsetReset::None => i64::MAX,
};
offsets.insert(tp.clone(), starting);
positions.entry(tp.clone()).or_default();
}
}
Ok(())
}
#[cfg(test)]
mod retry_tests {
use super::*;
use assert2::assert;
use std::sync::atomic::{AtomicUsize, Ordering};
struct Resp {
error_code: i16,
}
#[tokio::test(start_paused = true)]
async fn retries_until_coordinator_finishes_loading() {
let calls = AtomicUsize::new(0);
let r = with_coordinator_retry(
Duration::from_secs(30),
|r: &Resp| r.error_code,
|| {
let n = calls.fetch_add(1, Ordering::SeqCst);
async move {
Ok::<_, ConsumerError>(Resp {
error_code: if n < 3 { 14 } else { 0 },
})
}
},
)
.await
.unwrap();
assert!(r.error_code == 0);
assert!(calls.load(Ordering::SeqCst) == 4);
}
#[tokio::test(start_paused = true)]
async fn surfaces_last_response_after_deadline() {
let r = with_coordinator_retry(
Duration::from_secs(1),
|r: &Resp| r.error_code,
|| async { Ok::<_, ConsumerError>(Resp { error_code: 15 }) },
)
.await
.unwrap();
assert!(r.error_code == 15);
}
#[tokio::test(start_paused = true)]
async fn non_retriable_code_returns_immediately() {
let calls = AtomicUsize::new(0);
let r = with_coordinator_retry(
Duration::from_secs(30),
|r: &Resp| r.error_code,
|| {
calls.fetch_add(1, Ordering::SeqCst);
async move { Ok::<_, ConsumerError>(Resp { error_code: 25 }) } },
)
.await
.unwrap();
assert!(r.error_code == 25);
assert!(calls.load(Ordering::SeqCst) == 1);
}
#[tokio::test(start_paused = true)]
async fn disconnect_past_deadline_surfaces_coordinator_unavailable() {
let r = with_coordinator_retry(
Duration::from_secs(1),
|r: &Resp| r.error_code,
|| async {
Err::<Resp, _>(ConsumerError::Client(
crabka_client_core::ClientError::Disconnected,
))
},
)
.await;
assert!(matches!(r, Err(ConsumerError::CoordinatorUnavailable)));
}
}