use super::*;
use crate::storage::HybridStorage;
use crate::topic_manager::TopicManager;
use crate::Result;
use crossbeam::channel;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::time::{Duration, SystemTime};
use tokio::fs;
use tokio::sync::RwLock;
use tokio::time::interval;
use tracing::{debug, error, info, warn};
#[derive(Debug)]
pub struct ConsumerGroupCoordinator {
config: ConsumerGroupConfig,
groups: Arc<RwLock<HashMap<ConsumerGroupId, ConsumerGroupMetadata>>>,
topic_manager: Arc<TopicManager>,
storage: Option<Arc<HybridStorage>>,
offset_storage: Arc<RwLock<HashMap<(ConsumerGroupId, TopicName, PartitionId), ConsumerOffset>>>,
state_change_tx: channel::Sender<GroupStateChange>,
state_change_rx: Arc<tokio::sync::Mutex<Option<channel::Receiver<GroupStateChange>>>>,
metadata_dir: Option<PathBuf>,
}
#[derive(Debug, Clone)]
enum GroupStateChange {
MemberJoined {
group_id: ConsumerGroupId,
consumer_id: ConsumerId,
},
MemberLeft {
group_id: ConsumerGroupId,
consumer_id: ConsumerId,
},
RebalanceTriggered {
group_id: ConsumerGroupId,
},
GroupEmpty {
group_id: ConsumerGroupId,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct SerializableGroupMetadata {
pub group_id: ConsumerGroupId,
pub state: ConsumerGroupState,
pub protocol_type: String,
pub protocol_name: String,
pub leader_id: Option<ConsumerId>,
pub members: HashMap<ConsumerId, SerializableGroupMember>,
pub generation_id: i32,
pub created_at: u64, pub state_timestamp: u64, }
#[derive(Debug, Clone, Serialize, Deserialize)]
struct SerializableGroupMember {
pub consumer_id: ConsumerId,
pub group_id: ConsumerGroupId,
pub client_id: String,
pub client_host: String,
pub session_timeout_ms: u64,
pub rebalance_timeout_ms: u64,
pub subscribed_topics: Vec<TopicName>,
pub assigned_partitions: Vec<TopicPartition>,
pub last_heartbeat: u64, pub is_leader: bool,
}
impl ConsumerGroupCoordinator {
pub fn new(
config: ConsumerGroupConfig,
topic_manager: Arc<TopicManager>,
storage: Option<Arc<HybridStorage>>,
metadata_dir: Option<PathBuf>,
) -> Self {
let (state_change_tx, state_change_rx) = channel::unbounded();
Self {
config,
groups: Arc::new(RwLock::new(HashMap::new())),
topic_manager,
storage,
offset_storage: Arc::new(RwLock::new(HashMap::new())),
state_change_tx,
state_change_rx: Arc::new(tokio::sync::Mutex::new(Some(state_change_rx))),
metadata_dir,
}
}
fn get_metadata_file_path(&self, group_id: &ConsumerGroupId) -> Option<PathBuf> {
self.metadata_dir
.as_ref()
.map(|dir| dir.join(format!("group_{}.json", group_id)))
}
fn get_metadata_dir(&self) -> Option<&Path> {
self.metadata_dir.as_deref()
}
pub async fn start(self: Arc<Self>) -> Result<()> {
if let Err(e) = self.load_persisted_metadata().await {
warn!("Failed to load persisted group metadata: {}", e);
}
let expiry_coordinator = Arc::clone(&self);
tokio::spawn(async move {
expiry_coordinator.expiration_checker_loop().await;
});
let state_coordinator = Arc::clone(&self);
tokio::spawn(async move {
state_coordinator.process_state_changes().await;
});
let cleanup_coordinator = Arc::clone(&self);
tokio::spawn(async move {
cleanup_coordinator.offset_cleanup_loop().await;
});
let persistence_coordinator = Arc::clone(&self);
tokio::spawn(async move {
persistence_coordinator.metadata_persistence_loop().await;
});
info!("Consumer group coordinator started with persistence");
Ok(())
}
pub async fn handle_message(
&self,
request: ConsumerGroupMessage,
) -> Result<ConsumerGroupMessage> {
match request {
ConsumerGroupMessage::JoinGroup { .. } => Ok(self.handle_join_group(request).await),
ConsumerGroupMessage::SyncGroup { .. } => Ok(self.handle_sync_group(request).await),
ConsumerGroupMessage::Heartbeat { .. } => Ok(self.handle_heartbeat(request).await),
ConsumerGroupMessage::LeaveGroup { .. } => Ok(self.handle_leave_group(request).await),
ConsumerGroupMessage::ListGroups => Ok(self.handle_list_groups().await),
ConsumerGroupMessage::DescribeGroups { .. } => {
Ok(self.handle_describe_groups(request).await)
}
ConsumerGroupMessage::OffsetCommit { .. } => {
Ok(self.handle_offset_commit(request).await)
}
ConsumerGroupMessage::OffsetFetch { .. } => Ok(self.handle_offset_fetch(request).await),
_ => {
Ok(ConsumerGroupMessage::JoinGroupResponse {
error_code: error_codes::UNKNOWN_CONSUMER_ID,
generation_id: -1,
group_protocol: "".to_string(),
leader_id: "".to_string(),
consumer_id: "".to_string(),
members: vec![],
})
}
}
}
pub async fn handle_join_group(&self, request: ConsumerGroupMessage) -> ConsumerGroupMessage {
if let ConsumerGroupMessage::JoinGroup {
group_id,
consumer_id,
client_id,
client_host,
session_timeout_ms,
rebalance_timeout_ms,
protocol_type,
group_protocols,
} = request
{
if session_timeout_ms < self.config.min_session_timeout_ms
|| session_timeout_ms > self.config.max_session_timeout_ms
{
return ConsumerGroupMessage::JoinGroupResponse {
error_code: error_codes::INVALID_SESSION_TIMEOUT,
generation_id: -1,
group_protocol: String::new(),
leader_id: String::new(),
consumer_id: consumer_id.clone(),
members: Vec::new(),
};
}
let result = self
.join_group_internal(
group_id.clone(),
consumer_id.clone(),
client_id,
client_host,
session_timeout_ms,
rebalance_timeout_ms,
protocol_type,
group_protocols,
)
.await;
match result {
Ok((generation_id, group_protocol, leader_id, members)) => {
ConsumerGroupMessage::JoinGroupResponse {
error_code: error_codes::NONE,
generation_id,
group_protocol,
leader_id,
consumer_id,
members,
}
}
Err(error_code) => ConsumerGroupMessage::JoinGroupResponse {
error_code,
generation_id: -1,
group_protocol: String::new(),
leader_id: String::new(),
consumer_id,
members: Vec::new(),
},
}
} else {
ConsumerGroupMessage::JoinGroupResponse {
error_code: error_codes::INVALID_GROUP_ID,
generation_id: -1,
group_protocol: String::new(),
leader_id: String::new(),
consumer_id: String::new(),
members: Vec::new(),
}
}
}
pub async fn handle_sync_group(&self, request: ConsumerGroupMessage) -> ConsumerGroupMessage {
if let ConsumerGroupMessage::SyncGroup {
group_id,
consumer_id,
generation_id,
group_assignments,
} = request
{
let result = self
.sync_group_internal(
group_id,
consumer_id.clone(),
generation_id,
group_assignments,
)
.await;
match result {
Ok(assignment) => ConsumerGroupMessage::SyncGroupResponse {
error_code: error_codes::NONE,
assignment,
},
Err(error_code) => ConsumerGroupMessage::SyncGroupResponse {
error_code,
assignment: Vec::new(),
},
}
} else {
ConsumerGroupMessage::SyncGroupResponse {
error_code: error_codes::INVALID_GROUP_ID,
assignment: Vec::new(),
}
}
}
pub async fn handle_heartbeat(&self, request: ConsumerGroupMessage) -> ConsumerGroupMessage {
if let ConsumerGroupMessage::Heartbeat {
group_id,
consumer_id,
generation_id,
} = request
{
let error_code = self
.heartbeat_internal(group_id, consumer_id, generation_id)
.await;
ConsumerGroupMessage::HeartbeatResponse { error_code }
} else {
ConsumerGroupMessage::HeartbeatResponse {
error_code: error_codes::INVALID_GROUP_ID,
}
}
}
pub async fn handle_leave_group(&self, request: ConsumerGroupMessage) -> ConsumerGroupMessage {
if let ConsumerGroupMessage::LeaveGroup {
group_id,
consumer_id,
} = request
{
let error_code = self.leave_group_internal(group_id, consumer_id).await;
ConsumerGroupMessage::LeaveGroupResponse { error_code }
} else {
ConsumerGroupMessage::LeaveGroupResponse {
error_code: error_codes::INVALID_GROUP_ID,
}
}
}
pub async fn handle_list_groups(&self) -> ConsumerGroupMessage {
let groups = self.groups.read().await;
let group_overviews: Vec<GroupOverview> = groups
.values()
.filter(|metadata| {
!matches!(metadata.state, ConsumerGroupState::Dead)
})
.map(|metadata| GroupOverview {
group_id: metadata.group_id.clone(),
protocol_type: metadata.protocol_type.clone(),
})
.collect();
debug!("Listed {} active consumer groups", group_overviews.len());
ConsumerGroupMessage::ListGroupsResponse {
error_code: error_codes::NONE,
groups: group_overviews,
}
}
pub async fn handle_list_groups_filtered(
&self,
states: Option<Vec<ConsumerGroupState>>,
) -> ConsumerGroupMessage {
let groups = self.groups.read().await;
let group_overviews: Vec<GroupOverview> = groups
.values()
.filter(|metadata| match &states {
Some(filter_states) => filter_states.contains(&metadata.state),
None => !matches!(metadata.state, ConsumerGroupState::Dead),
})
.map(|metadata| GroupOverview {
group_id: metadata.group_id.clone(),
protocol_type: metadata.protocol_type.clone(),
})
.collect();
debug!(
"Listed {} filtered consumer groups (filter: {:?})",
group_overviews.len(),
states
);
ConsumerGroupMessage::ListGroupsResponse {
error_code: error_codes::NONE,
groups: group_overviews,
}
}
pub async fn handle_describe_groups(
&self,
request: ConsumerGroupMessage,
) -> ConsumerGroupMessage {
if let ConsumerGroupMessage::DescribeGroups { group_ids } = request {
let groups = self.groups.read().await;
let mut descriptions = Vec::new();
for group_id in group_ids {
if let Some(metadata) = groups.get(&group_id) {
let member_descriptions: Vec<MemberDescription> = metadata
.members
.values()
.map(|member| {
let member_metadata =
self.serialize_member_metadata(&member.subscribed_topics);
let member_assignment =
self.serialize_member_assignment(&member.assigned_partitions);
MemberDescription {
consumer_id: member.consumer_id.clone(),
client_id: member.client_id.clone(),
client_host: member.client_host.clone(),
member_metadata,
member_assignment,
}
})
.collect();
let total_partitions: usize = metadata
.members
.values()
.map(|m| m.assigned_partitions.len())
.sum();
debug!(
"Describing group {}: state={:?}, members={}, partitions={}",
group_id,
metadata.state,
metadata.members.len(),
total_partitions
);
descriptions.push(ConsumerGroupDescription {
error_code: error_codes::NONE,
group_id: group_id.clone(),
state: metadata.state.clone(),
protocol_type: metadata.protocol_type.clone(),
protocol_data: metadata.protocol_name.clone(),
members: member_descriptions,
});
} else {
warn!("Requested description for unknown group: {}", group_id);
descriptions.push(ConsumerGroupDescription {
error_code: error_codes::UNKNOWN_GROUP_ID,
group_id: group_id.clone(),
state: ConsumerGroupState::Dead,
protocol_type: String::new(),
protocol_data: String::new(),
members: Vec::new(),
});
}
}
ConsumerGroupMessage::DescribeGroupsResponse {
groups: descriptions,
}
} else {
ConsumerGroupMessage::DescribeGroupsResponse { groups: Vec::new() }
}
}
pub async fn handle_offset_commit(
&self,
request: ConsumerGroupMessage,
) -> ConsumerGroupMessage {
if let ConsumerGroupMessage::OffsetCommit {
group_id,
consumer_id,
generation_id,
retention_time_ms,
offsets,
} = request
{
let result = self
.commit_offsets_internal(
group_id,
consumer_id,
generation_id,
retention_time_ms,
offsets,
)
.await;
match result {
Ok(errors) => ConsumerGroupMessage::OffsetCommitResponse {
error_code: error_codes::NONE,
topic_partition_errors: errors,
},
Err(global_error) => ConsumerGroupMessage::OffsetCommitResponse {
error_code: global_error,
topic_partition_errors: Vec::new(),
},
}
} else {
ConsumerGroupMessage::OffsetCommitResponse {
error_code: error_codes::INVALID_GROUP_ID,
topic_partition_errors: Vec::new(),
}
}
}
pub async fn handle_offset_fetch(&self, request: ConsumerGroupMessage) -> ConsumerGroupMessage {
if let ConsumerGroupMessage::OffsetFetch {
group_id,
topic_partitions,
} = request
{
let result = self
.fetch_offsets_internal(group_id, topic_partitions)
.await;
match result {
Ok(offsets) => ConsumerGroupMessage::OffsetFetchResponse {
error_code: error_codes::NONE,
offsets,
},
Err(error_code) => ConsumerGroupMessage::OffsetFetchResponse {
error_code,
offsets: Vec::new(),
},
}
} else {
ConsumerGroupMessage::OffsetFetchResponse {
error_code: error_codes::INVALID_GROUP_ID,
offsets: Vec::new(),
}
}
}
async fn join_group_internal(
&self,
group_id: ConsumerGroupId,
consumer_id: ConsumerId,
client_id: String,
client_host: String,
session_timeout_ms: u64,
rebalance_timeout_ms: u64,
protocol_type: String,
group_protocols: Vec<GroupProtocol>,
) -> std::result::Result<(i32, String, ConsumerId, Vec<ConsumerGroupMember>), i16> {
let mut groups = self.groups.write().await;
let now = SystemTime::now();
let group = groups
.entry(group_id.clone())
.or_insert_with(|| ConsumerGroupMetadata {
group_id: group_id.clone(),
state: ConsumerGroupState::Empty,
protocol_type: protocol_type.clone(),
protocol_name: self.select_compatible_protocol(&group_protocols),
leader_id: None,
members: HashMap::new(),
generation_id: 0,
created_at: now,
state_timestamp: now,
});
if group.protocol_type != protocol_type && !group.members.is_empty() {
return Err(error_codes::INCONSISTENT_GROUP_PROTOCOL);
}
let subscribed_topics = self.extract_subscribed_topics(&group_protocols);
debug!(
"Extracted {} subscribed topics for consumer {}: {:?}",
subscribed_topics.len(),
consumer_id,
subscribed_topics
);
let is_new_member = !group.members.contains_key(&consumer_id);
let member = ConsumerGroupMember {
consumer_id: consumer_id.clone(),
group_id: group_id.clone(),
client_id,
client_host,
session_timeout_ms,
rebalance_timeout_ms,
subscribed_topics,
assigned_partitions: Vec::new(),
last_heartbeat: now,
is_leader: false,
};
group.members.insert(consumer_id.clone(), member);
if is_new_member || group.state == ConsumerGroupState::Empty {
group.state = ConsumerGroupState::PreparingRebalance;
group.generation_id += 1;
group.state_timestamp = now;
if group.leader_id.is_none()
|| !group
.members
.contains_key(group.leader_id.as_ref().unwrap())
{
group.leader_id = Some(consumer_id.clone());
if let Some(leader_member) = group.members.get_mut(&consumer_id) {
leader_member.is_leader = true;
}
}
let _ = self.state_change_tx.send(GroupStateChange::MemberJoined {
group_id: group_id.clone(),
consumer_id: consumer_id.clone(),
});
}
let leader_id = group.leader_id.clone().unwrap_or_default();
let generation_id = group.generation_id;
let protocol_name = group.protocol_name.clone();
let members: Vec<ConsumerGroupMember> = group.members.values().cloned().collect();
debug!(
"JoinGroup returning: group={}, protocol_name={}, generation={}",
group_id, protocol_name, generation_id
);
Ok((generation_id, protocol_name, leader_id, members))
}
async fn sync_group_internal(
&self,
group_id: ConsumerGroupId,
consumer_id: ConsumerId,
generation_id: i32,
group_assignments: HashMap<ConsumerId, Vec<TopicPartition>>,
) -> std::result::Result<Vec<TopicPartition>, i16> {
let mut groups = self.groups.write().await;
let group = groups
.get_mut(&group_id)
.ok_or(error_codes::UNKNOWN_GROUP_ID)?;
if group.generation_id != generation_id {
return Err(error_codes::ILLEGAL_GENERATION);
}
let member = group
.members
.get_mut(&consumer_id)
.ok_or(error_codes::UNKNOWN_CONSUMER_ID)?;
member.last_heartbeat = SystemTime::now();
if member.is_leader {
let final_assignments = if group_assignments.is_empty() {
self.generate_partition_assignments(&group_id, group)
.await?
} else {
group_assignments
};
for (member_id, assignment) in final_assignments.iter() {
if let Some(target_member) = group.members.get_mut(member_id) {
target_member.assigned_partitions = assignment.clone();
debug!(
"Assigned {} partitions to consumer {} in group {}: {:?}",
assignment.len(),
member_id,
group_id,
assignment
);
}
}
group.state = ConsumerGroupState::Stable;
group.state_timestamp = SystemTime::now();
info!(
"Group {} transitioned to stable state with {} members",
group_id,
group.members.len()
);
let assignment = final_assignments
.get(&consumer_id)
.cloned()
.unwrap_or_default();
Ok(assignment)
} else {
let assignment = member.assigned_partitions.clone();
Ok(assignment)
}
}
async fn heartbeat_internal(
&self,
group_id: ConsumerGroupId,
consumer_id: ConsumerId,
generation_id: i32,
) -> i16 {
let mut groups = self.groups.write().await;
let group = match groups.get_mut(&group_id) {
Some(g) => g,
None => return error_codes::UNKNOWN_GROUP_ID,
};
if group.generation_id != generation_id {
return error_codes::ILLEGAL_GENERATION;
}
if let Some(member) = group.members.get_mut(&consumer_id) {
member.last_heartbeat = SystemTime::now();
error_codes::NONE
} else {
error_codes::UNKNOWN_CONSUMER_ID
}
}
async fn leave_group_internal(
&self,
group_id: ConsumerGroupId,
consumer_id: ConsumerId,
) -> i16 {
let mut groups = self.groups.write().await;
let group = match groups.get_mut(&group_id) {
Some(g) => g,
None => return error_codes::UNKNOWN_GROUP_ID,
};
if group.members.remove(&consumer_id).is_some() {
if group.leader_id.as_ref() == Some(&consumer_id) {
group.leader_id = group.members.keys().next().cloned();
if let Some(new_leader_id) = &group.leader_id {
if let Some(new_leader) = group.members.get_mut(new_leader_id) {
new_leader.is_leader = true;
}
}
}
if group.members.is_empty() {
group.state = ConsumerGroupState::Empty;
group.leader_id = None;
let _ = self.state_change_tx.send(GroupStateChange::GroupEmpty {
group_id: group_id.clone(),
});
} else {
group.state = ConsumerGroupState::PreparingRebalance;
group.generation_id += 1;
let _ = self
.state_change_tx
.send(GroupStateChange::RebalanceTriggered {
group_id: group_id.clone(),
});
}
group.state_timestamp = SystemTime::now();
let _ = self.state_change_tx.send(GroupStateChange::MemberLeft {
group_id,
consumer_id,
});
error_codes::NONE
} else {
error_codes::UNKNOWN_CONSUMER_ID
}
}
async fn expiration_checker_loop(&self) {
let mut interval = interval(Duration::from_millis(
self.config.consumer_expiration_check_interval_ms,
));
loop {
interval.tick().await;
if let Err(e) = self.check_expired_consumers().await {
error!("Error checking expired consumers: {}", e);
}
}
}
async fn check_expired_consumers(&self) -> Result<()> {
let mut groups = self.groups.write().await;
let now = SystemTime::now();
let mut expired_members = Vec::new();
for (group_id, group) in groups.iter() {
for (consumer_id, member) in &group.members {
let elapsed = now
.duration_since(member.last_heartbeat)
.unwrap_or(Duration::from_secs(0));
if elapsed.as_millis() > member.session_timeout_ms as u128 {
expired_members.push((group_id.clone(), consumer_id.clone()));
}
}
}
for (group_id, consumer_id) in expired_members {
if let Some(group) = groups.get_mut(&group_id) {
group.members.remove(&consumer_id);
if group.members.is_empty() {
group.state = ConsumerGroupState::Empty;
group.leader_id = None;
} else {
if group.leader_id.as_ref() == Some(&consumer_id) {
group.leader_id = group.members.keys().next().cloned();
if let Some(new_leader_id) = &group.leader_id {
if let Some(new_leader) = group.members.get_mut(new_leader_id) {
new_leader.is_leader = true;
}
}
}
group.state = ConsumerGroupState::PreparingRebalance;
group.generation_id += 1;
}
group.state_timestamp = now;
warn!(
"Removed expired consumer {} from group {}",
consumer_id, group_id
);
let _ = self.state_change_tx.send(GroupStateChange::MemberLeft {
group_id: group_id.clone(),
consumer_id,
});
if group.members.is_empty() {
let _ = self
.state_change_tx
.send(GroupStateChange::GroupEmpty { group_id });
} else {
let _ = self
.state_change_tx
.send(GroupStateChange::RebalanceTriggered { group_id });
}
}
}
Ok(())
}
async fn process_state_changes(&self) {
let state_change_rx = {
let mut rx_guard = self.state_change_rx.lock().await;
if let Some(rx) = rx_guard.take() {
rx
} else {
return;
}
};
loop {
let result = tokio::task::spawn_blocking({
let rx_clone = state_change_rx.clone();
move || rx_clone.recv()
})
.await;
match result {
Ok(Ok(change)) => match change {
GroupStateChange::MemberJoined {
group_id,
consumer_id,
} => {
debug!("Consumer {} joined group {}", consumer_id, group_id);
}
GroupStateChange::MemberLeft {
group_id,
consumer_id,
} => {
debug!("Consumer {} left group {}", consumer_id, group_id);
}
GroupStateChange::RebalanceTriggered { group_id } => {
debug!("Rebalance triggered for group {}", group_id);
}
GroupStateChange::GroupEmpty { group_id } => {
debug!("Group {} is now empty", group_id);
}
},
Ok(Err(_)) => break, Err(_) => break, }
}
}
pub async fn assign_partitions(
&self,
group_id: &ConsumerGroupId,
topics: &[TopicName],
) -> Result<HashMap<ConsumerId, Vec<TopicPartition>>> {
let groups = self.groups.read().await;
let group = groups
.get(group_id)
.ok_or_else(|| crate::FluxmqError::Config(format!("Unknown group: {}", group_id)))?;
let mut all_partitions = Vec::new();
for topic in topics {
let partitions = self.topic_manager.get_partitions(topic);
for partition_id in partitions {
all_partitions.push(TopicPartition::new(topic, partition_id));
}
}
let mut consumers: Vec<ConsumerId> = group.members.keys().cloned().collect();
consumers.sort();
let assignor = PartitionAssignor::new(self.config.default_assignment_strategy.clone());
let assignments = assignor.assign(&consumers, &all_partitions);
debug!(
"Generated partition assignments for group {}: {} consumers, {} partitions",
group_id,
consumers.len(),
all_partitions.len()
);
Ok(assignments)
}
pub fn assign_partitions_without_lock(
&self,
group: &ConsumerGroupMetadata,
topics: &[TopicName],
) -> Result<HashMap<ConsumerId, Vec<TopicPartition>>> {
let mut all_partitions = Vec::new();
for topic in topics {
let partitions = self.topic_manager.get_partitions(topic);
debug!(
"Topic '{}' has {} partitions: {:?}",
topic,
partitions.len(),
partitions
);
for partition_id in partitions {
all_partitions.push(TopicPartition::new(topic, partition_id));
}
}
let mut consumers: Vec<ConsumerId> = group.members.keys().cloned().collect();
consumers.sort();
let assignor = PartitionAssignor::new(self.config.default_assignment_strategy.clone());
let assignments = assignor.assign(&consumers, &all_partitions);
debug!(
"Generated partition assignments: {} consumers, {} partitions",
consumers.len(),
all_partitions.len()
);
Ok(assignments)
}
async fn generate_partition_assignments(
&self,
group_id: &ConsumerGroupId,
group: &ConsumerGroupMetadata,
) -> std::result::Result<HashMap<ConsumerId, Vec<TopicPartition>>, i16> {
let mut subscribed_topics = std::collections::HashSet::new();
for member in group.members.values() {
for topic in &member.subscribed_topics {
subscribed_topics.insert(topic.clone());
}
}
if subscribed_topics.is_empty() {
let all_topics = self.topic_manager.list_topics();
if all_topics.is_empty() {
debug!("No topics available for assignment in group {}", group_id);
let mut empty_assignments = HashMap::new();
for consumer_id in group.members.keys() {
empty_assignments.insert(consumer_id.clone(), Vec::new());
}
return Ok(empty_assignments);
}
for topic in all_topics {
subscribed_topics.insert(topic);
}
}
let topics_vec: Vec<TopicName> = subscribed_topics.into_iter().collect();
debug!(
"Generating assignments for group {} with topics: {:?}, strategy: {:?}",
group_id, topics_vec, self.config.default_assignment_strategy
);
match self.assign_partitions_without_lock(group, &topics_vec) {
Ok(assignments) => {
let total_partitions: usize = assignments.values().map(|v| v.len()).sum();
info!(
"Generated {} partition assignments across {} consumers for group {}",
total_partitions,
assignments.len(),
group_id
);
Ok(assignments)
}
Err(e) => {
warn!(
"Failed to generate partition assignments for group {}: {}",
group_id, e
);
let mut empty_assignments = HashMap::new();
for consumer_id in group.members.keys() {
empty_assignments.insert(consumer_id.clone(), Vec::new());
}
Ok(empty_assignments)
}
}
}
pub async fn needs_rebalance(&self, group_id: &ConsumerGroupId) -> bool {
let groups = self.groups.read().await;
if let Some(group) = groups.get(group_id) {
match group.state {
ConsumerGroupState::PreparingRebalance
| ConsumerGroupState::CompletingRebalance => true,
_ => false,
}
} else {
false
}
}
pub async fn trigger_rebalance(&self, group_id: &ConsumerGroupId) -> Result<()> {
let mut groups = self.groups.write().await;
if let Some(group) = groups.get_mut(group_id) {
if !group.members.is_empty() {
group.state = ConsumerGroupState::PreparingRebalance;
group.generation_id += 1;
group.state_timestamp = SystemTime::now();
info!(
"Manually triggered rebalance for group {} (generation {})",
group_id, group.generation_id
);
let _ = self
.state_change_tx
.send(GroupStateChange::RebalanceTriggered {
group_id: group_id.clone(),
});
}
}
Ok(())
}
pub async fn get_group_stats(&self, group_id: &ConsumerGroupId) -> Option<GroupStats> {
let groups = self.groups.read().await;
groups.get(group_id).map(|group| {
let total_partitions: usize = group
.members
.values()
.map(|member| member.assigned_partitions.len())
.sum();
GroupStats {
group_id: group_id.clone(),
state: group.state.clone(),
member_count: group.members.len(),
generation_id: group.generation_id,
total_assigned_partitions: total_partitions,
leader_id: group.leader_id.clone(),
assignment_strategy: self.config.default_assignment_strategy.clone(),
created_at: group.created_at,
last_state_change: group.state_timestamp,
}
})
}
async fn commit_offsets_internal(
&self,
group_id: ConsumerGroupId,
consumer_id: ConsumerId,
generation_id: i32,
retention_time_ms: i64,
offsets: Vec<TopicPartitionOffset>,
) -> std::result::Result<Vec<TopicPartitionError>, i16> {
{
let groups = self.groups.read().await;
let group = groups.get(&group_id).ok_or(error_codes::UNKNOWN_GROUP_ID)?;
if group.generation_id != generation_id {
return Err(error_codes::ILLEGAL_GENERATION);
}
if !group.members.contains_key(&consumer_id) {
return Err(error_codes::UNKNOWN_CONSUMER_ID);
}
}
let now = SystemTime::now();
let expire_timestamp = if retention_time_ms > 0 {
Some(now + Duration::from_millis(retention_time_ms as u64))
} else {
None
};
let mut topic_partition_errors = Vec::new();
let mut offset_storage = self.offset_storage.write().await;
for offset in offsets {
if self.topic_manager.get_topic(&offset.topic).is_none() {
topic_partition_errors.push(TopicPartitionError {
topic: offset.topic,
partition: offset.partition,
error_code: super::error_codes::UNKNOWN_CONSUMER_ID, });
continue;
}
if !self
.topic_manager
.partition_exists(&offset.topic, offset.partition)
{
topic_partition_errors.push(TopicPartitionError {
topic: offset.topic,
partition: offset.partition,
error_code: super::error_codes::UNKNOWN_CONSUMER_ID, });
continue;
}
let key = (group_id.clone(), offset.topic.clone(), offset.partition);
let consumer_offset = ConsumerOffset {
group_id: group_id.clone(),
topic: offset.topic.clone(),
partition: offset.partition,
offset: offset.offset,
metadata: offset.metadata,
commit_timestamp: now,
expire_timestamp,
};
offset_storage.insert(key, consumer_offset);
if let Some(_storage) = &self.storage {
debug!(
"Would persist offset to disk: group={}, topic={}, partition={}, offset={}",
group_id, offset.topic, offset.partition, offset.offset
);
}
}
Ok(topic_partition_errors)
}
async fn fetch_offsets_internal(
&self,
group_id: ConsumerGroupId,
topic_partitions: Option<Vec<TopicPartition>>,
) -> std::result::Result<Vec<TopicPartitionOffsetResult>, i16> {
{
let groups = self.groups.read().await;
if !groups.contains_key(&group_id) {
return Err(error_codes::UNKNOWN_GROUP_ID);
}
}
let offset_storage = self.offset_storage.read().await;
let mut results = Vec::new();
match topic_partitions {
Some(partitions) => {
for tp in partitions {
let key = (group_id.clone(), tp.topic.clone(), tp.partition);
if let Some(stored_offset) = offset_storage.get(&key) {
let is_expired = stored_offset
.expire_timestamp
.map(|expire_time| SystemTime::now() > expire_time)
.unwrap_or(false);
if !is_expired {
results.push(TopicPartitionOffsetResult {
topic: tp.topic,
partition: tp.partition,
offset: stored_offset.offset,
leader_epoch: -1, metadata: stored_offset.metadata.clone(),
error_code: error_codes::NONE,
});
} else {
results.push(TopicPartitionOffsetResult {
topic: tp.topic,
partition: tp.partition,
offset: -1,
leader_epoch: -1,
metadata: None,
error_code: error_codes::NONE,
});
}
} else {
results.push(TopicPartitionOffsetResult {
topic: tp.topic,
partition: tp.partition,
offset: -1, leader_epoch: -1,
metadata: None,
error_code: error_codes::NONE,
});
}
}
}
None => {
for ((stored_group_id, topic, partition), stored_offset) in offset_storage.iter() {
if stored_group_id == &group_id {
let is_expired = stored_offset
.expire_timestamp
.map(|expire_time| SystemTime::now() > expire_time)
.unwrap_or(false);
if !is_expired {
results.push(TopicPartitionOffsetResult {
topic: topic.clone(),
partition: *partition,
offset: stored_offset.offset,
leader_epoch: -1, metadata: stored_offset.metadata.clone(),
error_code: error_codes::NONE,
});
}
}
}
}
}
Ok(results)
}
async fn offset_cleanup_loop(&self) {
let mut interval = interval(Duration::from_millis(
self.config.consumer_expiration_check_interval_ms * 2, ));
loop {
interval.tick().await;
if let Err(e) = self.cleanup_expired_offsets().await {
error!("Error cleaning up expired offsets: {}", e);
}
}
}
pub async fn cleanup_expired_offsets(&self) -> Result<()> {
let mut offset_storage = self.offset_storage.write().await;
let now = SystemTime::now();
offset_storage.retain(|_key, offset| {
offset
.expire_timestamp
.map(|expire_time| now <= expire_time)
.unwrap_or(true) });
Ok(())
}
fn serialize_member_metadata(&self, subscribed_topics: &[TopicName]) -> Vec<u8> {
let mut serialized = Vec::new();
serialized.extend_from_slice(&(subscribed_topics.len() as u32).to_be_bytes());
for topic in subscribed_topics {
serialized.extend_from_slice(&(topic.len() as u32).to_be_bytes());
serialized.extend_from_slice(topic.as_bytes());
}
serialized
}
fn serialize_member_assignment(&self, assigned_partitions: &[TopicPartition]) -> Vec<u8> {
let mut serialized = Vec::new();
serialized.extend_from_slice(&(assigned_partitions.len() as u32).to_be_bytes());
for tp in assigned_partitions {
serialized.extend_from_slice(&(tp.topic.len() as u32).to_be_bytes());
serialized.extend_from_slice(tp.topic.as_bytes());
serialized.extend_from_slice(&tp.partition.to_be_bytes());
}
serialized
}
pub async fn get_all_groups(&self) -> Vec<ConsumerGroupId> {
let groups = self.groups.read().await;
groups.keys().cloned().collect()
}
pub async fn get_group_metadata(
&self,
group_id: &ConsumerGroupId,
) -> Option<ConsumerGroupMetadata> {
let groups = self.groups.read().await;
groups.get(group_id).cloned()
}
pub async fn is_group_active(&self, group_id: &ConsumerGroupId) -> bool {
let groups = self.groups.read().await;
groups
.get(group_id)
.map(|group| {
!matches!(
group.state,
ConsumerGroupState::Dead | ConsumerGroupState::Empty
)
})
.unwrap_or(false)
}
async fn load_persisted_metadata(&self) -> Result<()> {
let Some(metadata_dir) = self.get_metadata_dir() else {
debug!("No metadata directory configured, skipping persistence loading");
return Ok(());
};
if !metadata_dir.exists() {
debug!("Metadata directory does not exist: {:?}", metadata_dir);
return Ok(());
}
let mut groups = self.groups.write().await;
let mut loaded_count = 0;
let mut entries = fs::read_dir(metadata_dir).await?;
while let Some(entry) = entries.next_entry().await? {
let path = entry.path();
if let Some(extension) = path.extension() {
if extension == "json"
&& path
.file_name()
.and_then(|name| name.to_str())
.map(|name| name.starts_with("group_"))
.unwrap_or(false)
{
match self.load_group_metadata_file(&path).await {
Ok(group_metadata) => {
info!(
"Loaded persisted group metadata for group: {}",
group_metadata.group_id
);
groups.insert(group_metadata.group_id.clone(), group_metadata);
loaded_count += 1;
}
Err(e) => {
warn!("Failed to load group metadata from {:?}: {}", path, e);
}
}
}
}
}
info!(
"Loaded {} consumer group metadata files from disk",
loaded_count
);
Ok(())
}
async fn load_group_metadata_file(&self, path: &Path) -> Result<ConsumerGroupMetadata> {
let content = fs::read_to_string(path).await?;
let serializable: SerializableGroupMetadata = serde_json::from_str(&content)?;
let metadata = self.from_serializable_metadata(serializable);
Ok(metadata)
}
fn from_serializable_metadata(
&self,
serializable: SerializableGroupMetadata,
) -> ConsumerGroupMetadata {
let members = serializable
.members
.into_iter()
.map(|(consumer_id, serializable_member)| {
let member = ConsumerGroupMember {
consumer_id: serializable_member.consumer_id,
group_id: serializable_member.group_id,
client_id: serializable_member.client_id,
client_host: serializable_member.client_host,
session_timeout_ms: serializable_member.session_timeout_ms,
rebalance_timeout_ms: serializable_member.rebalance_timeout_ms,
subscribed_topics: serializable_member.subscribed_topics,
assigned_partitions: serializable_member.assigned_partitions,
last_heartbeat: SystemTime::UNIX_EPOCH
+ Duration::from_secs(serializable_member.last_heartbeat),
is_leader: serializable_member.is_leader,
};
(consumer_id, member)
})
.collect();
ConsumerGroupMetadata {
group_id: serializable.group_id,
state: serializable.state,
protocol_type: serializable.protocol_type,
protocol_name: serializable.protocol_name,
leader_id: serializable.leader_id,
members,
generation_id: serializable.generation_id,
created_at: SystemTime::UNIX_EPOCH + Duration::from_secs(serializable.created_at),
state_timestamp: SystemTime::UNIX_EPOCH
+ Duration::from_secs(serializable.state_timestamp),
}
}
async fn metadata_persistence_loop(&self) {
let mut interval = interval(Duration::from_millis(
self.config.consumer_expiration_check_interval_ms * 10, ));
loop {
interval.tick().await;
if let Err(e) = self.persist_metadata_changes().await {
error!("Error persisting group metadata: {}", e);
}
}
}
async fn persist_metadata_changes(&self) -> Result<()> {
let Some(metadata_dir) = self.get_metadata_dir() else {
return Ok(());
};
if !metadata_dir.exists() {
fs::create_dir_all(metadata_dir).await?;
}
let groups = self.groups.read().await;
let mut persisted_count = 0;
for group in groups.values() {
if !matches!(group.state, ConsumerGroupState::Dead) {
match self.persist_group_metadata(group).await {
Ok(()) => persisted_count += 1,
Err(e) => {
error!(
"Failed to persist metadata for group {}: {}",
group.group_id, e
);
}
}
}
}
if persisted_count > 0 {
debug!(
"Persisted {} consumer group metadata files",
persisted_count
);
}
Ok(())
}
async fn persist_group_metadata(&self, group: &ConsumerGroupMetadata) -> Result<()> {
let Some(file_path) = self.get_metadata_file_path(&group.group_id) else {
return Ok(());
};
let serializable = self.to_serializable_metadata(group);
let json_content = serde_json::to_string_pretty(&serializable)?;
fs::write(&file_path, json_content).await?;
debug!(
"Persisted metadata for group {} to {:?}",
group.group_id, file_path
);
Ok(())
}
fn to_serializable_metadata(
&self,
metadata: &ConsumerGroupMetadata,
) -> SerializableGroupMetadata {
let members = metadata
.members
.iter()
.map(|(consumer_id, member)| {
let serializable_member = SerializableGroupMember {
consumer_id: member.consumer_id.clone(),
group_id: member.group_id.clone(),
client_id: member.client_id.clone(),
client_host: member.client_host.clone(),
session_timeout_ms: member.session_timeout_ms,
rebalance_timeout_ms: member.rebalance_timeout_ms,
subscribed_topics: member.subscribed_topics.clone(),
assigned_partitions: member.assigned_partitions.clone(),
last_heartbeat: member
.last_heartbeat
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or(Duration::from_secs(0))
.as_secs(),
is_leader: member.is_leader,
};
(consumer_id.clone(), serializable_member)
})
.collect();
SerializableGroupMetadata {
group_id: metadata.group_id.clone(),
state: metadata.state.clone(),
protocol_type: metadata.protocol_type.clone(),
protocol_name: metadata.protocol_name.clone(),
leader_id: metadata.leader_id.clone(),
members,
generation_id: metadata.generation_id,
created_at: metadata
.created_at
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or(Duration::from_secs(0))
.as_secs(),
state_timestamp: metadata
.state_timestamp
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or(Duration::from_secs(0))
.as_secs(),
}
}
pub async fn force_persist_metadata(&self) -> Result<()> {
self.persist_metadata_changes().await
}
pub async fn remove_persisted_group(&self, group_id: &ConsumerGroupId) -> Result<()> {
if let Some(file_path) = self.get_metadata_file_path(group_id) {
if file_path.exists() {
fs::remove_file(&file_path).await?;
debug!(
"Removed persisted metadata file for dead group {}: {:?}",
group_id, file_path
);
}
}
Ok(())
}
}
pub struct ConsumerGroupManager {
coordinators: HashMap<String, Arc<ConsumerGroupCoordinator>>,
}
impl ConsumerGroupManager {
pub fn new() -> Self {
Self {
coordinators: HashMap::new(),
}
}
pub fn add_coordinator(&mut self, key: String, coordinator: Arc<ConsumerGroupCoordinator>) {
self.coordinators.insert(key, coordinator);
}
pub fn get_coordinator(&self, _group_id: &str) -> Option<Arc<ConsumerGroupCoordinator>> {
self.coordinators.get("default").cloned()
}
}
impl ConsumerGroupCoordinator {
fn select_compatible_protocol(&self, group_protocols: &[GroupProtocol]) -> String {
let supported_protocols = ["range", "roundrobin", "sticky"];
for protocol in group_protocols {
if supported_protocols.contains(&protocol.name.as_str()) {
return protocol.name.clone();
}
}
for protocol in group_protocols {
if protocol.name == "consumer" {
return "range".to_string();
}
}
"range".to_string()
}
fn extract_subscribed_topics(&self, group_protocols: &[GroupProtocol]) -> Vec<TopicName> {
let mut topics = Vec::new();
for protocol in group_protocols {
if protocol.name == "consumer"
|| protocol.name == "range"
|| protocol.name == "roundrobin"
{
if let Ok(parsed_topics) = self.parse_consumer_protocol_metadata(&protocol.metadata)
{
topics.extend(parsed_topics);
}
}
}
if topics.is_empty() {
debug!("No topics found in protocol metadata, will use topic inference");
}
topics
}
fn parse_consumer_protocol_metadata(&self, metadata: &[u8]) -> Result<Vec<TopicName>> {
use bytes::Buf;
if metadata.len() < 6 {
return Ok(Vec::new());
}
let mut cursor = std::io::Cursor::new(metadata);
let _version = cursor.get_i16();
let topic_count = cursor.get_i32();
if topic_count < 0 || topic_count > 1000 {
return Ok(Vec::new());
}
let mut topics = Vec::new();
for _ in 0..topic_count {
if cursor.remaining() < 2 {
break;
}
let topic_len = cursor.get_i16();
if topic_len <= 0 || cursor.remaining() < topic_len as usize {
break;
}
let mut topic_bytes = vec![0u8; topic_len as usize];
cursor.copy_to_slice(&mut topic_bytes);
if let Ok(topic_name) = String::from_utf8(topic_bytes) {
topics.push(topic_name);
}
}
Ok(topics)
}
}