use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{Mutex, mpsc};
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
use crabka_client_core::Client;
use crabka_protocol::owned::streams_group_heartbeat_request::StreamsGroupHeartbeatRequest;
use super::coordinator::{self, CoordinatorState};
use super::status::map_status;
use super::types::{StreamsAssignment, StreamsEvent, TaskOffsetTracker};
use crate::error::StreamsClientError;
use crate::membership::assignment::resolve;
const COORDINATOR_LOAD_IN_PROGRESS: i16 = 14;
#[async_trait::async_trait]
pub trait SchemaPrewarm: Send + Sync {
async fn prewarm(&self) -> Result<(), StreamsClientError>;
}
pub struct StreamsMembership {
member_id: String,
group_id: String,
member_epoch: Arc<Mutex<i32>>,
events: mpsc::UnboundedReceiver<StreamsEvent>,
shutdown: CancellationToken,
hb_handle: Option<JoinHandle<()>>,
tracker: Arc<Mutex<TaskOffsetTracker>>,
}
#[bon::bon]
impl StreamsMembership {
#[builder(start_fn = builder, finish_fn = build)]
#[allow(clippy::too_many_lines)]
pub async fn start(
#[builder(into)] bootstrap: String,
#[builder(into, default = "crabka-streams".to_string())] client_id: String,
#[builder(into)] group_id: String,
topology: std::sync::Arc<crate::topology::BuiltTopology>,
#[builder(into)] process_id: Option<String>,
#[builder(into)] instance_id: Option<String>,
#[builder(default = Duration::from_secs(30))] rebalance_timeout: Duration,
security: Option<crabka_client_core::security::ClientSecurity>,
schema_prewarm: Option<std::sync::Arc<dyn SchemaPrewarm>>,
) -> Result<Self, StreamsClientError> {
if group_id.is_empty() {
return Err(StreamsClientError::Server(0));
}
if let Some(prewarm) = &schema_prewarm {
prewarm.prewarm().await?;
}
let process_id = process_id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
let member_id = uuid::Uuid::new_v4().to_string();
let rebalance_timeout_ms = i32::try_from(rebalance_timeout.as_millis()).unwrap_or(30_000);
let client = Client::builder()
.bootstrap(&bootstrap)
.client_id(client_id.clone())
.maybe_security(security.clone())
.build()
.await?;
let (events_tx, events_rx) = mpsc::unbounded_channel();
let join = loop {
let resp = client
.send(StreamsGroupHeartbeatRequest {
group_id: group_id.clone(),
member_id: member_id.clone(),
member_epoch: 0,
process_id: Some(process_id.clone()),
instance_id: instance_id.clone(),
rebalance_timeout_ms,
topology: Some(topology.to_wire_request()),
..Default::default()
})
.await?;
if resp.error_code == COORDINATOR_LOAD_IN_PROGRESS {
tokio::time::sleep(Duration::from_millis(200)).await;
continue;
}
break map_error(resp)?;
};
let member_epoch_val = join.member_epoch;
let hb_interval = if join.heartbeat_interval_ms > 0 {
Duration::from_millis(u64::try_from(join.heartbeat_interval_ms).unwrap_or(3000))
} else {
Duration::from_secs(3)
};
if let Some(statuses) = &join.status
&& !statuses.is_empty()
{
let _ = events_tx.send(StreamsEvent::NotReady(
statuses.iter().map(map_status).collect(),
));
}
let owned_active = Arc::new(Mutex::new(join.active_tasks.clone().unwrap_or_default()));
let owned_standby = Arc::new(Mutex::new(join.standby_tasks.clone().unwrap_or_default()));
let owned_warmup = Arc::new(Mutex::new(join.warmup_tasks.clone().unwrap_or_default()));
let tracker = Arc::new(Mutex::new(TaskOffsetTracker::default()));
let initial = StreamsAssignment {
active: resolve(join.active_tasks.as_ref(), &topology),
standby: resolve(join.standby_tasks.as_ref(), &topology),
warmup: resolve(join.warmup_tasks.as_ref(), &topology),
};
if initial != StreamsAssignment::default() {
let _ = events_tx.send(StreamsEvent::Assigned(initial.clone()));
}
let coordinator_client = Client::builder()
.bootstrap(&bootstrap)
.client_id(client_id.clone())
.maybe_security(security.clone())
.build()
.await?;
let shutdown = CancellationToken::new();
let member_epoch = Arc::new(Mutex::new(member_epoch_val));
let state = CoordinatorState {
client: coordinator_client,
group_id: group_id.clone(),
member_id: member_id.clone(),
process_id,
instance_id,
rebalance_timeout_ms,
topology: Arc::clone(&topology),
member_epoch: Arc::clone(&member_epoch),
owned_active,
owned_standby,
owned_warmup,
tracker: tracker.clone(),
heartbeat_interval: hb_interval,
events: events_tx,
last_assignment: tokio::sync::Mutex::new(initial),
};
let hb_handle = tokio::spawn(coordinator::run(state, shutdown.clone()));
Ok(Self {
member_id,
group_id,
member_epoch,
events: events_rx,
shutdown,
hb_handle: Some(hb_handle),
tracker,
})
}
}
impl StreamsMembership {
#[must_use]
pub fn member_id(&self) -> &str {
&self.member_id
}
#[must_use]
pub fn tracker(&self) -> Arc<Mutex<TaskOffsetTracker>> {
self.tracker.clone()
}
#[must_use]
pub fn group_id(&self) -> &str {
&self.group_id
}
pub async fn group_metadata(&self) -> crate::runtime::eos::StreamsGroupMeta {
let epoch = *self.member_epoch.lock().await;
crate::runtime::eos::StreamsGroupMeta {
group_id: self.group_id.clone(),
generation_id: epoch,
member_id: self.member_id.clone(),
group_instance_id: None,
}
}
pub async fn next_event(&mut self) -> Result<StreamsEvent, StreamsClientError> {
self.events.recv().await.ok_or(StreamsClientError::Closed)
}
pub async fn close(&mut self) -> Result<(), StreamsClientError> {
self.shutdown.cancel();
if let Some(h) = self.hb_handle.take() {
let _ = h.await;
}
Ok(())
}
}
fn map_error(
resp: crabka_protocol::owned::streams_group_heartbeat_response::StreamsGroupHeartbeatResponse,
) -> Result<
crabka_protocol::owned::streams_group_heartbeat_response::StreamsGroupHeartbeatResponse,
StreamsClientError,
> {
const STREAMS_INVALID_TOPOLOGY: i16 = 130;
const STREAMS_INVALID_TOPOLOGY_EPOCH: i16 = 131;
const STREAMS_TOPOLOGY_FENCED: i16 = 132;
const GROUP_AUTHORIZATION_FAILED: i16 = 30; const TOPIC_AUTHORIZATION_FAILED: i16 = 29; const GROUP_ID_NOT_FOUND: i16 = 69; match resp.error_code {
0 => Ok(resp),
c @ (STREAMS_INVALID_TOPOLOGY
| STREAMS_INVALID_TOPOLOGY_EPOCH
| STREAMS_TOPOLOGY_FENCED) => Err(StreamsClientError::InvalidTopology {
code: c,
message: resp.error_message.unwrap_or_default(),
}),
c @ (GROUP_AUTHORIZATION_FAILED | TOPIC_AUTHORIZATION_FAILED) => {
Err(StreamsClientError::Authorization(c))
}
GROUP_ID_NOT_FOUND => Err(StreamsClientError::GroupIdNotFound),
other => Err(StreamsClientError::Server(other)),
}
}
#[cfg(test)]
mod tests {
use super::map_error;
use crate::error::StreamsClientError;
use assert2::check;
use crabka_protocol::owned::streams_group_heartbeat_response::StreamsGroupHeartbeatResponse;
fn resp(code: i16) -> StreamsGroupHeartbeatResponse {
StreamsGroupHeartbeatResponse {
error_code: code,
error_message: Some("detail".into()),
..Default::default()
}
}
#[test]
fn ok_code_passes_through() {
check!(map_error(resp(0)).is_ok());
}
#[test]
fn invalid_topology_family_maps() {
for code in [130i16, 131, 132] {
check!(matches!(
map_error(resp(code)),
Err(StreamsClientError::InvalidTopology { code: c, .. }) if c == code
));
}
}
#[test]
fn auth_not_found_and_unknown_codes_map() {
check!(matches!(
map_error(resp(30)),
Err(StreamsClientError::Authorization(30))
));
check!(matches!(
map_error(resp(29)),
Err(StreamsClientError::Authorization(29))
));
check!(matches!(
map_error(resp(69)),
Err(StreamsClientError::GroupIdNotFound)
));
check!(matches!(
map_error(resp(99)),
Err(StreamsClientError::Server(99))
));
}
}