use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Mutex;
use tokio_util::sync::CancellationToken;
use crabka_client_core::Client;
use crabka_protocol::owned::share_group_heartbeat_request::ShareGroupHeartbeatRequest;
use crabka_protocol::primitives::uuid::Uuid as WireUuid;
const FENCED_MEMBER_EPOCH: i16 = 110;
const UNKNOWN_MEMBER_ID: i16 = 25;
const STALE_MEMBER_EPOCH: i16 = 113;
pub(crate) struct ShareCoordinatorState {
pub client: Client,
pub group_id: String,
pub member_id: String,
pub member_epoch: Arc<Mutex<i32>>,
pub assignment: Arc<Mutex<Vec<(WireUuid, String, i32)>>>,
pub topic_names: Arc<Mutex<HashMap<WireUuid, String>>>,
pub subscribe: Vec<String>,
pub heartbeat_interval: Duration,
}
enum HeartbeatOutcome {
Ok,
RejoinFromScratch,
Transient,
}
pub(crate) async fn run(state: ShareCoordinatorState, shutdown: CancellationToken) {
let mut ticker = tokio::time::interval(state.heartbeat_interval);
ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
let mut rejoining = false;
loop {
tokio::select! {
() = shutdown.cancelled() => break,
_ = ticker.tick() => {}
}
tokio::select! {
() = shutdown.cancelled() => break,
outcome = heartbeat_once(&state, rejoining) => match outcome {
HeartbeatOutcome::Ok => rejoining = false,
HeartbeatOutcome::Transient => {}
HeartbeatOutcome::RejoinFromScratch => {
*state.member_epoch.lock().await = 0;
rejoining = true;
}
},
}
}
let leave = state.client.send(ShareGroupHeartbeatRequest {
group_id: state.group_id.clone(),
member_id: state.member_id.clone(),
member_epoch: -1,
..Default::default()
});
let _ = tokio::time::timeout(Duration::from_secs(5), leave).await;
}
async fn heartbeat_once(state: &ShareCoordinatorState, rejoining: bool) -> HeartbeatOutcome {
let epoch = *state.member_epoch.lock().await;
let subscribed = if rejoining {
Some(state.subscribe.clone())
} else {
None
};
let result = state
.client
.send(ShareGroupHeartbeatRequest {
group_id: state.group_id.clone(),
member_id: state.member_id.clone(),
member_epoch: epoch,
subscribed_topic_names: subscribed,
..Default::default()
})
.await;
match result {
Ok(r) if r.error_code == 0 => {
*state.member_epoch.lock().await = r.member_epoch;
if let Some(assignment) = r.assignment {
update_assignment(state, assignment).await;
}
HeartbeatOutcome::Ok
}
Ok(r)
if r.error_code == FENCED_MEMBER_EPOCH
|| r.error_code == UNKNOWN_MEMBER_ID
|| r.error_code == STALE_MEMBER_EPOCH =>
{
tracing::warn!(
error_code = r.error_code,
"share heartbeat fenced; rejoining from epoch 0"
);
HeartbeatOutcome::RejoinFromScratch
}
Ok(r) => {
tracing::warn!(
error_code = r.error_code,
"unexpected share heartbeat error"
);
HeartbeatOutcome::Transient
}
Err(e) => {
tracing::warn!(error = %e, "share heartbeat send failed");
HeartbeatOutcome::Transient
}
}
}
async fn update_assignment(
state: &ShareCoordinatorState,
assignment: crabka_protocol::owned::share_group_heartbeat_response::Assignment,
) {
let names = state.topic_names.lock().await;
let mut next: Vec<(WireUuid, String, i32)> = Vec::new();
for tp in &assignment.topic_partitions {
let name = names
.get(&tp.topic_id)
.cloned()
.unwrap_or_else(|| hex_topic_id(tp.topic_id));
for &partition in &tp.partitions {
next.push((tp.topic_id, name.clone(), partition));
}
}
drop(names);
*state.assignment.lock().await = next;
}
fn hex_topic_id(id: WireUuid) -> String {
use std::fmt::Write as _;
id.0.iter().fold(String::with_capacity(32), |mut s, b| {
let _ = write!(s, "{b:02x}");
s
})
}