use rivven_core::consumer_group::{
ConsumerGroup, GroupId, GroupState, MemberId, PartitionAssignment,
};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::{Arc, PoisonError, RwLock, RwLockReadGuard, RwLockWriteGuard};
use std::time::Duration;
use thiserror::Error;
pub type CoordinatorResult<T> = Result<T, CoordinatorError>;
#[derive(Debug, Error, Clone, Serialize, Deserialize)]
pub enum CoordinatorError {
#[error("Group not found: {0}")]
GroupNotFound(String),
#[error("Member not found: {0}")]
MemberNotFound(String),
#[error("Rebalance in progress")]
RebalanceInProgress,
#[error("Invalid generation: expected {expected}, got {actual}")]
InvalidGeneration { expected: u32, actual: u32 },
#[error("Not group leader")]
NotGroupLeader,
#[error("Illegal generation: {0}")]
IllegalGeneration(String),
#[error("Unknown member: {0}")]
UnknownMember(String),
#[error("Invalid session timeout: {0}")]
InvalidSessionTimeout(String),
#[error("Internal error: lock poisoned")]
LockPoisoned,
#[error("Persistence error: {0}")]
PersistenceError(String),
}
impl<T> From<PoisonError<T>> for CoordinatorError {
fn from(_: PoisonError<T>) -> Self {
CoordinatorError::LockPoisoned
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PersistedGroupState {
pub group_id: GroupId,
pub generation_id: u32,
pub state: String,
pub members: Vec<String>,
pub offsets: HashMap<String, HashMap<u32, i64>>,
pub updated_at: i64,
}
impl PersistedGroupState {
pub fn from_group(group: &ConsumerGroup) -> Self {
Self {
group_id: group.group_id.clone(),
generation_id: group.generation_id,
state: format!("{:?}", group.state),
members: group.members.keys().cloned().collect(),
offsets: group.offsets.clone(),
updated_at: chrono::Utc::now().timestamp_millis(),
}
}
}
pub trait GroupPersistence: Send + Sync {
fn save_group(&self, state: &PersistedGroupState) -> CoordinatorResult<()>;
fn load_group(&self, group_id: &str) -> CoordinatorResult<Option<PersistedGroupState>>;
fn list_groups(&self) -> CoordinatorResult<Vec<String>>;
fn delete_group(&self, group_id: &str) -> CoordinatorResult<()>;
fn save_offset(
&self,
group_id: &str,
topic: &str,
partition: u32,
offset: i64,
) -> CoordinatorResult<()>;
fn load_offsets(&self, group_id: &str)
-> CoordinatorResult<HashMap<String, HashMap<u32, i64>>>;
}
#[derive(Default)]
pub struct InMemoryPersistence {
groups: RwLock<HashMap<String, PersistedGroupState>>,
}
impl InMemoryPersistence {
pub fn new() -> Self {
Self {
groups: RwLock::new(HashMap::new()),
}
}
}
impl GroupPersistence for InMemoryPersistence {
fn save_group(&self, state: &PersistedGroupState) -> CoordinatorResult<()> {
let mut groups = self
.groups
.write()
.map_err(|_| CoordinatorError::LockPoisoned)?;
groups.insert(state.group_id.clone(), state.clone());
Ok(())
}
fn load_group(&self, group_id: &str) -> CoordinatorResult<Option<PersistedGroupState>> {
let groups = self
.groups
.read()
.map_err(|_| CoordinatorError::LockPoisoned)?;
Ok(groups.get(group_id).cloned())
}
fn list_groups(&self) -> CoordinatorResult<Vec<String>> {
let groups = self
.groups
.read()
.map_err(|_| CoordinatorError::LockPoisoned)?;
Ok(groups.keys().cloned().collect())
}
fn delete_group(&self, group_id: &str) -> CoordinatorResult<()> {
let mut groups = self
.groups
.write()
.map_err(|_| CoordinatorError::LockPoisoned)?;
groups.remove(group_id);
Ok(())
}
fn save_offset(
&self,
group_id: &str,
topic: &str,
partition: u32,
offset: i64,
) -> CoordinatorResult<()> {
let mut groups = self
.groups
.write()
.map_err(|_| CoordinatorError::LockPoisoned)?;
if let Some(state) = groups.get_mut(group_id) {
state
.offsets
.entry(topic.to_string())
.or_default()
.insert(partition, offset);
state.updated_at = chrono::Utc::now().timestamp_millis();
}
Ok(())
}
fn load_offsets(
&self,
group_id: &str,
) -> CoordinatorResult<HashMap<String, HashMap<u32, i64>>> {
let groups = self
.groups
.read()
.map_err(|_| CoordinatorError::LockPoisoned)?;
Ok(groups
.get(group_id)
.map(|s| s.offsets.clone())
.unwrap_or_default())
}
}
type PersistCallback = Box<dyn Fn(&[u8]) -> bool + Send + Sync>;
pub struct RaftPersistence {
cache: RwLock<HashMap<String, PersistedGroupState>>,
#[allow(dead_code)]
on_persist: Option<PersistCallback>,
}
impl RaftPersistence {
pub fn new() -> Self {
Self {
cache: RwLock::new(HashMap::new()),
on_persist: None,
}
}
pub fn with_callback<F>(callback: F) -> Self
where
F: Fn(&[u8]) -> bool + Send + Sync + 'static,
{
Self {
cache: RwLock::new(HashMap::new()),
on_persist: Some(Box::new(callback)),
}
}
pub fn apply_log_entry(&self, entry: &RaftLogEntry) -> CoordinatorResult<()> {
let mut cache = self
.cache
.write()
.map_err(|_| CoordinatorError::LockPoisoned)?;
match entry {
RaftLogEntry::GroupStateChange(state) => {
cache.insert(state.group_id.clone(), state.clone());
}
RaftLogEntry::OffsetCommit {
group_id,
topic,
partition,
offset,
} => {
if let Some(state) = cache.get_mut(group_id) {
state
.offsets
.entry(topic.clone())
.or_default()
.insert(*partition, *offset);
state.updated_at = chrono::Utc::now().timestamp_millis();
}
}
RaftLogEntry::GroupDeleted(group_id) => {
cache.remove(group_id);
}
}
Ok(())
}
pub fn restore_snapshot(
&self,
groups: HashMap<String, PersistedGroupState>,
) -> CoordinatorResult<()> {
let mut cache = self
.cache
.write()
.map_err(|_| CoordinatorError::LockPoisoned)?;
*cache = groups;
Ok(())
}
pub fn create_snapshot(&self) -> CoordinatorResult<HashMap<String, PersistedGroupState>> {
let cache = self
.cache
.read()
.map_err(|_| CoordinatorError::LockPoisoned)?;
Ok(cache.clone())
}
fn persist_entry(&self, entry: &RaftLogEntry) -> CoordinatorResult<()> {
if let Some(ref callback) = self.on_persist {
let bytes = serde_json::to_vec(entry)
.map_err(|e| CoordinatorError::PersistenceError(e.to_string()))?;
if !callback(&bytes) {
return Err(CoordinatorError::PersistenceError(
"Raft rejected write".to_string(),
));
}
}
self.apply_log_entry(entry)
}
}
impl Default for RaftPersistence {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum RaftLogEntry {
GroupStateChange(PersistedGroupState),
OffsetCommit {
group_id: String,
topic: String,
partition: u32,
offset: i64,
},
GroupDeleted(String),
}
impl GroupPersistence for RaftPersistence {
fn save_group(&self, state: &PersistedGroupState) -> CoordinatorResult<()> {
self.persist_entry(&RaftLogEntry::GroupStateChange(state.clone()))
}
fn load_group(&self, group_id: &str) -> CoordinatorResult<Option<PersistedGroupState>> {
let cache = self
.cache
.read()
.map_err(|_| CoordinatorError::LockPoisoned)?;
Ok(cache.get(group_id).cloned())
}
fn list_groups(&self) -> CoordinatorResult<Vec<String>> {
let cache = self
.cache
.read()
.map_err(|_| CoordinatorError::LockPoisoned)?;
Ok(cache.keys().cloned().collect())
}
fn delete_group(&self, group_id: &str) -> CoordinatorResult<()> {
self.persist_entry(&RaftLogEntry::GroupDeleted(group_id.to_string()))
}
fn save_offset(
&self,
group_id: &str,
topic: &str,
partition: u32,
offset: i64,
) -> CoordinatorResult<()> {
self.persist_entry(&RaftLogEntry::OffsetCommit {
group_id: group_id.to_string(),
topic: topic.to_string(),
partition,
offset,
})
}
fn load_offsets(
&self,
group_id: &str,
) -> CoordinatorResult<HashMap<String, HashMap<u32, i64>>> {
let cache = self
.cache
.read()
.map_err(|_| CoordinatorError::LockPoisoned)?;
Ok(cache
.get(group_id)
.map(|s| s.offsets.clone())
.unwrap_or_default())
}
}
pub struct ConsumerCoordinator {
groups: Arc<RwLock<HashMap<GroupId, ConsumerGroup>>>,
persistence: Arc<dyn GroupPersistence>,
}
impl ConsumerCoordinator {
pub fn new() -> Self {
Self {
groups: Arc::new(RwLock::new(HashMap::new())),
persistence: Arc::new(InMemoryPersistence::new()),
}
}
pub fn with_persistence(persistence: Arc<dyn GroupPersistence>) -> Self {
Self {
groups: Arc::new(RwLock::new(HashMap::new())),
persistence,
}
}
pub fn with_raft_persistence() -> Self {
Self {
groups: Arc::new(RwLock::new(HashMap::new())),
persistence: Arc::new(RaftPersistence::new()),
}
}
pub fn restore(&self) -> CoordinatorResult<usize> {
let group_ids = self.persistence.list_groups()?;
let mut restored = 0;
for group_id in group_ids {
if let Some(state) = self.persistence.load_group(&group_id)? {
let mut groups = self.write_groups()?;
use std::collections::hash_map::Entry;
if let Entry::Vacant(entry) = groups.entry(group_id) {
let mut group = ConsumerGroup::new(
entry.key().clone(),
Duration::from_secs(30),
Duration::from_secs(60),
);
group.offsets = state.offsets;
group.generation_id = state.generation_id;
entry.insert(group);
restored += 1;
}
}
}
tracing::info!(
restored_groups = restored,
"Restored consumer groups from persistence"
);
Ok(restored)
}
fn read_groups(
&self,
) -> CoordinatorResult<RwLockReadGuard<'_, HashMap<GroupId, ConsumerGroup>>> {
self.groups
.read()
.map_err(|_| CoordinatorError::LockPoisoned)
}
fn write_groups(
&self,
) -> CoordinatorResult<RwLockWriteGuard<'_, HashMap<GroupId, ConsumerGroup>>> {
self.groups
.write()
.map_err(|_| CoordinatorError::LockPoisoned)
}
#[allow(clippy::too_many_arguments)] pub fn join_group(
&self,
group_id: GroupId,
member_id: Option<MemberId>,
client_id: String,
session_timeout: Duration,
rebalance_timeout: Duration,
subscriptions: Vec<String>,
metadata: Vec<u8>,
) -> CoordinatorResult<JoinGroupResponse> {
let mut groups = self.write_groups()?;
let is_new_group = !groups.contains_key(&group_id);
let group = groups.entry(group_id.clone()).or_insert_with(|| {
ConsumerGroup::new(group_id.clone(), session_timeout, rebalance_timeout)
});
let member_id =
member_id.unwrap_or_else(|| format!("{}-{}", client_id, uuid::Uuid::new_v4()));
group.add_member(member_id.clone(), client_id, subscriptions, metadata);
let generation_id = group.generation_id;
let leader_id = group.leader_id.clone().unwrap_or_default();
let members = &group.members;
let response = JoinGroupResponse {
generation_id,
member_id,
leader_id,
members: members
.iter()
.map(|(id, member)| MemberInfo {
member_id: id.clone(),
metadata: member.metadata.clone(),
})
.collect(),
};
if is_new_group {
let state = PersistedGroupState::from_group(group);
if let Err(e) = self.persistence.save_group(&state) {
tracing::warn!(group_id = %group_id, error = %e, "Failed to persist new group state");
}
}
Ok(response)
}
pub fn sync_group(
&self,
group_id: GroupId,
member_id: MemberId,
generation_id: u32,
assignments: HashMap<MemberId, Vec<PartitionAssignment>>,
) -> CoordinatorResult<SyncGroupResponse> {
let mut groups = self.write_groups()?;
let group = groups
.get_mut(&group_id)
.ok_or_else(|| CoordinatorError::GroupNotFound(group_id.clone()))?;
if group.generation_id != generation_id {
return Err(CoordinatorError::InvalidGeneration {
expected: group.generation_id,
actual: generation_id,
});
}
if !group.members.contains_key(&member_id) {
return Err(CoordinatorError::UnknownMember(member_id));
}
if Some(&member_id) == group.leader_id.as_ref() && !assignments.is_empty() {
group.complete_rebalance(assignments);
}
let assignment = group
.members
.get(&member_id)
.map(|m| m.assignment.clone())
.unwrap_or_default();
Ok(SyncGroupResponse { assignment })
}
pub fn heartbeat(
&self,
group_id: GroupId,
member_id: MemberId,
generation_id: u32,
) -> CoordinatorResult<HeartbeatResponse> {
let mut groups = self.write_groups()?;
let group = groups
.get_mut(&group_id)
.ok_or_else(|| CoordinatorError::GroupNotFound(group_id.clone()))?;
if group.generation_id != generation_id {
return Err(CoordinatorError::InvalidGeneration {
expected: group.generation_id,
actual: generation_id,
});
}
if !group.members.contains_key(&member_id) {
return Err(CoordinatorError::UnknownMember(member_id));
}
let _ = group.heartbeat(&member_id);
let timed_out = group.check_timeouts();
if !timed_out.is_empty() {
return Ok(HeartbeatResponse {
rebalance_required: true,
});
}
Ok(HeartbeatResponse {
rebalance_required: group.state == GroupState::PreparingRebalance,
})
}
pub fn leave_group(
&self,
group_id: GroupId,
member_id: MemberId,
) -> CoordinatorResult<LeaveGroupResponse> {
let mut groups = self.write_groups()?;
let group = groups
.get_mut(&group_id)
.ok_or_else(|| CoordinatorError::GroupNotFound(group_id.clone()))?;
group.remove_member(&member_id);
Ok(LeaveGroupResponse {})
}
pub fn commit_offset(
&self,
group_id: GroupId,
topic: String,
partition: u32,
offset: i64,
) -> CoordinatorResult<CommitOffsetResponse> {
let mut groups = self.write_groups()?;
let group = groups
.get_mut(&group_id)
.ok_or_else(|| CoordinatorError::GroupNotFound(group_id.clone()))?;
group.commit_offset(&topic, partition, offset);
self.persistence
.save_offset(&group_id, &topic, partition, offset)?;
Ok(CommitOffsetResponse {})
}
pub fn fetch_offset(
&self,
group_id: GroupId,
topic: String,
partition: u32,
) -> CoordinatorResult<FetchOffsetResponse> {
let groups = self.read_groups()?;
let group = groups
.get(&group_id)
.ok_or_else(|| CoordinatorError::GroupNotFound(group_id.clone()))?;
let offset = group.fetch_offset(&topic, partition);
Ok(FetchOffsetResponse { offset })
}
pub fn check_timeouts(&self) -> CoordinatorResult<()> {
let mut groups = self.write_groups()?;
for group in groups.values_mut() {
let timed_out = group.check_timeouts();
if !timed_out.is_empty() {
for member_id in timed_out {
group.remove_member(&member_id);
}
}
}
Ok(())
}
}
impl Default for ConsumerCoordinator {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JoinGroupResponse {
pub generation_id: u32,
pub member_id: MemberId,
pub leader_id: MemberId,
pub members: Vec<MemberInfo>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemberInfo {
pub member_id: MemberId,
pub metadata: Vec<u8>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SyncGroupResponse {
pub assignment: Vec<PartitionAssignment>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HeartbeatResponse {
pub rebalance_required: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LeaveGroupResponse {}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CommitOffsetResponse {}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FetchOffsetResponse {
pub offset: Option<i64>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_join_group_creates_group() {
let coordinator = ConsumerCoordinator::new();
let response = coordinator
.join_group(
"test-group".to_string(),
None,
"client-1".to_string(),
Duration::from_secs(10),
Duration::from_secs(60),
vec!["topic-1".to_string()],
vec![],
)
.unwrap();
assert_eq!(response.generation_id, 0);
assert!(response.member_id.starts_with("client-1"));
assert_eq!(response.members.len(), 1);
}
#[test]
fn test_join_group_first_member_is_leader() {
let coordinator = ConsumerCoordinator::new();
let response = coordinator
.join_group(
"test-group".to_string(),
None,
"client-1".to_string(),
Duration::from_secs(10),
Duration::from_secs(60),
vec!["topic-1".to_string()],
vec![],
)
.unwrap();
assert_eq!(response.leader_id, response.member_id);
}
#[test]
fn test_join_group_triggers_rebalance() {
let coordinator = ConsumerCoordinator::new();
let _response1 = coordinator
.join_group(
"test-group".to_string(),
None,
"client-1".to_string(),
Duration::from_secs(10),
Duration::from_secs(60),
vec!["topic-1".to_string()],
vec![],
)
.unwrap();
let response2 = coordinator
.join_group(
"test-group".to_string(),
None,
"client-2".to_string(),
Duration::from_secs(10),
Duration::from_secs(60),
vec!["topic-1".to_string()],
vec![],
)
.unwrap();
assert_eq!(response2.generation_id, 0);
assert_eq!(response2.members.len(), 2);
}
#[test]
fn test_sync_group_completes_rebalance() {
let coordinator = ConsumerCoordinator::new();
let response1 = coordinator
.join_group(
"test-group".to_string(),
None,
"client-1".to_string(),
Duration::from_secs(10),
Duration::from_secs(60),
vec!["topic-1".to_string()],
vec![],
)
.unwrap();
let member_id = response1.member_id.clone();
let generation_id = response1.generation_id;
let mut assignments = HashMap::new();
assignments.insert(
member_id.clone(),
vec![PartitionAssignment {
topic: "topic-1".to_string(),
partition: 0,
}],
);
let sync_response = coordinator
.sync_group(
"test-group".to_string(),
member_id,
generation_id,
assignments,
)
.unwrap();
assert_eq!(sync_response.assignment.len(), 1);
assert_eq!(sync_response.assignment[0].partition, 0);
}
#[test]
fn test_heartbeat_detects_invalid_generation() {
let coordinator = ConsumerCoordinator::new();
let response = coordinator
.join_group(
"test-group".to_string(),
None,
"client-1".to_string(),
Duration::from_secs(10),
Duration::from_secs(60),
vec!["topic-1".to_string()],
vec![],
)
.unwrap();
let result = coordinator.heartbeat(
"test-group".to_string(),
response.member_id,
999, );
assert!(matches!(
result,
Err(CoordinatorError::InvalidGeneration { .. })
));
}
#[test]
fn test_leave_group_removes_member() {
let coordinator = ConsumerCoordinator::new();
let response = coordinator
.join_group(
"test-group".to_string(),
None,
"client-1".to_string(),
Duration::from_secs(10),
Duration::from_secs(60),
vec!["topic-1".to_string()],
vec![],
)
.unwrap();
coordinator
.leave_group("test-group".to_string(), response.member_id.clone())
.unwrap();
let result = coordinator.heartbeat(
"test-group".to_string(),
response.member_id,
response.generation_id,
);
assert!(matches!(result, Err(CoordinatorError::UnknownMember(_))));
}
#[test]
fn test_commit_and_fetch_offset() {
let coordinator = ConsumerCoordinator::new();
coordinator
.join_group(
"test-group".to_string(),
None,
"client-1".to_string(),
Duration::from_secs(10),
Duration::from_secs(60),
vec!["topic-1".to_string()],
vec![],
)
.unwrap();
coordinator
.commit_offset("test-group".to_string(), "topic-1".to_string(), 0, 42)
.unwrap();
let response = coordinator
.fetch_offset("test-group".to_string(), "topic-1".to_string(), 0)
.unwrap();
assert_eq!(response.offset, Some(42));
}
#[test]
fn test_fetch_offset_returns_none_when_not_committed() {
let coordinator = ConsumerCoordinator::new();
coordinator
.join_group(
"test-group".to_string(),
None,
"client-1".to_string(),
Duration::from_secs(10),
Duration::from_secs(60),
vec!["topic-1".to_string()],
vec![],
)
.unwrap();
let response = coordinator
.fetch_offset("test-group".to_string(), "topic-1".to_string(), 0)
.unwrap();
assert_eq!(response.offset, None);
}
#[test]
fn test_in_memory_persistence() {
let persistence = InMemoryPersistence::new();
let state = PersistedGroupState {
group_id: "test-group".to_string(),
generation_id: 1,
state: "Stable".to_string(),
members: vec!["member-1".to_string()],
offsets: HashMap::new(),
updated_at: 12345678,
};
persistence.save_group(&state).unwrap();
let loaded = persistence.load_group("test-group").unwrap();
assert!(loaded.is_some());
let loaded = loaded.unwrap();
assert_eq!(loaded.group_id, "test-group");
assert_eq!(loaded.generation_id, 1);
assert_eq!(loaded.members.len(), 1);
let groups = persistence.list_groups().unwrap();
assert_eq!(groups.len(), 1);
assert_eq!(groups[0], "test-group");
persistence.delete_group("test-group").unwrap();
let loaded = persistence.load_group("test-group").unwrap();
assert!(loaded.is_none());
}
#[test]
fn test_in_memory_persistence_offsets() {
let persistence = InMemoryPersistence::new();
let state = PersistedGroupState {
group_id: "test-group".to_string(),
generation_id: 1,
state: "Stable".to_string(),
members: vec!["member-1".to_string()],
offsets: HashMap::new(),
updated_at: 12345678,
};
persistence.save_group(&state).unwrap();
persistence
.save_offset("test-group", "topic-1", 0, 100)
.unwrap();
persistence
.save_offset("test-group", "topic-1", 1, 200)
.unwrap();
persistence
.save_offset("test-group", "topic-2", 0, 50)
.unwrap();
let offsets = persistence.load_offsets("test-group").unwrap();
assert_eq!(offsets.len(), 2);
let topic1_offsets = offsets.get("topic-1").unwrap();
assert_eq!(topic1_offsets.get(&0), Some(&100));
assert_eq!(topic1_offsets.get(&1), Some(&200));
let topic2_offsets = offsets.get("topic-2").unwrap();
assert_eq!(topic2_offsets.get(&0), Some(&50));
}
#[test]
fn test_raft_persistence_log_entries() {
use std::sync::atomic::{AtomicUsize, Ordering};
let log_count = Arc::new(AtomicUsize::new(0));
let log_count_clone = log_count.clone();
let persistence = RaftPersistence::with_callback(move |_bytes| {
log_count_clone.fetch_add(1, Ordering::SeqCst);
true
});
let state = PersistedGroupState {
group_id: "test-group".to_string(),
generation_id: 1,
state: "Stable".to_string(),
members: vec!["member-1".to_string()],
offsets: HashMap::new(),
updated_at: 12345678,
};
persistence.save_group(&state).unwrap();
assert_eq!(log_count.load(Ordering::SeqCst), 1);
persistence
.save_offset("test-group", "topic-1", 0, 100)
.unwrap();
assert_eq!(log_count.load(Ordering::SeqCst), 2);
persistence.delete_group("test-group").unwrap();
assert_eq!(log_count.load(Ordering::SeqCst), 3);
}
#[test]
fn test_raft_persistence_apply_log_entry() {
let persistence = RaftPersistence::new();
let entry = RaftLogEntry::GroupStateChange(PersistedGroupState {
group_id: "test-group".to_string(),
generation_id: 1,
state: "Stable".to_string(),
members: vec!["member-1".to_string()],
offsets: HashMap::new(),
updated_at: 12345678,
});
persistence.apply_log_entry(&entry).unwrap();
let loaded = persistence.load_group("test-group").unwrap();
assert!(loaded.is_some());
let entry = RaftLogEntry::OffsetCommit {
group_id: "test-group".to_string(),
topic: "topic-1".to_string(),
partition: 0,
offset: 100,
};
persistence.apply_log_entry(&entry).unwrap();
let offsets = persistence.load_offsets("test-group").unwrap();
let topic1_offsets = offsets.get("topic-1").unwrap();
assert_eq!(topic1_offsets.get(&0), Some(&100));
let entry = RaftLogEntry::GroupDeleted("test-group".to_string());
persistence.apply_log_entry(&entry).unwrap();
let loaded = persistence.load_group("test-group").unwrap();
assert!(loaded.is_none());
}
#[test]
fn test_raft_persistence_snapshot() {
let persistence = RaftPersistence::new();
let entry = RaftLogEntry::GroupStateChange(PersistedGroupState {
group_id: "group-1".to_string(),
generation_id: 1,
state: "Stable".to_string(),
members: vec!["member-1".to_string()],
offsets: HashMap::new(),
updated_at: 12345678,
});
persistence.apply_log_entry(&entry).unwrap();
let entry = RaftLogEntry::OffsetCommit {
group_id: "group-1".to_string(),
topic: "topic-1".to_string(),
partition: 0,
offset: 100,
};
persistence.apply_log_entry(&entry).unwrap();
let snapshot = persistence.create_snapshot().unwrap();
assert_eq!(snapshot.len(), 1);
assert!(snapshot.contains_key("group-1"));
assert_eq!(snapshot.get("group-1").unwrap().group_id, "group-1");
let persistence2 = RaftPersistence::new();
persistence2.restore_snapshot(snapshot).unwrap();
let loaded = persistence2.load_group("group-1").unwrap();
assert!(loaded.is_some());
let offsets = persistence2.load_offsets("group-1").unwrap();
let topic1_offsets = offsets.get("topic-1").unwrap();
assert_eq!(topic1_offsets.get(&0), Some(&100));
}
#[test]
fn test_coordinator_with_raft_persistence() {
let entries = Arc::new(std::sync::Mutex::new(Vec::<Vec<u8>>::new()));
let entries_clone = entries.clone();
let raft_persistence = Arc::new(RaftPersistence::with_callback(move |bytes| {
entries_clone.lock().unwrap().push(bytes.to_vec());
true
}));
let coordinator = ConsumerCoordinator::with_persistence(raft_persistence);
let _response = coordinator
.join_group(
"test-group".to_string(),
None,
"client-1".to_string(),
Duration::from_secs(10),
Duration::from_secs(60),
vec!["topic-1".to_string()],
vec![],
)
.unwrap();
let logged_entries = entries.lock().unwrap();
assert!(!logged_entries.is_empty());
let entry: RaftLogEntry = serde_json::from_slice(&logged_entries[0]).unwrap();
match entry {
RaftLogEntry::GroupStateChange(state) => {
assert_eq!(state.group_id, "test-group");
}
_ => panic!("Expected GroupStateChange entry"),
}
}
#[test]
fn test_coordinator_restore() {
let raft_persistence = Arc::new(RaftPersistence::new());
let mut topic_offsets = HashMap::new();
topic_offsets.insert(0u32, 500i64);
let mut offsets = HashMap::new();
offsets.insert("topic-1".to_string(), topic_offsets);
let entry = RaftLogEntry::GroupStateChange(PersistedGroupState {
group_id: "restored-group".to_string(),
generation_id: 5,
state: "Stable".to_string(),
members: vec!["member-1-uuid".to_string()],
offsets,
updated_at: 12345678,
});
raft_persistence.apply_log_entry(&entry).unwrap();
let coordinator = ConsumerCoordinator::with_persistence(raft_persistence);
coordinator.restore().unwrap();
let response = coordinator
.fetch_offset("restored-group".to_string(), "topic-1".to_string(), 0)
.unwrap();
assert_eq!(response.offset, Some(500));
}
}