1use serde::{Deserialize, Serialize};
76use std::collections::{HashMap, HashSet};
77use std::time::{Duration, SystemTime};
78
79pub type GroupId = String;
81
82pub type MemberId = String;
84
85pub type GroupInstanceId = String;
90
91#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
93pub enum GroupState {
94 Empty,
96
97 PreparingRebalance,
99
100 CompletingRebalance,
102
103 Stable,
105
106 Dead,
108}
109
110#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
112pub enum AssignmentStrategy {
113 #[default]
115 Range,
116
117 RoundRobin,
119
120 Sticky,
122}
123
124#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
128pub enum RebalanceProtocol {
129 #[default]
132 Eager,
133
134 Cooperative,
137}
138
139impl RebalanceProtocol {
140 pub fn select_common(protocols: &[Self]) -> Self {
143 if protocols.is_empty() {
144 return RebalanceProtocol::Eager;
145 }
146
147 if protocols
149 .iter()
150 .all(|p| *p == RebalanceProtocol::Cooperative)
151 {
152 RebalanceProtocol::Cooperative
153 } else {
154 RebalanceProtocol::Eager
155 }
156 }
157}
158
159#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
161pub struct GroupMember {
162 pub member_id: MemberId,
164
165 pub group_instance_id: Option<GroupInstanceId>,
168
169 pub client_id: String,
171
172 pub subscriptions: Vec<String>,
174
175 pub assignment: Vec<PartitionAssignment>,
177
178 pub pending_revocation: Vec<PartitionAssignment>,
181
182 #[serde(
184 serialize_with = "serialize_systemtime",
185 deserialize_with = "deserialize_systemtime"
186 )]
187 pub last_heartbeat: SystemTime,
188
189 pub metadata: Vec<u8>,
191
192 pub is_static: bool,
194
195 pub supported_protocols: Vec<RebalanceProtocol>,
198}
199
200fn serialize_systemtime<S>(time: &SystemTime, serializer: S) -> Result<S::Ok, S::Error>
202where
203 S: serde::Serializer,
204{
205 let duration = time
206 .duration_since(SystemTime::UNIX_EPOCH)
207 .map_err(serde::ser::Error::custom)?;
208 serializer.serialize_u128(duration.as_millis())
209}
210
211fn deserialize_systemtime<'de, D>(deserializer: D) -> Result<SystemTime, D::Error>
212where
213 D: serde::Deserializer<'de>,
214{
215 let millis = u128::deserialize(deserializer)?;
216 Ok(SystemTime::UNIX_EPOCH + std::time::Duration::from_millis(millis as u64))
217}
218
219#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
221pub struct PartitionAssignment {
222 pub topic: String,
223 pub partition: u32,
224}
225
226#[derive(Debug, Clone, PartialEq, Eq)]
228pub enum RebalanceResult {
229 Complete,
231
232 AwaitingRevocations {
235 revocations: HashMap<MemberId, Vec<PartitionAssignment>>,
237 pending_assignments: HashMap<MemberId, Vec<PartitionAssignment>>,
239 },
240}
241
242#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
244pub struct ConsumerGroup {
245 pub group_id: GroupId,
247
248 pub state: GroupState,
250
251 pub generation_id: u32,
253
254 pub leader_id: Option<MemberId>,
256
257 pub protocol_name: String,
259
260 pub assignment_strategy: AssignmentStrategy,
262
263 pub rebalance_protocol: RebalanceProtocol,
266
267 pub members: HashMap<MemberId, GroupMember>,
269
270 pub static_members: HashMap<GroupInstanceId, MemberId>,
273
274 pub pending_static_members: HashMap<GroupInstanceId, Vec<PartitionAssignment>>,
277
278 pub awaiting_revocation: HashMap<MemberId, Vec<PartitionAssignment>>,
281
282 pub offsets: HashMap<String, HashMap<u32, i64>>,
284
285 #[serde(
287 serialize_with = "serialize_duration",
288 deserialize_with = "deserialize_duration"
289 )]
290 pub session_timeout: Duration,
291
292 #[serde(
294 serialize_with = "serialize_duration",
295 deserialize_with = "deserialize_duration"
296 )]
297 pub rebalance_timeout: Duration,
298}
299
300fn serialize_duration<S>(duration: &Duration, serializer: S) -> Result<S::Ok, S::Error>
302where
303 S: serde::Serializer,
304{
305 serializer.serialize_u64(duration.as_millis() as u64)
306}
307
308fn deserialize_duration<'de, D>(deserializer: D) -> Result<Duration, D::Error>
309where
310 D: serde::Deserializer<'de>,
311{
312 let millis = u64::deserialize(deserializer)?;
313 Ok(Duration::from_millis(millis))
314}
315
316impl ConsumerGroup {
317 pub fn new(group_id: GroupId, session_timeout: Duration, rebalance_timeout: Duration) -> Self {
319 Self {
320 group_id,
321 state: GroupState::Empty,
322 generation_id: 0,
323 leader_id: None,
324 protocol_name: "consumer".to_string(),
325 assignment_strategy: AssignmentStrategy::default(),
326 rebalance_protocol: RebalanceProtocol::Eager,
327 members: HashMap::new(),
328 static_members: HashMap::new(),
329 pending_static_members: HashMap::new(),
330 awaiting_revocation: HashMap::new(),
331 offsets: HashMap::new(),
332 session_timeout,
333 rebalance_timeout,
334 }
335 }
336
337 pub fn add_member(
347 &mut self,
348 member_id: MemberId,
349 client_id: String,
350 subscriptions: Vec<String>,
351 metadata: Vec<u8>,
352 ) {
353 self.add_member_full(
354 member_id,
355 None,
356 client_id,
357 subscriptions,
358 metadata,
359 vec![RebalanceProtocol::Eager],
360 )
361 }
362
363 pub fn add_member_with_instance_id(
365 &mut self,
366 member_id: MemberId,
367 group_instance_id: Option<GroupInstanceId>,
368 client_id: String,
369 subscriptions: Vec<String>,
370 metadata: Vec<u8>,
371 ) {
372 self.add_member_full(
373 member_id,
374 group_instance_id,
375 client_id,
376 subscriptions,
377 metadata,
378 vec![RebalanceProtocol::Eager],
379 )
380 }
381
382 pub fn add_member_full(
384 &mut self,
385 member_id: MemberId,
386 group_instance_id: Option<GroupInstanceId>,
387 client_id: String,
388 subscriptions: Vec<String>,
389 metadata: Vec<u8>,
390 supported_protocols: Vec<RebalanceProtocol>,
391 ) {
392 let is_static = group_instance_id.is_some();
393 let supported_protocols = if supported_protocols.is_empty() {
394 vec![RebalanceProtocol::Eager]
395 } else {
396 supported_protocols
397 };
398
399 if let Some(ref instance_id) = group_instance_id {
401 if let Some(old_member_id) = self.static_members.get(instance_id).cloned() {
403 if old_member_id != member_id {
404 self.members.remove(&old_member_id);
406 }
407 }
408
409 if let Some(saved_assignment) = self.pending_static_members.remove(instance_id) {
411 let member = GroupMember {
413 member_id: member_id.clone(),
414 group_instance_id: Some(instance_id.clone()),
415 client_id,
416 subscriptions,
417 assignment: saved_assignment,
418 pending_revocation: Vec::new(),
419 last_heartbeat: SystemTime::now(),
420 metadata,
421 is_static: true,
422 supported_protocols: supported_protocols.clone(),
423 };
424
425 self.members.insert(member_id.clone(), member);
426 self.static_members
427 .insert(instance_id.clone(), member_id.clone());
428
429 self.update_rebalance_protocol();
431
432 if self.leader_id.is_none() {
434 self.leader_id = Some(member_id);
435 }
436
437 if self.state == GroupState::Empty {
439 self.state = GroupState::Stable;
440 }
441 return;
442 }
443
444 self.static_members
446 .insert(instance_id.clone(), member_id.clone());
447 }
448
449 let member = GroupMember {
451 member_id: member_id.clone(),
452 group_instance_id,
453 client_id,
454 subscriptions,
455 assignment: Vec::new(),
456 pending_revocation: Vec::new(),
457 last_heartbeat: SystemTime::now(),
458 metadata,
459 is_static,
460 supported_protocols,
461 };
462
463 self.members.insert(member_id.clone(), member);
464
465 self.update_rebalance_protocol();
467
468 if self.leader_id.is_none() {
470 self.leader_id = Some(member_id);
471 }
472
473 if self.state != GroupState::Empty {
475 self.transition_to_preparing_rebalance();
476 } else if self.members.len() == 1 {
477 self.state = GroupState::PreparingRebalance;
479 }
480 }
481
482 pub fn has_static_member(&self, instance_id: &GroupInstanceId) -> bool {
484 self.static_members.contains_key(instance_id)
485 }
486
487 pub fn get_member_for_instance(&self, instance_id: &GroupInstanceId) -> Option<&MemberId> {
489 self.static_members.get(instance_id)
490 }
491
492 pub fn fence_static_member(&mut self, instance_id: &GroupInstanceId) -> Option<MemberId> {
497 if let Some(old_member_id) = self.static_members.get(instance_id).cloned() {
498 if let Some(old_member) = self.members.get(&old_member_id) {
500 if !old_member.assignment.is_empty() {
501 self.pending_static_members
502 .insert(instance_id.clone(), old_member.assignment.clone());
503 }
504 }
505
506 self.members.remove(&old_member_id);
508
509 if self.leader_id.as_ref() == Some(&old_member_id) {
511 self.leader_id = self.members.keys().next().cloned();
512 }
513
514 Some(old_member_id)
515 } else {
516 None
517 }
518 }
519
520 pub fn remove_member(&mut self, member_id: &MemberId) -> bool {
525 if let Some(member) = self.members.remove(member_id) {
526 if member.is_static {
528 if let Some(ref instance_id) = member.group_instance_id {
529 if !member.assignment.is_empty() {
531 self.pending_static_members
532 .insert(instance_id.clone(), member.assignment);
533 }
534 }
537 } else {
538 if !self.members.is_empty() {
540 self.transition_to_preparing_rebalance();
541 }
542 }
543
544 if self.leader_id.as_ref() == Some(member_id) {
546 self.leader_id = self.members.keys().next().cloned();
547 }
548
549 if self.members.is_empty() {
551 self.state = GroupState::Empty;
552 self.generation_id = 0;
553 self.leader_id = None;
554 self.static_members.clear();
556 self.pending_static_members.clear();
557 }
558
559 true
560 } else {
561 false
562 }
563 }
564
565 pub fn remove_static_member(&mut self, instance_id: &GroupInstanceId) -> bool {
569 if let Some(member_id) = self.static_members.remove(instance_id) {
570 self.pending_static_members.remove(instance_id);
571
572 if self.members.remove(&member_id).is_some() {
573 if self.leader_id.as_ref() == Some(&member_id) {
575 self.leader_id = self.members.keys().next().cloned();
576 }
577
578 if self.members.is_empty() {
580 self.state = GroupState::Empty;
581 self.generation_id = 0;
582 self.leader_id = None;
583 self.static_members.clear();
584 self.pending_static_members.clear();
585 } else {
586 self.transition_to_preparing_rebalance();
588 }
589
590 return true;
591 }
592 }
593
594 if self.pending_static_members.remove(instance_id).is_some() {
596 self.static_members.remove(instance_id);
597 return true;
599 }
600
601 false
602 }
603
604 pub fn heartbeat(&mut self, member_id: &MemberId) -> Result<(), String> {
606 if let Some(member) = self.members.get_mut(member_id) {
607 member.last_heartbeat = SystemTime::now();
608 Ok(())
609 } else {
610 Err(format!("Unknown member: {}", member_id))
611 }
612 }
613
614 pub fn check_timeouts(&mut self) -> Vec<MemberId> {
619 let now = SystemTime::now();
620 let mut timed_out = Vec::new();
621 let mut static_timeouts: Vec<GroupInstanceId> = Vec::new();
622
623 for (member_id, member) in &self.members {
625 if let Ok(elapsed) = now.duration_since(member.last_heartbeat) {
626 if elapsed > self.session_timeout {
627 timed_out.push(member_id.clone());
628
629 if let Some(ref instance_id) = member.group_instance_id {
631 static_timeouts.push(instance_id.clone());
632 }
633 }
634 }
635 }
636
637 for member_id in &timed_out {
639 let member = self.members.get(member_id);
640 if let Some(m) = member {
641 if !m.is_static {
642 self.remove_member(member_id);
643 }
644 }
645 }
646
647 for instance_id in &static_timeouts {
649 self.remove_static_member(instance_id);
651 }
652
653 timed_out
654 }
655
656 pub fn check_pending_static_timeouts(
661 &mut self,
662 _pending_timeout: Duration,
663 ) -> Vec<GroupInstanceId> {
664 Vec::new()
669 }
670
671 fn update_rebalance_protocol(&mut self) {
680 let protocols: Vec<RebalanceProtocol> = self
681 .members
682 .values()
683 .flat_map(|m| {
684 if m.supported_protocols
686 .contains(&RebalanceProtocol::Cooperative)
687 {
688 Some(RebalanceProtocol::Cooperative)
689 } else {
690 Some(RebalanceProtocol::Eager)
691 }
692 })
693 .collect();
694
695 self.rebalance_protocol = RebalanceProtocol::select_common(&protocols);
696 }
697
698 pub fn is_cooperative(&self) -> bool {
700 self.rebalance_protocol == RebalanceProtocol::Cooperative
701 }
702
703 pub fn compute_revocations(
708 &self,
709 new_assignments: &HashMap<MemberId, Vec<PartitionAssignment>>,
710 ) -> HashMap<MemberId, Vec<PartitionAssignment>> {
711 let mut revocations: HashMap<MemberId, Vec<PartitionAssignment>> = HashMap::new();
712
713 for (member_id, member) in &self.members {
714 let new_assignment = new_assignments.get(member_id);
715
716 let mut to_revoke = Vec::new();
718 for partition in &member.assignment {
719 let still_assigned = new_assignment
720 .map(|a| a.contains(partition))
721 .unwrap_or(false);
722
723 if !still_assigned {
724 to_revoke.push(partition.clone());
725 }
726 }
727
728 if !to_revoke.is_empty() {
729 revocations.insert(member_id.clone(), to_revoke);
730 }
731 }
732
733 revocations
734 }
735
736 pub fn request_revocations(
741 &mut self,
742 revocations: HashMap<MemberId, Vec<PartitionAssignment>>,
743 ) {
744 self.awaiting_revocation = revocations.clone();
746
747 for (member_id, partitions) in revocations {
749 if let Some(member) = self.members.get_mut(&member_id) {
750 member.pending_revocation = partitions;
751 }
752 }
753
754 self.state = GroupState::CompletingRebalance;
757 }
758
759 pub fn acknowledge_revocation(&mut self, member_id: &MemberId) -> bool {
764 self.awaiting_revocation.remove(member_id);
766
767 if let Some(member) = self.members.get_mut(member_id) {
769 let revoked: HashSet<_> = member.pending_revocation.drain(..).collect();
771 member.assignment.retain(|p| !revoked.contains(p));
772 }
773
774 self.awaiting_revocation.is_empty()
776 }
777
778 pub fn has_pending_revocations(&self) -> bool {
780 !self.awaiting_revocation.is_empty()
781 }
782
783 pub fn get_pending_revocations(&self, member_id: &MemberId) -> Vec<PartitionAssignment> {
785 self.members
786 .get(member_id)
787 .map(|m| m.pending_revocation.clone())
788 .unwrap_or_default()
789 }
790
791 pub fn complete_cooperative_rebalance(
796 &mut self,
797 final_assignments: HashMap<MemberId, Vec<PartitionAssignment>>,
798 ) {
799 for (member_id, new_partitions) in final_assignments {
801 if let Some(member) = self.members.get_mut(&member_id) {
802 for partition in new_partitions {
804 if !member.assignment.contains(&partition) {
805 member.assignment.push(partition);
806 }
807 }
808 }
809 }
810
811 self.awaiting_revocation.clear();
813
814 self.generation_id += 1;
816 self.state = GroupState::Stable;
817 }
818
819 pub fn rebalance_with_strategy(
824 &mut self,
825 new_assignments: HashMap<MemberId, Vec<PartitionAssignment>>,
826 ) -> RebalanceResult {
827 match self.rebalance_protocol {
828 RebalanceProtocol::Eager => {
829 self.complete_rebalance(new_assignments);
831 RebalanceResult::Complete
832 }
833 RebalanceProtocol::Cooperative => {
834 let revocations = self.compute_revocations(&new_assignments);
836
837 if revocations.is_empty() {
838 self.complete_cooperative_rebalance(new_assignments);
840 RebalanceResult::Complete
841 } else {
842 self.request_revocations(revocations.clone());
844 RebalanceResult::AwaitingRevocations {
845 revocations,
846 pending_assignments: new_assignments,
847 }
848 }
849 }
850 }
851 }
852
853 fn transition_to_preparing_rebalance(&mut self) {
855 if self.state != GroupState::Empty {
856 self.state = GroupState::PreparingRebalance;
857 }
858 }
859
860 pub fn complete_rebalance(&mut self, assignments: HashMap<MemberId, Vec<PartitionAssignment>>) {
862 for (member_id, partitions) in assignments {
864 if let Some(member) = self.members.get_mut(&member_id) {
865 member.assignment = partitions;
866 }
867 }
868
869 self.generation_id += 1;
871 self.state = GroupState::Stable;
872 }
873
874 pub fn commit_offset(&mut self, topic: &str, partition: u32, offset: i64) {
876 self.offsets
877 .entry(topic.to_string())
878 .or_default()
879 .insert(partition, offset);
880 }
881
882 pub fn fetch_offset(&self, topic: &str, partition: u32) -> Option<i64> {
884 self.offsets.get(topic)?.get(&partition).copied()
885 }
886
887 pub fn all_assignments(&self) -> HashMap<PartitionAssignment, MemberId> {
889 let mut assignments = HashMap::new();
890 for (member_id, member) in &self.members {
891 for partition in &member.assignment {
892 assignments.insert(partition.clone(), member_id.clone());
893 }
894 }
895 assignments
896 }
897}
898
899pub mod assignment {
901 use super::*;
902
903 pub fn range_assignment(
910 members: &[MemberId],
911 topic_partitions: &HashMap<String, u32>,
912 ) -> HashMap<MemberId, Vec<PartitionAssignment>> {
913 let mut assignments: HashMap<MemberId, Vec<PartitionAssignment>> = HashMap::new();
914
915 if members.is_empty() {
916 return assignments;
917 }
918
919 for (topic, partition_count) in topic_partitions {
920 let partitions_per_member = partition_count / members.len() as u32;
921 let extra_partitions = partition_count % members.len() as u32;
922
923 let mut current_partition = 0;
924
925 for (idx, member_id) in members.iter().enumerate() {
926 let mut member_partitions = partitions_per_member;
927 if (idx as u32) < extra_partitions {
928 member_partitions += 1;
929 }
930
931 for _ in 0..member_partitions {
932 assignments
933 .entry(member_id.clone())
934 .or_default()
935 .push(PartitionAssignment {
936 topic: topic.clone(),
937 partition: current_partition,
938 });
939 current_partition += 1;
940 }
941 }
942 }
943
944 assignments
945 }
946
947 pub fn round_robin_assignment(
954 members: &[MemberId],
955 topic_partitions: &HashMap<String, u32>,
956 ) -> HashMap<MemberId, Vec<PartitionAssignment>> {
957 let mut assignments: HashMap<MemberId, Vec<PartitionAssignment>> = HashMap::new();
958
959 if members.is_empty() {
960 return assignments;
961 }
962
963 let mut member_idx = 0;
964
965 for (topic, partition_count) in topic_partitions {
966 for partition in 0..*partition_count {
967 assignments
968 .entry(members[member_idx].clone())
969 .or_default()
970 .push(PartitionAssignment {
971 topic: topic.clone(),
972 partition,
973 });
974
975 member_idx = (member_idx + 1) % members.len();
976 }
977 }
978
979 assignments
980 }
981
982 pub fn sticky_assignment(
986 members: &[MemberId],
987 topic_partitions: &HashMap<String, u32>,
988 previous_assignments: &HashMap<MemberId, Vec<PartitionAssignment>>,
989 ) -> HashMap<MemberId, Vec<PartitionAssignment>> {
990 let mut assignments: HashMap<MemberId, Vec<PartitionAssignment>> = HashMap::new();
991
992 if members.is_empty() {
993 return assignments;
994 }
995
996 let mut all_partitions = Vec::new();
998 for (topic, partition_count) in topic_partitions {
999 for partition in 0..*partition_count {
1000 all_partitions.push(PartitionAssignment {
1001 topic: topic.clone(),
1002 partition,
1003 });
1004 }
1005 }
1006
1007 let mut assigned: HashSet<PartitionAssignment> = HashSet::new();
1009
1010 for member_id in members {
1012 if let Some(prev_partitions) = previous_assignments.get(member_id) {
1013 let valid_partitions: Vec<_> = prev_partitions
1014 .iter()
1015 .filter(|p| all_partitions.contains(p) && !assigned.contains(p))
1016 .cloned()
1017 .collect();
1018
1019 for partition in &valid_partitions {
1020 assigned.insert(partition.clone());
1021 }
1022
1023 assignments.insert(member_id.clone(), valid_partitions);
1024 }
1025 }
1026
1027 let unassigned: Vec<_> = all_partitions
1029 .into_iter()
1030 .filter(|p| !assigned.contains(p))
1031 .collect();
1032
1033 let mut member_idx = 0;
1034 for partition in unassigned {
1035 assignments
1036 .entry(members[member_idx].clone())
1037 .or_default()
1038 .push(partition);
1039
1040 member_idx = (member_idx + 1) % members.len();
1041 }
1042
1043 assignments
1044 }
1045}
1046
1047#[cfg(test)]
1048mod tests {
1049 use super::*;
1050
1051 #[test]
1052 fn test_consumer_group_creation() {
1053 let group = ConsumerGroup::new(
1054 "test-group".to_string(),
1055 Duration::from_secs(30),
1056 Duration::from_secs(60),
1057 );
1058
1059 assert_eq!(group.group_id, "test-group");
1060 assert_eq!(group.state, GroupState::Empty);
1061 assert_eq!(group.generation_id, 0);
1062 assert!(group.leader_id.is_none());
1063 assert!(group.members.is_empty());
1064 }
1065
1066 #[test]
1067 fn test_add_first_member_becomes_leader() {
1068 let mut group = ConsumerGroup::new(
1069 "test-group".to_string(),
1070 Duration::from_secs(30),
1071 Duration::from_secs(60),
1072 );
1073
1074 group.add_member(
1075 "member-1".to_string(),
1076 "client-1".to_string(),
1077 vec!["topic-1".to_string()],
1078 vec![],
1079 );
1080
1081 assert_eq!(group.members.len(), 1);
1082 assert_eq!(group.leader_id, Some("member-1".to_string()));
1083 assert_eq!(group.state, GroupState::PreparingRebalance);
1084 }
1085
1086 #[test]
1087 fn test_remove_member_triggers_rebalance() {
1088 let mut group = ConsumerGroup::new(
1089 "test-group".to_string(),
1090 Duration::from_secs(30),
1091 Duration::from_secs(60),
1092 );
1093
1094 group.add_member(
1095 "member-1".to_string(),
1096 "client-1".to_string(),
1097 vec![],
1098 vec![],
1099 );
1100 group.add_member(
1101 "member-2".to_string(),
1102 "client-2".to_string(),
1103 vec![],
1104 vec![],
1105 );
1106
1107 group.state = GroupState::Stable;
1108
1109 group.remove_member(&"member-2".to_string());
1110
1111 assert_eq!(group.members.len(), 1);
1112 assert_eq!(group.state, GroupState::PreparingRebalance);
1113 }
1114
1115 #[test]
1116 fn test_remove_last_member_transitions_to_empty() {
1117 let mut group = ConsumerGroup::new(
1118 "test-group".to_string(),
1119 Duration::from_secs(30),
1120 Duration::from_secs(60),
1121 );
1122
1123 group.add_member(
1124 "member-1".to_string(),
1125 "client-1".to_string(),
1126 vec![],
1127 vec![],
1128 );
1129 group.remove_member(&"member-1".to_string());
1130
1131 assert_eq!(group.state, GroupState::Empty);
1132 assert_eq!(group.generation_id, 0);
1133 assert!(group.leader_id.is_none());
1134 }
1135
1136 #[test]
1137 fn test_offset_commit_and_fetch() {
1138 let mut group = ConsumerGroup::new(
1139 "test-group".to_string(),
1140 Duration::from_secs(30),
1141 Duration::from_secs(60),
1142 );
1143
1144 group.commit_offset("topic-1", 0, 100);
1145 group.commit_offset("topic-1", 1, 200);
1146 group.commit_offset("topic-2", 0, 300);
1147
1148 assert_eq!(group.fetch_offset("topic-1", 0), Some(100));
1149 assert_eq!(group.fetch_offset("topic-1", 1), Some(200));
1150 assert_eq!(group.fetch_offset("topic-2", 0), Some(300));
1151 assert_eq!(group.fetch_offset("topic-1", 2), None);
1152 }
1153
1154 #[test]
1155 fn test_range_assignment() {
1156 let members = vec!["m1".to_string(), "m2".to_string(), "m3".to_string()];
1157 let mut topic_partitions = HashMap::new();
1158 topic_partitions.insert("topic-1".to_string(), 10);
1159
1160 let assignments = assignment::range_assignment(&members, &topic_partitions);
1161
1162 assert_eq!(assignments.get("m1").unwrap().len(), 4);
1166 assert_eq!(assignments.get("m2").unwrap().len(), 3);
1167 assert_eq!(assignments.get("m3").unwrap().len(), 3);
1168
1169 let m1_partitions: Vec<u32> = assignments
1171 .get("m1")
1172 .unwrap()
1173 .iter()
1174 .map(|p| p.partition)
1175 .collect();
1176 assert_eq!(m1_partitions, vec![0, 1, 2, 3]);
1177 }
1178
1179 #[test]
1180 fn test_round_robin_assignment() {
1181 let members = vec!["m1".to_string(), "m2".to_string(), "m3".to_string()];
1182 let mut topic_partitions = HashMap::new();
1183 topic_partitions.insert("topic-1".to_string(), 10);
1184
1185 let assignments = assignment::round_robin_assignment(&members, &topic_partitions);
1186
1187 assert_eq!(assignments.get("m1").unwrap().len(), 4);
1191 assert_eq!(assignments.get("m2").unwrap().len(), 3);
1192 assert_eq!(assignments.get("m3").unwrap().len(), 3);
1193
1194 let m1_partitions: Vec<u32> = assignments
1195 .get("m1")
1196 .unwrap()
1197 .iter()
1198 .map(|p| p.partition)
1199 .collect();
1200 assert_eq!(m1_partitions, vec![0, 3, 6, 9]);
1201 }
1202
1203 #[test]
1204 fn test_sticky_assignment_preserves_assignments() {
1205 let members = vec!["m1".to_string(), "m2".to_string()];
1206 let mut topic_partitions = HashMap::new();
1207 topic_partitions.insert("topic-1".to_string(), 4);
1208
1209 let mut previous = HashMap::new();
1211 previous.insert(
1212 "m1".to_string(),
1213 vec![
1214 PartitionAssignment {
1215 topic: "topic-1".to_string(),
1216 partition: 0,
1217 },
1218 PartitionAssignment {
1219 topic: "topic-1".to_string(),
1220 partition: 1,
1221 },
1222 ],
1223 );
1224 previous.insert(
1225 "m2".to_string(),
1226 vec![
1227 PartitionAssignment {
1228 topic: "topic-1".to_string(),
1229 partition: 2,
1230 },
1231 PartitionAssignment {
1232 topic: "topic-1".to_string(),
1233 partition: 3,
1234 },
1235 ],
1236 );
1237
1238 let assignments = assignment::sticky_assignment(&members, &topic_partitions, &previous);
1239
1240 assert_eq!(assignments.get("m1").unwrap().len(), 2);
1242 assert_eq!(assignments.get("m2").unwrap().len(), 2);
1243
1244 let m1_partitions: HashSet<u32> = assignments
1245 .get("m1")
1246 .unwrap()
1247 .iter()
1248 .map(|p| p.partition)
1249 .collect();
1250 assert!(m1_partitions.contains(&0));
1251 assert!(m1_partitions.contains(&1));
1252 }
1253
1254 #[test]
1255 fn test_sticky_assignment_redistributes_on_new_member() {
1256 let members = vec!["m1".to_string(), "m2".to_string(), "m3".to_string()];
1257 let mut topic_partitions = HashMap::new();
1258 topic_partitions.insert("topic-1".to_string(), 6);
1259
1260 let mut previous = HashMap::new();
1262 previous.insert(
1263 "m1".to_string(),
1264 vec![
1265 PartitionAssignment {
1266 topic: "topic-1".to_string(),
1267 partition: 0,
1268 },
1269 PartitionAssignment {
1270 topic: "topic-1".to_string(),
1271 partition: 1,
1272 },
1273 PartitionAssignment {
1274 topic: "topic-1".to_string(),
1275 partition: 2,
1276 },
1277 ],
1278 );
1279 previous.insert(
1280 "m2".to_string(),
1281 vec![
1282 PartitionAssignment {
1283 topic: "topic-1".to_string(),
1284 partition: 3,
1285 },
1286 PartitionAssignment {
1287 topic: "topic-1".to_string(),
1288 partition: 4,
1289 },
1290 PartitionAssignment {
1291 topic: "topic-1".to_string(),
1292 partition: 5,
1293 },
1294 ],
1295 );
1296
1297 let assignments = assignment::sticky_assignment(&members, &topic_partitions, &previous);
1298
1299 let total_assigned: usize = assignments.values().map(|v| v.len()).sum();
1303 assert_eq!(total_assigned, 6, "All 6 partitions should be assigned");
1304
1305 let m1_partitions: HashSet<u32> = assignments
1307 .get("m1")
1308 .unwrap()
1309 .iter()
1310 .map(|p| p.partition)
1311 .collect();
1312 let m2_partitions: HashSet<u32> = assignments
1313 .get("m2")
1314 .unwrap()
1315 .iter()
1316 .map(|p| p.partition)
1317 .collect();
1318
1319 let m1_kept = m1_partitions.iter().filter(|p| **p <= 2).count();
1321 let m2_kept = m2_partitions.iter().filter(|p| **p >= 3).count();
1322
1323 assert!(
1324 m1_kept > 0 || m2_kept > 0,
1325 "Sticky assignment should preserve some assignments"
1326 );
1327 }
1328
1329 #[test]
1334 fn test_static_member_add() {
1335 let mut group = ConsumerGroup::new(
1336 "test-group".to_string(),
1337 Duration::from_secs(30),
1338 Duration::from_secs(60),
1339 );
1340
1341 group.add_member_with_instance_id(
1343 "member-1".to_string(),
1344 Some("instance-1".to_string()),
1345 "client-1".to_string(),
1346 vec!["topic-1".to_string()],
1347 vec![],
1348 );
1349
1350 assert_eq!(group.members.len(), 1);
1351 assert!(group.has_static_member(&"instance-1".to_string()));
1352 assert_eq!(
1353 group.get_member_for_instance(&"instance-1".to_string()),
1354 Some(&"member-1".to_string())
1355 );
1356
1357 let member = group.members.get("member-1").unwrap();
1358 assert!(member.is_static);
1359 assert_eq!(member.group_instance_id, Some("instance-1".to_string()));
1360 }
1361
1362 #[test]
1363 fn test_static_member_rejoin_no_rebalance() {
1364 let mut group = ConsumerGroup::new(
1365 "test-group".to_string(),
1366 Duration::from_secs(30),
1367 Duration::from_secs(60),
1368 );
1369
1370 group.add_member_with_instance_id(
1372 "member-1".to_string(),
1373 Some("instance-1".to_string()),
1374 "client-1".to_string(),
1375 vec!["topic-1".to_string()],
1376 vec![],
1377 );
1378 group.add_member(
1379 "member-2".to_string(),
1380 "client-2".to_string(),
1381 vec!["topic-1".to_string()],
1382 vec![],
1383 );
1384
1385 let mut assignments = HashMap::new();
1387 assignments.insert(
1388 "member-1".to_string(),
1389 vec![PartitionAssignment {
1390 topic: "topic-1".to_string(),
1391 partition: 0,
1392 }],
1393 );
1394 assignments.insert(
1395 "member-2".to_string(),
1396 vec![PartitionAssignment {
1397 topic: "topic-1".to_string(),
1398 partition: 1,
1399 }],
1400 );
1401 group.complete_rebalance(assignments);
1402 assert_eq!(group.state, GroupState::Stable);
1403 let gen_before = group.generation_id;
1404
1405 group.remove_member(&"member-1".to_string());
1407
1408 assert!(group.pending_static_members.contains_key("instance-1"));
1410
1411 group.add_member_with_instance_id(
1413 "member-1-new".to_string(),
1414 Some("instance-1".to_string()),
1415 "client-1".to_string(),
1416 vec!["topic-1".to_string()],
1417 vec![],
1418 );
1419
1420 let member = group.members.get("member-1-new").unwrap();
1422 assert_eq!(member.assignment.len(), 1);
1423 assert_eq!(member.assignment[0].partition, 0);
1424
1425 assert_eq!(group.generation_id, gen_before);
1427
1428 assert!(!group.pending_static_members.contains_key("instance-1"));
1430 }
1431
1432 #[test]
1433 fn test_static_member_fencing() {
1434 let mut group = ConsumerGroup::new(
1435 "test-group".to_string(),
1436 Duration::from_secs(30),
1437 Duration::from_secs(60),
1438 );
1439
1440 group.add_member_with_instance_id(
1442 "member-1".to_string(),
1443 Some("instance-1".to_string()),
1444 "client-1".to_string(),
1445 vec!["topic-1".to_string()],
1446 vec![],
1447 );
1448
1449 assert!(group.members.contains_key("member-1"));
1450
1451 group.add_member_with_instance_id(
1453 "member-1-new".to_string(),
1454 Some("instance-1".to_string()),
1455 "client-1".to_string(),
1456 vec!["topic-1".to_string()],
1457 vec![],
1458 );
1459
1460 assert!(!group.members.contains_key("member-1"));
1462 assert!(group.members.contains_key("member-1-new"));
1464 assert_eq!(
1466 group.get_member_for_instance(&"instance-1".to_string()),
1467 Some(&"member-1-new".to_string())
1468 );
1469 }
1470
1471 #[test]
1472 fn test_dynamic_member_removal_triggers_rebalance() {
1473 let mut group = ConsumerGroup::new(
1474 "test-group".to_string(),
1475 Duration::from_secs(30),
1476 Duration::from_secs(60),
1477 );
1478
1479 group.add_member_with_instance_id(
1481 "static-member".to_string(),
1482 Some("instance-1".to_string()),
1483 "client-1".to_string(),
1484 vec!["topic-1".to_string()],
1485 vec![],
1486 );
1487 group.add_member(
1488 "dynamic-member".to_string(),
1489 "client-2".to_string(),
1490 vec!["topic-1".to_string()],
1491 vec![],
1492 );
1493
1494 group.state = GroupState::Stable;
1495
1496 group.remove_member(&"dynamic-member".to_string());
1498
1499 assert_eq!(group.state, GroupState::PreparingRebalance);
1501 }
1502
1503 #[test]
1504 fn test_static_member_timeout_triggers_rebalance() {
1505 let mut group = ConsumerGroup::new(
1506 "test-group".to_string(),
1507 Duration::from_millis(10), Duration::from_secs(60),
1509 );
1510
1511 group.add_member_with_instance_id(
1513 "member-1".to_string(),
1514 Some("instance-1".to_string()),
1515 "client-1".to_string(),
1516 vec!["topic-1".to_string()],
1517 vec![],
1518 );
1519 group.add_member(
1520 "member-2".to_string(),
1521 "client-2".to_string(),
1522 vec!["topic-1".to_string()],
1523 vec![],
1524 );
1525
1526 group.state = GroupState::Stable;
1527
1528 group.remove_static_member(&"instance-1".to_string());
1530
1531 assert_eq!(group.state, GroupState::PreparingRebalance);
1533 assert!(!group.has_static_member(&"instance-1".to_string()));
1535 }
1536
1537 #[test]
1538 fn test_mixed_static_and_dynamic_members() {
1539 let mut group = ConsumerGroup::new(
1540 "test-group".to_string(),
1541 Duration::from_secs(30),
1542 Duration::from_secs(60),
1543 );
1544
1545 group.add_member_with_instance_id(
1547 "static-1".to_string(),
1548 Some("instance-1".to_string()),
1549 "client-1".to_string(),
1550 vec!["topic-1".to_string()],
1551 vec![],
1552 );
1553 group.add_member_with_instance_id(
1554 "static-2".to_string(),
1555 Some("instance-2".to_string()),
1556 "client-2".to_string(),
1557 vec!["topic-1".to_string()],
1558 vec![],
1559 );
1560 group.add_member(
1561 "dynamic-1".to_string(),
1562 "client-3".to_string(),
1563 vec!["topic-1".to_string()],
1564 vec![],
1565 );
1566
1567 assert_eq!(group.members.len(), 3);
1568 assert_eq!(group.static_members.len(), 2);
1569
1570 let static1 = group.members.get("static-1").unwrap();
1572 let static2 = group.members.get("static-2").unwrap();
1573 let dynamic1 = group.members.get("dynamic-1").unwrap();
1574
1575 assert!(static1.is_static);
1576 assert!(static2.is_static);
1577 assert!(!dynamic1.is_static);
1578 }
1579
1580 #[test]
1581 fn test_all_members_leave_clears_static_mappings() {
1582 let mut group = ConsumerGroup::new(
1583 "test-group".to_string(),
1584 Duration::from_secs(30),
1585 Duration::from_secs(60),
1586 );
1587
1588 group.add_member_with_instance_id(
1589 "member-1".to_string(),
1590 Some("instance-1".to_string()),
1591 "client-1".to_string(),
1592 vec!["topic-1".to_string()],
1593 vec![],
1594 );
1595
1596 group.remove_member(&"member-1".to_string());
1597
1598 group.remove_static_member(&"instance-1".to_string());
1601
1602 assert_eq!(group.state, GroupState::Empty);
1604 assert!(group.static_members.is_empty());
1605 assert!(group.pending_static_members.is_empty());
1606 }
1607
1608 #[test]
1613 fn test_rebalance_protocol_selection_all_eager() {
1614 let protocols = vec![RebalanceProtocol::Eager, RebalanceProtocol::Eager];
1615 assert_eq!(
1616 RebalanceProtocol::select_common(&protocols),
1617 RebalanceProtocol::Eager
1618 );
1619 }
1620
1621 #[test]
1622 fn test_rebalance_protocol_selection_all_cooperative() {
1623 let protocols = vec![
1624 RebalanceProtocol::Cooperative,
1625 RebalanceProtocol::Cooperative,
1626 ];
1627 assert_eq!(
1628 RebalanceProtocol::select_common(&protocols),
1629 RebalanceProtocol::Cooperative
1630 );
1631 }
1632
1633 #[test]
1634 fn test_rebalance_protocol_selection_mixed() {
1635 let protocols = vec![RebalanceProtocol::Cooperative, RebalanceProtocol::Eager];
1637 assert_eq!(
1638 RebalanceProtocol::select_common(&protocols),
1639 RebalanceProtocol::Eager
1640 );
1641 }
1642
1643 #[test]
1644 fn test_cooperative_member_add_updates_protocol() {
1645 let mut group = ConsumerGroup::new(
1646 "test-group".to_string(),
1647 Duration::from_secs(30),
1648 Duration::from_secs(60),
1649 );
1650
1651 group.add_member_full(
1653 "member-1".to_string(),
1654 None,
1655 "client-1".to_string(),
1656 vec!["topic-1".to_string()],
1657 vec![],
1658 vec![RebalanceProtocol::Cooperative],
1659 );
1660
1661 assert!(group.is_cooperative());
1663
1664 group.add_member_full(
1666 "member-2".to_string(),
1667 None,
1668 "client-2".to_string(),
1669 vec!["topic-1".to_string()],
1670 vec![],
1671 vec![RebalanceProtocol::Eager],
1672 );
1673
1674 assert!(!group.is_cooperative());
1676 assert_eq!(group.rebalance_protocol, RebalanceProtocol::Eager);
1677 }
1678
1679 #[test]
1680 fn test_compute_revocations() {
1681 let mut group = ConsumerGroup::new(
1682 "test-group".to_string(),
1683 Duration::from_secs(30),
1684 Duration::from_secs(60),
1685 );
1686
1687 group.add_member_full(
1689 "member-1".to_string(),
1690 None,
1691 "client-1".to_string(),
1692 vec!["topic-1".to_string()],
1693 vec![],
1694 vec![RebalanceProtocol::Cooperative],
1695 );
1696 group.add_member_full(
1697 "member-2".to_string(),
1698 None,
1699 "client-2".to_string(),
1700 vec!["topic-1".to_string()],
1701 vec![],
1702 vec![RebalanceProtocol::Cooperative],
1703 );
1704
1705 let mut initial = HashMap::new();
1707 initial.insert(
1708 "member-1".to_string(),
1709 vec![
1710 PartitionAssignment {
1711 topic: "topic-1".to_string(),
1712 partition: 0,
1713 },
1714 PartitionAssignment {
1715 topic: "topic-1".to_string(),
1716 partition: 1,
1717 },
1718 ],
1719 );
1720 initial.insert(
1721 "member-2".to_string(),
1722 vec![
1723 PartitionAssignment {
1724 topic: "topic-1".to_string(),
1725 partition: 2,
1726 },
1727 PartitionAssignment {
1728 topic: "topic-1".to_string(),
1729 partition: 3,
1730 },
1731 ],
1732 );
1733 group.complete_rebalance(initial);
1734
1735 let mut new_assignment = HashMap::new();
1737 new_assignment.insert(
1738 "member-1".to_string(),
1739 vec![PartitionAssignment {
1740 topic: "topic-1".to_string(),
1741 partition: 0,
1742 }],
1743 );
1744 new_assignment.insert(
1745 "member-2".to_string(),
1746 vec![
1747 PartitionAssignment {
1748 topic: "topic-1".to_string(),
1749 partition: 1,
1750 },
1751 PartitionAssignment {
1752 topic: "topic-1".to_string(),
1753 partition: 2,
1754 },
1755 PartitionAssignment {
1756 topic: "topic-1".to_string(),
1757 partition: 3,
1758 },
1759 ],
1760 );
1761
1762 let revocations = group.compute_revocations(&new_assignment);
1763
1764 assert!(revocations.contains_key("member-1"));
1766 assert!(!revocations.contains_key("member-2"));
1767
1768 let m1_revoked = revocations.get("member-1").unwrap();
1769 assert_eq!(m1_revoked.len(), 1);
1770 assert_eq!(m1_revoked[0].partition, 1);
1771 }
1772
1773 #[test]
1774 fn test_cooperative_rebalance_two_phase() {
1775 let mut group = ConsumerGroup::new(
1776 "test-group".to_string(),
1777 Duration::from_secs(30),
1778 Duration::from_secs(60),
1779 );
1780
1781 group.add_member_full(
1783 "member-1".to_string(),
1784 None,
1785 "client-1".to_string(),
1786 vec!["topic-1".to_string()],
1787 vec![],
1788 vec![RebalanceProtocol::Cooperative],
1789 );
1790 group.add_member_full(
1791 "member-2".to_string(),
1792 None,
1793 "client-2".to_string(),
1794 vec!["topic-1".to_string()],
1795 vec![],
1796 vec![RebalanceProtocol::Cooperative],
1797 );
1798
1799 assert!(group.is_cooperative());
1800
1801 let mut initial = HashMap::new();
1803 initial.insert(
1804 "member-1".to_string(),
1805 vec![
1806 PartitionAssignment {
1807 topic: "topic-1".to_string(),
1808 partition: 0,
1809 },
1810 PartitionAssignment {
1811 topic: "topic-1".to_string(),
1812 partition: 1,
1813 },
1814 ],
1815 );
1816 initial.insert("member-2".to_string(), vec![]);
1817 group.complete_rebalance(initial);
1818
1819 let gen_before = group.generation_id;
1820
1821 let mut new_assignment = HashMap::new();
1823 new_assignment.insert(
1824 "member-1".to_string(),
1825 vec![PartitionAssignment {
1826 topic: "topic-1".to_string(),
1827 partition: 0,
1828 }],
1829 );
1830 new_assignment.insert(
1831 "member-2".to_string(),
1832 vec![PartitionAssignment {
1833 topic: "topic-1".to_string(),
1834 partition: 1,
1835 }],
1836 );
1837
1838 let result = group.rebalance_with_strategy(new_assignment.clone());
1840
1841 match result {
1842 RebalanceResult::AwaitingRevocations {
1843 revocations,
1844 pending_assignments: _,
1845 } => {
1846 assert!(revocations.contains_key("member-1"));
1848 assert!(group.has_pending_revocations());
1849 assert_eq!(group.state, GroupState::CompletingRebalance);
1850 }
1851 RebalanceResult::Complete => panic!("Expected AwaitingRevocations"),
1852 }
1853
1854 let m1 = group.members.get("member-1").unwrap();
1856 assert_eq!(m1.assignment.len(), 2);
1857 assert_eq!(m1.pending_revocation.len(), 1);
1858
1859 let all_acked = group.acknowledge_revocation(&"member-1".to_string());
1861 assert!(all_acked);
1862
1863 let m1 = group.members.get("member-1").unwrap();
1865 assert_eq!(m1.assignment.len(), 1);
1866 assert_eq!(m1.assignment[0].partition, 0);
1867
1868 group.complete_cooperative_rebalance(new_assignment);
1870
1871 let m2 = group.members.get("member-2").unwrap();
1873 assert_eq!(m2.assignment.len(), 1);
1874 assert_eq!(m2.assignment[0].partition, 1);
1875
1876 assert_eq!(group.generation_id, gen_before + 1);
1878 assert_eq!(group.state, GroupState::Stable);
1879 }
1880
1881 #[test]
1882 fn test_eager_rebalance_immediate() {
1883 let mut group = ConsumerGroup::new(
1884 "test-group".to_string(),
1885 Duration::from_secs(30),
1886 Duration::from_secs(60),
1887 );
1888
1889 group.add_member(
1891 "member-1".to_string(),
1892 "client-1".to_string(),
1893 vec!["topic-1".to_string()],
1894 vec![],
1895 );
1896 group.add_member(
1897 "member-2".to_string(),
1898 "client-2".to_string(),
1899 vec!["topic-1".to_string()],
1900 vec![],
1901 );
1902
1903 assert!(!group.is_cooperative());
1904
1905 let mut new_assignment = HashMap::new();
1907 new_assignment.insert(
1908 "member-1".to_string(),
1909 vec![PartitionAssignment {
1910 topic: "topic-1".to_string(),
1911 partition: 0,
1912 }],
1913 );
1914 new_assignment.insert(
1915 "member-2".to_string(),
1916 vec![PartitionAssignment {
1917 topic: "topic-1".to_string(),
1918 partition: 1,
1919 }],
1920 );
1921
1922 let result = group.rebalance_with_strategy(new_assignment);
1923
1924 assert_eq!(result, RebalanceResult::Complete);
1925 assert_eq!(group.state, GroupState::Stable);
1926 assert!(!group.has_pending_revocations());
1927 }
1928
1929 #[test]
1930 fn test_cooperative_no_revocations_needed() {
1931 let mut group = ConsumerGroup::new(
1932 "test-group".to_string(),
1933 Duration::from_secs(30),
1934 Duration::from_secs(60),
1935 );
1936
1937 group.add_member_full(
1939 "member-1".to_string(),
1940 None,
1941 "client-1".to_string(),
1942 vec!["topic-1".to_string()],
1943 vec![],
1944 vec![RebalanceProtocol::Cooperative],
1945 );
1946
1947 let mut initial = HashMap::new();
1949 initial.insert(
1950 "member-1".to_string(),
1951 vec![PartitionAssignment {
1952 topic: "topic-1".to_string(),
1953 partition: 0,
1954 }],
1955 );
1956 group.complete_rebalance(initial);
1957
1958 let mut new_assignment = HashMap::new();
1960 new_assignment.insert(
1961 "member-1".to_string(),
1962 vec![
1963 PartitionAssignment {
1964 topic: "topic-1".to_string(),
1965 partition: 0,
1966 },
1967 PartitionAssignment {
1968 topic: "topic-1".to_string(),
1969 partition: 1,
1970 },
1971 ],
1972 );
1973
1974 let result = group.rebalance_with_strategy(new_assignment);
1976
1977 assert_eq!(result, RebalanceResult::Complete);
1978 assert_eq!(group.state, GroupState::Stable);
1979
1980 let m1 = group.members.get("member-1").unwrap();
1981 assert_eq!(m1.assignment.len(), 2);
1982 }
1983}