use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::time::{Duration, SystemTime};
pub type GroupId = String;
pub type MemberId = String;
pub type GroupInstanceId = String;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum GroupState {
Empty,
PreparingRebalance,
CompletingRebalance,
Stable,
Dead,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum AssignmentStrategy {
#[default]
Range,
RoundRobin,
Sticky,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum RebalanceProtocol {
#[default]
Eager,
Cooperative,
}
impl RebalanceProtocol {
pub fn select_common(protocols: &[Self]) -> Self {
if protocols.is_empty() {
return RebalanceProtocol::Eager;
}
if protocols
.iter()
.all(|p| *p == RebalanceProtocol::Cooperative)
{
RebalanceProtocol::Cooperative
} else {
RebalanceProtocol::Eager
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct GroupMember {
pub member_id: MemberId,
pub group_instance_id: Option<GroupInstanceId>,
pub client_id: String,
pub subscriptions: Vec<String>,
pub assignment: Vec<PartitionAssignment>,
pub pending_revocation: Vec<PartitionAssignment>,
#[serde(
serialize_with = "serialize_systemtime",
deserialize_with = "deserialize_systemtime"
)]
pub last_heartbeat: SystemTime,
pub metadata: Vec<u8>,
pub is_static: bool,
pub supported_protocols: Vec<RebalanceProtocol>,
}
fn serialize_systemtime<S>(time: &SystemTime, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let duration = time
.duration_since(SystemTime::UNIX_EPOCH)
.map_err(serde::ser::Error::custom)?;
serializer.serialize_u128(duration.as_millis())
}
fn deserialize_systemtime<'de, D>(deserializer: D) -> Result<SystemTime, D::Error>
where
D: serde::Deserializer<'de>,
{
let millis = u128::deserialize(deserializer)?;
let millis_u64 = u64::try_from(millis).map_err(serde::de::Error::custom)?;
Ok(SystemTime::UNIX_EPOCH + std::time::Duration::from_millis(millis_u64))
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct PartitionAssignment {
pub topic: String,
pub partition: u32,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum RebalanceResult {
Complete,
AwaitingRevocations {
revocations: HashMap<MemberId, Vec<PartitionAssignment>>,
pending_assignments: HashMap<MemberId, Vec<PartitionAssignment>>,
},
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct ConsumerGroup {
pub group_id: GroupId,
pub state: GroupState,
pub generation_id: u32,
pub leader_id: Option<MemberId>,
pub protocol_name: String,
pub assignment_strategy: AssignmentStrategy,
pub rebalance_protocol: RebalanceProtocol,
pub members: HashMap<MemberId, GroupMember>,
pub static_members: HashMap<GroupInstanceId, MemberId>,
#[serde(skip)]
pub pending_static_members:
HashMap<GroupInstanceId, (Vec<PartitionAssignment>, std::time::Instant)>,
pub awaiting_revocation: HashMap<MemberId, Vec<PartitionAssignment>>,
pub offsets: HashMap<String, HashMap<u32, i64>>,
#[serde(
serialize_with = "serialize_duration",
deserialize_with = "deserialize_duration"
)]
pub session_timeout: Duration,
#[serde(
serialize_with = "serialize_duration",
deserialize_with = "deserialize_duration"
)]
pub rebalance_timeout: Duration,
}
fn serialize_duration<S>(duration: &Duration, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_u64(duration.as_millis() as u64)
}
fn deserialize_duration<'de, D>(deserializer: D) -> Result<Duration, D::Error>
where
D: serde::Deserializer<'de>,
{
let millis = u64::deserialize(deserializer)?;
Ok(Duration::from_millis(millis))
}
impl ConsumerGroup {
pub fn new(group_id: GroupId, session_timeout: Duration, rebalance_timeout: Duration) -> Self {
Self {
group_id,
state: GroupState::Empty,
generation_id: 0,
leader_id: None,
protocol_name: "consumer".to_string(),
assignment_strategy: AssignmentStrategy::default(),
rebalance_protocol: RebalanceProtocol::Eager,
members: HashMap::new(),
static_members: HashMap::new(),
pending_static_members: HashMap::new(),
awaiting_revocation: HashMap::new(),
offsets: HashMap::new(),
session_timeout,
rebalance_timeout,
}
}
pub fn add_member(
&mut self,
member_id: MemberId,
client_id: String,
subscriptions: Vec<String>,
metadata: Vec<u8>,
) {
self.add_member_full(
member_id,
None,
client_id,
subscriptions,
metadata,
vec![RebalanceProtocol::Eager],
)
}
pub fn add_member_with_instance_id(
&mut self,
member_id: MemberId,
group_instance_id: Option<GroupInstanceId>,
client_id: String,
subscriptions: Vec<String>,
metadata: Vec<u8>,
) {
self.add_member_full(
member_id,
group_instance_id,
client_id,
subscriptions,
metadata,
vec![RebalanceProtocol::Eager],
)
}
pub fn add_member_full(
&mut self,
member_id: MemberId,
group_instance_id: Option<GroupInstanceId>,
client_id: String,
subscriptions: Vec<String>,
metadata: Vec<u8>,
supported_protocols: Vec<RebalanceProtocol>,
) {
let is_static = group_instance_id.is_some();
let supported_protocols = if supported_protocols.is_empty() {
vec![RebalanceProtocol::Eager]
} else {
supported_protocols
};
if let Some(ref instance_id) = group_instance_id {
if let Some(old_member_id) = self.static_members.get(instance_id).cloned() {
if old_member_id != member_id {
self.members.remove(&old_member_id);
}
}
if let Some((saved_assignment, _pending_since)) =
self.pending_static_members.remove(instance_id)
{
let member = GroupMember {
member_id: member_id.clone(),
group_instance_id: Some(instance_id.clone()),
client_id,
subscriptions,
assignment: saved_assignment,
pending_revocation: Vec::new(),
last_heartbeat: SystemTime::now(),
metadata,
is_static: true,
supported_protocols: supported_protocols.clone(),
};
self.members.insert(member_id.clone(), member);
self.static_members
.insert(instance_id.clone(), member_id.clone());
self.update_rebalance_protocol();
if self.leader_id.is_none() {
self.leader_id = Some(member_id);
}
if self.state == GroupState::Empty {
self.state = GroupState::Stable;
}
return;
}
self.static_members
.insert(instance_id.clone(), member_id.clone());
}
let member = GroupMember {
member_id: member_id.clone(),
group_instance_id,
client_id,
subscriptions,
assignment: Vec::new(),
pending_revocation: Vec::new(),
last_heartbeat: SystemTime::now(),
metadata,
is_static,
supported_protocols,
};
self.members.insert(member_id.clone(), member);
self.update_rebalance_protocol();
if self.leader_id.is_none() {
self.leader_id = Some(member_id);
}
if self.state != GroupState::Empty {
self.transition_to_preparing_rebalance();
} else if self.members.len() == 1 {
self.state = GroupState::PreparingRebalance;
}
}
pub fn has_static_member(&self, instance_id: &GroupInstanceId) -> bool {
self.static_members.contains_key(instance_id)
}
pub fn get_member_for_instance(&self, instance_id: &GroupInstanceId) -> Option<&MemberId> {
self.static_members.get(instance_id)
}
pub fn fence_static_member(&mut self, instance_id: &GroupInstanceId) -> Option<MemberId> {
if let Some(old_member_id) = self.static_members.get(instance_id).cloned() {
if let Some(old_member) = self.members.get(&old_member_id) {
if !old_member.assignment.is_empty() {
self.pending_static_members.insert(
instance_id.clone(),
(old_member.assignment.clone(), std::time::Instant::now()),
);
}
}
self.members.remove(&old_member_id);
if self.leader_id.as_ref() == Some(&old_member_id) {
self.leader_id = self.members.keys().next().cloned();
}
Some(old_member_id)
} else {
None
}
}
pub fn remove_member(&mut self, member_id: &MemberId) -> bool {
if let Some(member) = self.members.remove(member_id) {
if member.is_static {
if let Some(ref instance_id) = member.group_instance_id {
if !member.assignment.is_empty() {
self.pending_static_members.insert(
instance_id.clone(),
(member.assignment, std::time::Instant::now()),
);
}
}
} else {
if !self.members.is_empty() {
self.transition_to_preparing_rebalance();
}
}
if self.leader_id.as_ref() == Some(member_id) {
self.leader_id = self.members.keys().next().cloned();
}
if self.members.is_empty() {
self.state = GroupState::Empty;
self.generation_id = 0;
self.leader_id = None;
self.static_members.clear();
self.pending_static_members.clear();
}
true
} else {
false
}
}
pub fn remove_static_member(&mut self, instance_id: &GroupInstanceId) -> bool {
if let Some(member_id) = self.static_members.remove(instance_id) {
self.pending_static_members.remove(instance_id);
if self.members.remove(&member_id).is_some() {
if self.leader_id.as_ref() == Some(&member_id) {
self.leader_id = self.members.keys().next().cloned();
}
if self.members.is_empty() {
self.state = GroupState::Empty;
self.generation_id = 0;
self.leader_id = None;
self.static_members.clear();
self.pending_static_members.clear();
} else {
self.transition_to_preparing_rebalance();
}
return true;
}
}
if self.pending_static_members.remove(instance_id).is_some() {
self.static_members.remove(instance_id);
return true;
}
false
}
pub fn heartbeat(&mut self, member_id: &MemberId) -> Result<(), String> {
if let Some(member) = self.members.get_mut(member_id) {
member.last_heartbeat = SystemTime::now();
Ok(())
} else {
Err(format!("Unknown member: {}", member_id))
}
}
pub fn check_timeouts(&mut self) -> Vec<MemberId> {
let now = SystemTime::now();
let mut timed_out = Vec::new();
let mut static_timeouts: Vec<GroupInstanceId> = Vec::new();
for (member_id, member) in &self.members {
if let Ok(elapsed) = now.duration_since(member.last_heartbeat) {
if elapsed > self.session_timeout {
timed_out.push(member_id.clone());
if let Some(ref instance_id) = member.group_instance_id {
static_timeouts.push(instance_id.clone());
}
}
}
}
for member_id in &timed_out {
let member = self.members.get(member_id);
if let Some(m) = member {
if !m.is_static {
self.remove_member(member_id);
}
}
}
for instance_id in &static_timeouts {
self.remove_static_member(instance_id);
}
timed_out
}
pub fn check_pending_static_timeouts(
&mut self,
pending_timeout: Duration,
) -> Vec<GroupInstanceId> {
let now = std::time::Instant::now();
let timed_out: Vec<GroupInstanceId> = self
.pending_static_members
.iter()
.filter(|(_, (_, pending_since))| now.duration_since(*pending_since) >= pending_timeout)
.map(|(id, _)| id.clone())
.collect();
for id in &timed_out {
self.pending_static_members.remove(id);
self.static_members.remove(id);
}
if !timed_out.is_empty() && !self.members.is_empty() {
self.transition_to_preparing_rebalance();
}
timed_out
}
fn update_rebalance_protocol(&mut self) {
let protocols: Vec<RebalanceProtocol> = self
.members
.values()
.flat_map(|m| {
if m.supported_protocols
.contains(&RebalanceProtocol::Cooperative)
{
Some(RebalanceProtocol::Cooperative)
} else {
Some(RebalanceProtocol::Eager)
}
})
.collect();
self.rebalance_protocol = RebalanceProtocol::select_common(&protocols);
}
pub fn is_cooperative(&self) -> bool {
self.rebalance_protocol == RebalanceProtocol::Cooperative
}
pub fn compute_revocations(
&self,
new_assignments: &HashMap<MemberId, Vec<PartitionAssignment>>,
) -> HashMap<MemberId, Vec<PartitionAssignment>> {
let mut revocations: HashMap<MemberId, Vec<PartitionAssignment>> = HashMap::new();
for (member_id, member) in &self.members {
let new_assignment = new_assignments.get(member_id);
let mut to_revoke = Vec::new();
for partition in &member.assignment {
let still_assigned = new_assignment
.map(|a| a.contains(partition))
.unwrap_or(false);
if !still_assigned {
to_revoke.push(partition.clone());
}
}
if !to_revoke.is_empty() {
revocations.insert(member_id.clone(), to_revoke);
}
}
revocations
}
pub fn request_revocations(
&mut self,
revocations: HashMap<MemberId, Vec<PartitionAssignment>>,
) {
self.awaiting_revocation = revocations.clone();
for (member_id, partitions) in revocations {
if let Some(member) = self.members.get_mut(&member_id) {
member.pending_revocation = partitions;
}
}
self.state = GroupState::CompletingRebalance;
}
pub fn acknowledge_revocation(&mut self, member_id: &MemberId) -> bool {
self.awaiting_revocation.remove(member_id);
if let Some(member) = self.members.get_mut(member_id) {
let revoked: HashSet<_> = member.pending_revocation.drain(..).collect();
member.assignment.retain(|p| !revoked.contains(p));
}
self.awaiting_revocation.is_empty()
}
pub fn has_pending_revocations(&self) -> bool {
!self.awaiting_revocation.is_empty()
}
pub fn get_pending_revocations(&self, member_id: &MemberId) -> Vec<PartitionAssignment> {
self.members
.get(member_id)
.map(|m| m.pending_revocation.clone())
.unwrap_or_default()
}
pub fn complete_cooperative_rebalance(
&mut self,
final_assignments: HashMap<MemberId, Vec<PartitionAssignment>>,
) {
for (member_id, new_partitions) in final_assignments {
if let Some(member) = self.members.get_mut(&member_id) {
for partition in new_partitions {
if !member.assignment.contains(&partition) {
member.assignment.push(partition);
}
}
}
}
self.awaiting_revocation.clear();
self.generation_id = self.generation_id.wrapping_add(1);
self.state = GroupState::Stable;
}
pub fn rebalance_with_strategy(
&mut self,
new_assignments: HashMap<MemberId, Vec<PartitionAssignment>>,
) -> RebalanceResult {
match self.rebalance_protocol {
RebalanceProtocol::Eager => {
self.complete_rebalance(new_assignments);
RebalanceResult::Complete
}
RebalanceProtocol::Cooperative => {
let revocations = self.compute_revocations(&new_assignments);
if revocations.is_empty() {
self.complete_cooperative_rebalance(new_assignments);
RebalanceResult::Complete
} else {
self.request_revocations(revocations.clone());
RebalanceResult::AwaitingRevocations {
revocations,
pending_assignments: new_assignments,
}
}
}
}
}
fn transition_to_preparing_rebalance(&mut self) {
if self.state != GroupState::Empty {
self.state = GroupState::PreparingRebalance;
}
}
pub fn complete_rebalance(&mut self, assignments: HashMap<MemberId, Vec<PartitionAssignment>>) {
for (member_id, partitions) in assignments {
if let Some(member) = self.members.get_mut(&member_id) {
member.assignment = partitions;
}
}
self.generation_id = self.generation_id.wrapping_add(1);
self.state = GroupState::Stable;
}
pub fn commit_offset(&mut self, topic: &str, partition: u32, offset: i64) {
self.offsets
.entry(topic.to_string())
.or_default()
.insert(partition, offset);
}
pub fn fetch_offset(&self, topic: &str, partition: u32) -> Option<i64> {
self.offsets.get(topic)?.get(&partition).copied()
}
pub fn all_assignments(&self) -> HashMap<PartitionAssignment, MemberId> {
let mut assignments = HashMap::new();
for (member_id, member) in &self.members {
for partition in &member.assignment {
assignments.insert(partition.clone(), member_id.clone());
}
}
assignments
}
}
pub mod assignment {
use super::*;
pub fn range_assignment(
members: &[MemberId],
topic_partitions: &HashMap<String, u32>,
) -> HashMap<MemberId, Vec<PartitionAssignment>> {
let mut assignments: HashMap<MemberId, Vec<PartitionAssignment>> = HashMap::new();
if members.is_empty() {
return assignments;
}
for (topic, partition_count) in topic_partitions {
let partitions_per_member = partition_count / members.len() as u32;
let extra_partitions = partition_count % members.len() as u32;
let mut current_partition = 0;
for (idx, member_id) in members.iter().enumerate() {
let mut member_partitions = partitions_per_member;
if (idx as u32) < extra_partitions {
member_partitions += 1;
}
for _ in 0..member_partitions {
assignments
.entry(member_id.clone())
.or_default()
.push(PartitionAssignment {
topic: topic.clone(),
partition: current_partition,
});
current_partition += 1;
}
}
}
assignments
}
pub fn round_robin_assignment(
members: &[MemberId],
topic_partitions: &HashMap<String, u32>,
) -> HashMap<MemberId, Vec<PartitionAssignment>> {
let mut assignments: HashMap<MemberId, Vec<PartitionAssignment>> = HashMap::new();
if members.is_empty() {
return assignments;
}
let mut member_idx = 0;
for (topic, partition_count) in topic_partitions {
for partition in 0..*partition_count {
assignments
.entry(members[member_idx].clone())
.or_default()
.push(PartitionAssignment {
topic: topic.clone(),
partition,
});
member_idx = (member_idx + 1) % members.len();
}
}
assignments
}
pub fn sticky_assignment(
members: &[MemberId],
topic_partitions: &HashMap<String, u32>,
previous_assignments: &HashMap<MemberId, Vec<PartitionAssignment>>,
) -> HashMap<MemberId, Vec<PartitionAssignment>> {
let mut assignments: HashMap<MemberId, Vec<PartitionAssignment>> = HashMap::new();
if members.is_empty() {
return assignments;
}
let mut all_partitions = Vec::new();
for (topic, partition_count) in topic_partitions {
for partition in 0..*partition_count {
all_partitions.push(PartitionAssignment {
topic: topic.clone(),
partition,
});
}
}
let mut assigned: HashSet<PartitionAssignment> = HashSet::new();
for member_id in members {
if let Some(prev_partitions) = previous_assignments.get(member_id) {
let valid_partitions: Vec<_> = prev_partitions
.iter()
.filter(|p| all_partitions.contains(p) && !assigned.contains(p))
.cloned()
.collect();
for partition in &valid_partitions {
assigned.insert(partition.clone());
}
assignments.insert(member_id.clone(), valid_partitions);
}
}
let unassigned: Vec<_> = all_partitions
.into_iter()
.filter(|p| !assigned.contains(p))
.collect();
let mut member_idx = 0;
for partition in unassigned {
assignments
.entry(members[member_idx].clone())
.or_default()
.push(partition);
member_idx = (member_idx + 1) % members.len();
}
assignments
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_consumer_group_creation() {
let group = ConsumerGroup::new(
"test-group".to_string(),
Duration::from_secs(30),
Duration::from_secs(60),
);
assert_eq!(group.group_id, "test-group");
assert_eq!(group.state, GroupState::Empty);
assert_eq!(group.generation_id, 0);
assert!(group.leader_id.is_none());
assert!(group.members.is_empty());
}
#[test]
fn test_add_first_member_becomes_leader() {
let mut group = ConsumerGroup::new(
"test-group".to_string(),
Duration::from_secs(30),
Duration::from_secs(60),
);
group.add_member(
"member-1".to_string(),
"client-1".to_string(),
vec!["topic-1".to_string()],
vec![],
);
assert_eq!(group.members.len(), 1);
assert_eq!(group.leader_id, Some("member-1".to_string()));
assert_eq!(group.state, GroupState::PreparingRebalance);
}
#[test]
fn test_remove_member_triggers_rebalance() {
let mut group = ConsumerGroup::new(
"test-group".to_string(),
Duration::from_secs(30),
Duration::from_secs(60),
);
group.add_member(
"member-1".to_string(),
"client-1".to_string(),
vec![],
vec![],
);
group.add_member(
"member-2".to_string(),
"client-2".to_string(),
vec![],
vec![],
);
group.state = GroupState::Stable;
group.remove_member(&"member-2".to_string());
assert_eq!(group.members.len(), 1);
assert_eq!(group.state, GroupState::PreparingRebalance);
}
#[test]
fn test_remove_last_member_transitions_to_empty() {
let mut group = ConsumerGroup::new(
"test-group".to_string(),
Duration::from_secs(30),
Duration::from_secs(60),
);
group.add_member(
"member-1".to_string(),
"client-1".to_string(),
vec![],
vec![],
);
group.remove_member(&"member-1".to_string());
assert_eq!(group.state, GroupState::Empty);
assert_eq!(group.generation_id, 0);
assert!(group.leader_id.is_none());
}
#[test]
fn test_offset_commit_and_fetch() {
let mut group = ConsumerGroup::new(
"test-group".to_string(),
Duration::from_secs(30),
Duration::from_secs(60),
);
group.commit_offset("topic-1", 0, 100);
group.commit_offset("topic-1", 1, 200);
group.commit_offset("topic-2", 0, 300);
assert_eq!(group.fetch_offset("topic-1", 0), Some(100));
assert_eq!(group.fetch_offset("topic-1", 1), Some(200));
assert_eq!(group.fetch_offset("topic-2", 0), Some(300));
assert_eq!(group.fetch_offset("topic-1", 2), None);
}
#[test]
fn test_range_assignment() {
let members = vec!["m1".to_string(), "m2".to_string(), "m3".to_string()];
let mut topic_partitions = HashMap::new();
topic_partitions.insert("topic-1".to_string(), 10);
let assignments = assignment::range_assignment(&members, &topic_partitions);
assert_eq!(assignments.get("m1").unwrap().len(), 4);
assert_eq!(assignments.get("m2").unwrap().len(), 3);
assert_eq!(assignments.get("m3").unwrap().len(), 3);
let m1_partitions: Vec<u32> = assignments
.get("m1")
.unwrap()
.iter()
.map(|p| p.partition)
.collect();
assert_eq!(m1_partitions, vec![0, 1, 2, 3]);
}
#[test]
fn test_round_robin_assignment() {
let members = vec!["m1".to_string(), "m2".to_string(), "m3".to_string()];
let mut topic_partitions = HashMap::new();
topic_partitions.insert("topic-1".to_string(), 10);
let assignments = assignment::round_robin_assignment(&members, &topic_partitions);
assert_eq!(assignments.get("m1").unwrap().len(), 4);
assert_eq!(assignments.get("m2").unwrap().len(), 3);
assert_eq!(assignments.get("m3").unwrap().len(), 3);
let m1_partitions: Vec<u32> = assignments
.get("m1")
.unwrap()
.iter()
.map(|p| p.partition)
.collect();
assert_eq!(m1_partitions, vec![0, 3, 6, 9]);
}
#[test]
fn test_sticky_assignment_preserves_assignments() {
let members = vec!["m1".to_string(), "m2".to_string()];
let mut topic_partitions = HashMap::new();
topic_partitions.insert("topic-1".to_string(), 4);
let mut previous = HashMap::new();
previous.insert(
"m1".to_string(),
vec![
PartitionAssignment {
topic: "topic-1".to_string(),
partition: 0,
},
PartitionAssignment {
topic: "topic-1".to_string(),
partition: 1,
},
],
);
previous.insert(
"m2".to_string(),
vec![
PartitionAssignment {
topic: "topic-1".to_string(),
partition: 2,
},
PartitionAssignment {
topic: "topic-1".to_string(),
partition: 3,
},
],
);
let assignments = assignment::sticky_assignment(&members, &topic_partitions, &previous);
assert_eq!(assignments.get("m1").unwrap().len(), 2);
assert_eq!(assignments.get("m2").unwrap().len(), 2);
let m1_partitions: HashSet<u32> = assignments
.get("m1")
.unwrap()
.iter()
.map(|p| p.partition)
.collect();
assert!(m1_partitions.contains(&0));
assert!(m1_partitions.contains(&1));
}
#[test]
fn test_sticky_assignment_redistributes_on_new_member() {
let members = vec!["m1".to_string(), "m2".to_string(), "m3".to_string()];
let mut topic_partitions = HashMap::new();
topic_partitions.insert("topic-1".to_string(), 6);
let mut previous = HashMap::new();
previous.insert(
"m1".to_string(),
vec![
PartitionAssignment {
topic: "topic-1".to_string(),
partition: 0,
},
PartitionAssignment {
topic: "topic-1".to_string(),
partition: 1,
},
PartitionAssignment {
topic: "topic-1".to_string(),
partition: 2,
},
],
);
previous.insert(
"m2".to_string(),
vec![
PartitionAssignment {
topic: "topic-1".to_string(),
partition: 3,
},
PartitionAssignment {
topic: "topic-1".to_string(),
partition: 4,
},
PartitionAssignment {
topic: "topic-1".to_string(),
partition: 5,
},
],
);
let assignments = assignment::sticky_assignment(&members, &topic_partitions, &previous);
let total_assigned: usize = assignments.values().map(|v| v.len()).sum();
assert_eq!(total_assigned, 6, "All 6 partitions should be assigned");
let m1_partitions: HashSet<u32> = assignments
.get("m1")
.unwrap()
.iter()
.map(|p| p.partition)
.collect();
let m2_partitions: HashSet<u32> = assignments
.get("m2")
.unwrap()
.iter()
.map(|p| p.partition)
.collect();
let m1_kept = m1_partitions.iter().filter(|p| **p <= 2).count();
let m2_kept = m2_partitions.iter().filter(|p| **p >= 3).count();
assert!(
m1_kept > 0 || m2_kept > 0,
"Sticky assignment should preserve some assignments"
);
}
#[test]
fn test_static_member_add() {
let mut group = ConsumerGroup::new(
"test-group".to_string(),
Duration::from_secs(30),
Duration::from_secs(60),
);
group.add_member_with_instance_id(
"member-1".to_string(),
Some("instance-1".to_string()),
"client-1".to_string(),
vec!["topic-1".to_string()],
vec![],
);
assert_eq!(group.members.len(), 1);
assert!(group.has_static_member(&"instance-1".to_string()));
assert_eq!(
group.get_member_for_instance(&"instance-1".to_string()),
Some(&"member-1".to_string())
);
let member = group.members.get("member-1").unwrap();
assert!(member.is_static);
assert_eq!(member.group_instance_id, Some("instance-1".to_string()));
}
#[test]
fn test_static_member_rejoin_no_rebalance() {
let mut group = ConsumerGroup::new(
"test-group".to_string(),
Duration::from_secs(30),
Duration::from_secs(60),
);
group.add_member_with_instance_id(
"member-1".to_string(),
Some("instance-1".to_string()),
"client-1".to_string(),
vec!["topic-1".to_string()],
vec![],
);
group.add_member(
"member-2".to_string(),
"client-2".to_string(),
vec!["topic-1".to_string()],
vec![],
);
let mut assignments = HashMap::new();
assignments.insert(
"member-1".to_string(),
vec![PartitionAssignment {
topic: "topic-1".to_string(),
partition: 0,
}],
);
assignments.insert(
"member-2".to_string(),
vec![PartitionAssignment {
topic: "topic-1".to_string(),
partition: 1,
}],
);
group.complete_rebalance(assignments);
assert_eq!(group.state, GroupState::Stable);
let gen_before = group.generation_id;
group.remove_member(&"member-1".to_string());
assert!(group.pending_static_members.contains_key("instance-1"));
group.add_member_with_instance_id(
"member-1-new".to_string(),
Some("instance-1".to_string()),
"client-1".to_string(),
vec!["topic-1".to_string()],
vec![],
);
let member = group.members.get("member-1-new").unwrap();
assert_eq!(member.assignment.len(), 1);
assert_eq!(member.assignment[0].partition, 0);
assert_eq!(group.generation_id, gen_before);
assert!(!group.pending_static_members.contains_key("instance-1"));
}
#[test]
fn test_static_member_fencing() {
let mut group = ConsumerGroup::new(
"test-group".to_string(),
Duration::from_secs(30),
Duration::from_secs(60),
);
group.add_member_with_instance_id(
"member-1".to_string(),
Some("instance-1".to_string()),
"client-1".to_string(),
vec!["topic-1".to_string()],
vec![],
);
assert!(group.members.contains_key("member-1"));
group.add_member_with_instance_id(
"member-1-new".to_string(),
Some("instance-1".to_string()),
"client-1".to_string(),
vec!["topic-1".to_string()],
vec![],
);
assert!(!group.members.contains_key("member-1"));
assert!(group.members.contains_key("member-1-new"));
assert_eq!(
group.get_member_for_instance(&"instance-1".to_string()),
Some(&"member-1-new".to_string())
);
}
#[test]
fn test_dynamic_member_removal_triggers_rebalance() {
let mut group = ConsumerGroup::new(
"test-group".to_string(),
Duration::from_secs(30),
Duration::from_secs(60),
);
group.add_member_with_instance_id(
"static-member".to_string(),
Some("instance-1".to_string()),
"client-1".to_string(),
vec!["topic-1".to_string()],
vec![],
);
group.add_member(
"dynamic-member".to_string(),
"client-2".to_string(),
vec!["topic-1".to_string()],
vec![],
);
group.state = GroupState::Stable;
group.remove_member(&"dynamic-member".to_string());
assert_eq!(group.state, GroupState::PreparingRebalance);
}
#[test]
fn test_static_member_timeout_triggers_rebalance() {
let mut group = ConsumerGroup::new(
"test-group".to_string(),
Duration::from_millis(10), Duration::from_secs(60),
);
group.add_member_with_instance_id(
"member-1".to_string(),
Some("instance-1".to_string()),
"client-1".to_string(),
vec!["topic-1".to_string()],
vec![],
);
group.add_member(
"member-2".to_string(),
"client-2".to_string(),
vec!["topic-1".to_string()],
vec![],
);
group.state = GroupState::Stable;
group.remove_static_member(&"instance-1".to_string());
assert_eq!(group.state, GroupState::PreparingRebalance);
assert!(!group.has_static_member(&"instance-1".to_string()));
}
#[test]
fn test_mixed_static_and_dynamic_members() {
let mut group = ConsumerGroup::new(
"test-group".to_string(),
Duration::from_secs(30),
Duration::from_secs(60),
);
group.add_member_with_instance_id(
"static-1".to_string(),
Some("instance-1".to_string()),
"client-1".to_string(),
vec!["topic-1".to_string()],
vec![],
);
group.add_member_with_instance_id(
"static-2".to_string(),
Some("instance-2".to_string()),
"client-2".to_string(),
vec!["topic-1".to_string()],
vec![],
);
group.add_member(
"dynamic-1".to_string(),
"client-3".to_string(),
vec!["topic-1".to_string()],
vec![],
);
assert_eq!(group.members.len(), 3);
assert_eq!(group.static_members.len(), 2);
let static1 = group.members.get("static-1").unwrap();
let static2 = group.members.get("static-2").unwrap();
let dynamic1 = group.members.get("dynamic-1").unwrap();
assert!(static1.is_static);
assert!(static2.is_static);
assert!(!dynamic1.is_static);
}
#[test]
fn test_all_members_leave_clears_static_mappings() {
let mut group = ConsumerGroup::new(
"test-group".to_string(),
Duration::from_secs(30),
Duration::from_secs(60),
);
group.add_member_with_instance_id(
"member-1".to_string(),
Some("instance-1".to_string()),
"client-1".to_string(),
vec!["topic-1".to_string()],
vec![],
);
group.remove_member(&"member-1".to_string());
group.remove_static_member(&"instance-1".to_string());
assert_eq!(group.state, GroupState::Empty);
assert!(group.static_members.is_empty());
assert!(group.pending_static_members.is_empty());
}
#[test]
fn test_rebalance_protocol_selection_all_eager() {
let protocols = vec![RebalanceProtocol::Eager, RebalanceProtocol::Eager];
assert_eq!(
RebalanceProtocol::select_common(&protocols),
RebalanceProtocol::Eager
);
}
#[test]
fn test_rebalance_protocol_selection_all_cooperative() {
let protocols = vec![
RebalanceProtocol::Cooperative,
RebalanceProtocol::Cooperative,
];
assert_eq!(
RebalanceProtocol::select_common(&protocols),
RebalanceProtocol::Cooperative
);
}
#[test]
fn test_rebalance_protocol_selection_mixed() {
let protocols = vec![RebalanceProtocol::Cooperative, RebalanceProtocol::Eager];
assert_eq!(
RebalanceProtocol::select_common(&protocols),
RebalanceProtocol::Eager
);
}
#[test]
fn test_cooperative_member_add_updates_protocol() {
let mut group = ConsumerGroup::new(
"test-group".to_string(),
Duration::from_secs(30),
Duration::from_secs(60),
);
group.add_member_full(
"member-1".to_string(),
None,
"client-1".to_string(),
vec!["topic-1".to_string()],
vec![],
vec![RebalanceProtocol::Cooperative],
);
assert!(group.is_cooperative());
group.add_member_full(
"member-2".to_string(),
None,
"client-2".to_string(),
vec!["topic-1".to_string()],
vec![],
vec![RebalanceProtocol::Eager],
);
assert!(!group.is_cooperative());
assert_eq!(group.rebalance_protocol, RebalanceProtocol::Eager);
}
#[test]
fn test_compute_revocations() {
let mut group = ConsumerGroup::new(
"test-group".to_string(),
Duration::from_secs(30),
Duration::from_secs(60),
);
group.add_member_full(
"member-1".to_string(),
None,
"client-1".to_string(),
vec!["topic-1".to_string()],
vec![],
vec![RebalanceProtocol::Cooperative],
);
group.add_member_full(
"member-2".to_string(),
None,
"client-2".to_string(),
vec!["topic-1".to_string()],
vec![],
vec![RebalanceProtocol::Cooperative],
);
let mut initial = HashMap::new();
initial.insert(
"member-1".to_string(),
vec![
PartitionAssignment {
topic: "topic-1".to_string(),
partition: 0,
},
PartitionAssignment {
topic: "topic-1".to_string(),
partition: 1,
},
],
);
initial.insert(
"member-2".to_string(),
vec![
PartitionAssignment {
topic: "topic-1".to_string(),
partition: 2,
},
PartitionAssignment {
topic: "topic-1".to_string(),
partition: 3,
},
],
);
group.complete_rebalance(initial);
let mut new_assignment = HashMap::new();
new_assignment.insert(
"member-1".to_string(),
vec![PartitionAssignment {
topic: "topic-1".to_string(),
partition: 0,
}],
);
new_assignment.insert(
"member-2".to_string(),
vec![
PartitionAssignment {
topic: "topic-1".to_string(),
partition: 1,
},
PartitionAssignment {
topic: "topic-1".to_string(),
partition: 2,
},
PartitionAssignment {
topic: "topic-1".to_string(),
partition: 3,
},
],
);
let revocations = group.compute_revocations(&new_assignment);
assert!(revocations.contains_key("member-1"));
assert!(!revocations.contains_key("member-2"));
let m1_revoked = revocations.get("member-1").unwrap();
assert_eq!(m1_revoked.len(), 1);
assert_eq!(m1_revoked[0].partition, 1);
}
#[test]
fn test_cooperative_rebalance_two_phase() {
let mut group = ConsumerGroup::new(
"test-group".to_string(),
Duration::from_secs(30),
Duration::from_secs(60),
);
group.add_member_full(
"member-1".to_string(),
None,
"client-1".to_string(),
vec!["topic-1".to_string()],
vec![],
vec![RebalanceProtocol::Cooperative],
);
group.add_member_full(
"member-2".to_string(),
None,
"client-2".to_string(),
vec!["topic-1".to_string()],
vec![],
vec![RebalanceProtocol::Cooperative],
);
assert!(group.is_cooperative());
let mut initial = HashMap::new();
initial.insert(
"member-1".to_string(),
vec![
PartitionAssignment {
topic: "topic-1".to_string(),
partition: 0,
},
PartitionAssignment {
topic: "topic-1".to_string(),
partition: 1,
},
],
);
initial.insert("member-2".to_string(), vec![]);
group.complete_rebalance(initial);
let gen_before = group.generation_id;
let mut new_assignment = HashMap::new();
new_assignment.insert(
"member-1".to_string(),
vec![PartitionAssignment {
topic: "topic-1".to_string(),
partition: 0,
}],
);
new_assignment.insert(
"member-2".to_string(),
vec![PartitionAssignment {
topic: "topic-1".to_string(),
partition: 1,
}],
);
let result = group.rebalance_with_strategy(new_assignment.clone());
match result {
RebalanceResult::AwaitingRevocations {
revocations,
pending_assignments: _,
} => {
assert!(revocations.contains_key("member-1"));
assert!(group.has_pending_revocations());
assert_eq!(group.state, GroupState::CompletingRebalance);
}
RebalanceResult::Complete => panic!("Expected AwaitingRevocations"),
}
let m1 = group.members.get("member-1").unwrap();
assert_eq!(m1.assignment.len(), 2);
assert_eq!(m1.pending_revocation.len(), 1);
let all_acked = group.acknowledge_revocation(&"member-1".to_string());
assert!(all_acked);
let m1 = group.members.get("member-1").unwrap();
assert_eq!(m1.assignment.len(), 1);
assert_eq!(m1.assignment[0].partition, 0);
group.complete_cooperative_rebalance(new_assignment);
let m2 = group.members.get("member-2").unwrap();
assert_eq!(m2.assignment.len(), 1);
assert_eq!(m2.assignment[0].partition, 1);
assert_eq!(group.generation_id, gen_before + 1);
assert_eq!(group.state, GroupState::Stable);
}
#[test]
fn test_eager_rebalance_immediate() {
let mut group = ConsumerGroup::new(
"test-group".to_string(),
Duration::from_secs(30),
Duration::from_secs(60),
);
group.add_member(
"member-1".to_string(),
"client-1".to_string(),
vec!["topic-1".to_string()],
vec![],
);
group.add_member(
"member-2".to_string(),
"client-2".to_string(),
vec!["topic-1".to_string()],
vec![],
);
assert!(!group.is_cooperative());
let mut new_assignment = HashMap::new();
new_assignment.insert(
"member-1".to_string(),
vec![PartitionAssignment {
topic: "topic-1".to_string(),
partition: 0,
}],
);
new_assignment.insert(
"member-2".to_string(),
vec![PartitionAssignment {
topic: "topic-1".to_string(),
partition: 1,
}],
);
let result = group.rebalance_with_strategy(new_assignment);
assert_eq!(result, RebalanceResult::Complete);
assert_eq!(group.state, GroupState::Stable);
assert!(!group.has_pending_revocations());
}
#[test]
fn test_cooperative_no_revocations_needed() {
let mut group = ConsumerGroup::new(
"test-group".to_string(),
Duration::from_secs(30),
Duration::from_secs(60),
);
group.add_member_full(
"member-1".to_string(),
None,
"client-1".to_string(),
vec!["topic-1".to_string()],
vec![],
vec![RebalanceProtocol::Cooperative],
);
let mut initial = HashMap::new();
initial.insert(
"member-1".to_string(),
vec![PartitionAssignment {
topic: "topic-1".to_string(),
partition: 0,
}],
);
group.complete_rebalance(initial);
let mut new_assignment = HashMap::new();
new_assignment.insert(
"member-1".to_string(),
vec![
PartitionAssignment {
topic: "topic-1".to_string(),
partition: 0,
},
PartitionAssignment {
topic: "topic-1".to_string(),
partition: 1,
},
],
);
let result = group.rebalance_with_strategy(new_assignment);
assert_eq!(result, RebalanceResult::Complete);
assert_eq!(group.state, GroupState::Stable);
let m1 = group.members.get("member-1").unwrap();
assert_eq!(m1.assignment.len(), 2);
}
}