use std::collections::{HashMap, HashSet};
use std::time::Instant;
use bytes::{BufMut, Bytes, BytesMut};
use crabka_protocol::owned::consumer_protocol_assignment::{
ConsumerProtocolAssignment, TopicPartition,
};
use crabka_protocol::owned::consumer_protocol_subscription::ConsumerProtocolSubscription;
use crabka_protocol::primitives::uuid::Uuid;
use crabka_protocol::{Decode, Encode};
use super::classic_state::{Group as ClassicState, Member as ClassicMember, select_protocol};
use super::consumer_state::{ClassicMemberFacade, GroupState as ConsumerState, MemberState};
use super::persistence_next_gen::MemberAssignmentState;
use super::reconciler::ReconcileInput;
pub(crate) fn decode_consumer_subscription(
metadata: &[u8],
) -> Option<ConsumerProtocolSubscription> {
use bytes::Buf;
if metadata.len() < 2 {
return None;
}
let mut cur = metadata;
let version = cur.get_i16();
if !(0..=3).contains(&version) {
return None;
}
ConsumerProtocolSubscription::decode(&mut cur, version).ok()
}
pub(crate) fn classic_is_convertible(state: &ClassicState) -> bool {
if state.protocol_type.as_deref() != Some("consumer") {
return false;
}
state
.members
.values()
.all(|m| decode_consumer_subscription(&m.protocol_metadata).is_some())
}
pub(crate) fn consumer_is_convertible() -> bool {
true
}
pub(crate) fn convert_classic_to_consumer(classic: &ClassicState) -> ConsumerState {
let mut state = ConsumerState::new(classic.group_id.clone());
state.group_epoch = classic.generation_id.max(0);
for m in classic.members.values() {
let names: HashSet<String> = decode_consumer_subscription(&m.protocol_metadata)
.map(|s| s.topics.into_iter().collect())
.unwrap_or_default();
let facade = ClassicMemberFacade {
generation_id: classic.generation_id,
supported_protocols: m.protocols.clone(),
session_timeout: m.session_timeout,
last_synced_assignment: m.assignment.clone().unwrap_or_default(),
awaiting_sync: true,
};
state.add_or_update_member(MemberState {
member_id: m.member_id.clone(),
instance_id: m.group_instance_id.clone(),
rack_id: None,
client_id: m.client_id.clone(),
client_host: m.host.clone(),
subscribed_topic_names: names,
subscribed_topic_regex: None,
compiled_regex: None,
server_assignor: None,
rebalance_timeout: m.rebalance_timeout,
member_epoch: state.group_epoch,
previous_member_epoch: 0,
assignment_state: MemberAssignmentState::Stable,
assigned_partitions: HashMap::new(),
partitions_pending_revocation: HashMap::new(),
last_seen: Instant::now(),
classic: Some(facade),
});
}
state.dirty = true;
state
}
pub(crate) fn upgrade_pending_records(state: &ConsumerState) -> super::actor::PendingRecords {
let mut pending = super::actor::full_pending_records(state);
pending.classic_group_metadata_tombstone = true;
pending
}
pub(crate) fn convert_consumer_to_classic(
state: &ConsumerState,
image: &ReconcileInput,
) -> ClassicState {
let mut classic = ClassicState::new(state.group_id.clone());
classic.protocol_type = Some("consumer".into());
for (mid, m) in &state.members {
let facade = m
.classic
.as_ref()
.expect("downgrade precondition: all members are hosted classic members");
let seed = member_target_assignment(state, mid, image);
let mut cm = ClassicMember::new(
mid.clone(),
m.client_id.clone(),
m.client_host.clone(),
facade.session_timeout,
m.rebalance_timeout,
facade.supported_protocols.clone(),
)
.with_instance_id(m.instance_id.clone());
cm.assignment = Some(seed);
classic.add_member(cm);
}
if let Some(name) = select_protocol(&classic.members) {
classic.complete_rebalance(&name);
let assignments: std::collections::HashMap<String, bytes::Bytes> = classic
.members
.iter()
.filter_map(|(id, m)| m.assignment.clone().map(|a| (id.clone(), a)))
.collect();
classic.install_assignments(assignments);
}
classic.generation_id = state.group_epoch.max(0);
classic
}
pub(crate) fn downgrade_pending_records(
consumer: &ConsumerState,
classic: &ClassicState,
) -> super::actor::PendingRecords {
let mut pending = super::actor::PendingRecords {
next_gen_group_metadata_tombstone: true,
next_gen_target_metadata_tombstone: true,
classic_group_metadata: Some(super::actor::classic_group_metadata_record(classic)),
..Default::default()
};
for mid in consumer.members.keys() {
pending.member_metadata.push((mid.clone(), None));
pending.target_per_member.push((mid.clone(), None));
pending.current_per_member.push((mid.clone(), None));
}
pending
}
pub(crate) fn target_to_consumer_assignment(
target: &HashMap<Uuid, Vec<i32>>,
image: &ReconcileInput,
) -> Bytes {
let id_to_name: HashMap<Uuid, &str> = image
.topic_id_by_name
.iter()
.map(|(name, id)| (*id, name.as_str()))
.collect();
let mut assigned: Vec<TopicPartition> = target
.iter()
.filter_map(|(tid, parts)| {
id_to_name.get(tid).map(|name| {
let mut p = parts.clone();
p.sort_unstable();
TopicPartition {
topic: (*name).to_string(),
partitions: p,
..Default::default()
}
})
})
.collect();
assigned.sort_by(|a, b| a.topic.cmp(&b.topic));
let assignment = ConsumerProtocolAssignment {
assigned_partitions: assigned,
..Default::default()
};
let mut out = BytesMut::new();
out.put_i16(0); assignment
.encode(&mut out, 0)
.expect("ConsumerProtocolAssignment encode is infallible into BytesMut");
out.freeze()
}
use super::actor::{JoinResult, JoinResultMember, SyncResult};
use crate::codes;
pub(crate) fn serve_classic_heartbeat(
state: &mut ConsumerState,
member_id: &str,
image: &ReconcileInput,
) -> i16 {
let Some(m) = state.members.get(member_id) else {
return codes::UNKNOWN_MEMBER_ID;
};
let current = member_target_assignment(state, member_id, image);
let owes = m
.classic
.as_ref()
.is_none_or(|c| c.last_synced_assignment != current);
if let Some(m) = state.members.get_mut(member_id) {
m.last_seen = Instant::now();
}
if owes {
codes::REBALANCE_IN_PROGRESS
} else {
codes::NONE
}
}
fn member_target_assignment(
state: &ConsumerState,
member_id: &str,
image: &ReconcileInput,
) -> Bytes {
let target = state
.target
.per_member
.get(member_id)
.cloned()
.unwrap_or_default();
target_to_consumer_assignment(&target, image)
}
pub(crate) fn serve_classic_sync(
state: &mut ConsumerState,
member_id: &str,
image: &ReconcileInput,
) -> SyncResult {
if !state.members.contains_key(member_id) {
return SyncResult {
error_code: codes::UNKNOWN_MEMBER_ID,
..Default::default()
};
}
let blob = member_target_assignment(state, member_id, image);
if let Some(m) = state.members.get_mut(member_id)
&& let Some(c) = m.classic.as_mut()
{
c.last_synced_assignment = blob.clone();
c.awaiting_sync = false;
}
SyncResult {
error_code: codes::NONE,
assignment: blob,
protocol_type: Some("consumer".into()),
protocol_name: None,
}
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn upsert_classic_member(
state: &mut ConsumerState,
member_id: &str,
subscription_topics: HashSet<String>,
protocols: Vec<(String, Bytes)>,
client_id: String,
client_host: String,
session_timeout: std::time::Duration,
rebalance_timeout: std::time::Duration,
instance_id: Option<String>,
) {
let existing = state.members.get(member_id);
let assigned_partitions = existing
.map(|m| m.assigned_partitions.clone())
.unwrap_or_default();
let partitions_pending_revocation = existing
.map(|m| m.partitions_pending_revocation.clone())
.unwrap_or_default();
let last_synced_assignment = existing
.and_then(|m| m.classic.as_ref())
.map(|c| c.last_synced_assignment.clone())
.unwrap_or_default();
let member_epoch = existing.map_or(state.group_epoch, |m| m.member_epoch);
let previous_member_epoch = existing.map_or(0, |m| m.previous_member_epoch);
let assignment_state = existing.map_or(MemberAssignmentState::Stable, |m| m.assignment_state);
let facade = ClassicMemberFacade {
generation_id: state.group_epoch,
supported_protocols: protocols,
session_timeout,
last_synced_assignment,
awaiting_sync: existing.is_none(),
};
state.add_or_update_member(MemberState {
member_id: member_id.to_string(),
instance_id,
rack_id: None,
client_id,
client_host,
subscribed_topic_names: subscription_topics,
subscribed_topic_regex: None,
compiled_regex: None,
server_assignor: None,
rebalance_timeout,
member_epoch,
previous_member_epoch,
assignment_state,
assigned_partitions,
partitions_pending_revocation,
last_seen: Instant::now(),
classic: Some(facade),
});
}
pub(crate) fn build_hosted_classic_join_result(
state: &ConsumerState,
member_id: &str,
protocol_name: Option<String>,
) -> JoinResult {
JoinResult {
error_code: codes::NONE,
generation_id: state.group_epoch,
protocol_type: Some("consumer".into()),
protocol_name,
leader: member_id.to_string(),
member_id: member_id.to_string(),
members: vec![JoinResultMember {
member_id: member_id.to_string(),
group_instance_id: state
.members
.get(member_id)
.and_then(|m| m.instance_id.clone()),
metadata: Bytes::new(),
}],
}
}
#[cfg(test)]
mod tests {
use super::*;
use assert2::assert;
use bytes::{Buf, BufMut, Bytes, BytesMut};
use crabka_protocol::Encode;
use crabka_protocol::primitives::uuid::Uuid;
use std::time::Duration;
use super::super::classic_state::{Group, Member};
fn subscription_blob(topics: &[&str]) -> Bytes {
let sub = ConsumerProtocolSubscription {
topics: topics.iter().map(|s| (*s).to_string()).collect(),
..Default::default()
};
let mut out = BytesMut::new();
out.put_i16(0); sub.encode(&mut out, 0).unwrap();
out.freeze()
}
fn consumer_member(id: &str, metadata: Bytes) -> Member {
let mut m = Member::new(
id,
"client",
"127.0.0.1",
Duration::from_secs(30),
Duration::from_mins(1),
vec![("range".into(), metadata.clone())],
);
m.protocol_metadata = metadata;
m
}
#[test]
fn empty_consumer_group_is_convertible() {
let mut g = Group::new("g");
g.protocol_type = Some("consumer".into());
assert!(classic_is_convertible(&g));
}
#[test]
fn non_consumer_protocol_type_is_not_convertible() {
let mut g = Group::new("g");
g.protocol_type = Some("connect".into());
assert!(!classic_is_convertible(&g));
let g2 = Group::new("g2");
assert!(!classic_is_convertible(&g2));
}
#[test]
fn group_of_valid_consumer_members_is_convertible() {
let mut g = Group::new("g");
g.protocol_type = Some("consumer".into());
g.add_member(consumer_member("m1", subscription_blob(&["t1"])));
g.add_member(consumer_member("m2", subscription_blob(&["t1", "t2"])));
assert!(classic_is_convertible(&g));
}
#[test]
fn member_with_undecodable_metadata_blocks_conversion() {
let mut g = Group::new("g");
g.protocol_type = Some("consumer".into());
g.add_member(consumer_member("ok", subscription_blob(&["t1"])));
g.add_member(consumer_member(
"bad",
Bytes::from_static(&[0xff, 0xff, 0x01]),
));
assert!(!classic_is_convertible(&g));
}
#[test]
fn decode_rejects_short_and_bad_version() {
assert!(decode_consumer_subscription(&[]).is_none());
assert!(decode_consumer_subscription(&[0]).is_none());
assert!(decode_consumer_subscription(&[0, 99]).is_none());
}
#[test]
fn consumer_group_always_downgradable() {
assert!(consumer_is_convertible());
}
#[test]
fn convert_preserves_members_subscriptions_and_facade() {
let mut g = Group::new("g");
g.protocol_type = Some("consumer".into());
g.generation_id = 3;
g.add_member(consumer_member("m1", subscription_blob(&["t1"])));
g.add_member(consumer_member("m2", subscription_blob(&["t1", "t2"])));
let state = convert_classic_to_consumer(&g);
assert!(state.group_id == "g");
assert!(state.group_epoch == 3); assert!(state.members.len() == 2);
let m1 = &state.members["m1"];
assert!(m1.is_classic());
assert!(m1.subscribed_topic_names.contains("t1"));
let facade = m1.classic.as_ref().unwrap();
assert!(facade.generation_id == 3);
assert!(facade.awaiting_sync);
let m2 = &state.members["m2"];
assert!(m2.subscribed_topic_names.len() == 2);
assert!(state.dirty);
}
#[test]
fn target_translates_to_consumer_assignment_blob() {
let t1 = Uuid([1; 16]);
let t2 = Uuid([2; 16]);
let image = ReconcileInput {
topic_id_by_name: [("orders".to_string(), t1), ("events".to_string(), t2)].into(),
..Default::default()
};
let target: std::collections::HashMap<Uuid, Vec<i32>> =
[(t1, vec![2, 0, 1]), (t2, vec![5])].into();
let blob = target_to_consumer_assignment(&target, &image);
let mut cur = &blob[..];
let version = cur.get_i16();
assert!(version == 0);
let decoded = ConsumerProtocolAssignment::decode(&mut cur, version).unwrap();
let names: Vec<&str> = decoded
.assigned_partitions
.iter()
.map(|tp| tp.topic.as_str())
.collect();
assert!(names == vec!["events", "orders"]);
let orders = decoded
.assigned_partitions
.iter()
.find(|tp| tp.topic == "orders")
.unwrap();
assert!(orders.partitions == vec![0, 1, 2]);
}
#[test]
fn target_drops_unknown_topic_ids() {
let known = Uuid([1; 16]);
let ghost = Uuid([9; 16]);
let image = ReconcileInput {
topic_id_by_name: [("orders".to_string(), known)].into(),
..Default::default()
};
let target: std::collections::HashMap<Uuid, Vec<i32>> =
[(known, vec![0]), (ghost, vec![0])].into();
let blob = target_to_consumer_assignment(&target, &image);
let mut cur = &blob[..];
let _ = cur.get_i16();
let decoded = ConsumerProtocolAssignment::decode(&mut cur, 0).unwrap();
assert!(decoded.assigned_partitions.len() == 1);
assert!(decoded.assigned_partitions[0].topic == "orders");
}
#[test]
fn downgrade_re_expresses_members_as_classic() {
use crate::coordinator::unified::classic_state::GroupState as ClassicGroupState;
use crate::coordinator::unified::consumer_state::{
ClassicMemberFacade, GroupState, MemberState,
};
use crate::coordinator::unified::persistence_next_gen::MemberAssignmentState;
use std::time::{Duration, Instant};
let t1 = Uuid([1; 16]);
let image = ReconcileInput {
topic_id_by_name: [("orders".to_string(), t1)].into(),
..Default::default()
};
let mut state = GroupState::new("g");
state.group_epoch = 7;
let m = MemberState {
member_id: "m1".into(),
instance_id: Some("inst-a".into()),
rack_id: None,
client_id: "c".into(),
client_host: "/127.0.0.1".into(),
subscribed_topic_names: ["orders".to_string()].into(),
subscribed_topic_regex: None,
compiled_regex: None,
server_assignor: None,
rebalance_timeout: Duration::from_mins(1),
member_epoch: 7,
previous_member_epoch: 6,
assignment_state: MemberAssignmentState::Stable,
assigned_partitions: std::collections::HashMap::new(),
partitions_pending_revocation: std::collections::HashMap::new(),
last_seen: Instant::now(),
classic: Some(ClassicMemberFacade {
generation_id: 7,
supported_protocols: vec![("range".into(), bytes::Bytes::from_static(b"meta"))],
session_timeout: Duration::from_secs(30),
last_synced_assignment: bytes::Bytes::new(),
awaiting_sync: false,
}),
};
state.add_or_update_member(m);
state.target.epoch = 7;
state
.target
.per_member
.insert("m1".into(), [(t1, vec![0, 1])].into());
let classic = convert_consumer_to_classic(&state, &image);
assert!(classic.group_id == "g");
assert!(classic.generation_id == 7);
let member = classic.members.get("m1").expect("member preserved");
assert!(member.group_instance_id.as_deref() == Some("inst-a"));
assert!(member.session_timeout == Duration::from_secs(30));
let asn = member.assignment.clone().expect("seed assignment");
let mut cur = &asn[..];
let version = cur.get_i16();
assert!(version == 0);
let decoded = ConsumerProtocolAssignment::decode(&mut cur, 0).unwrap();
assert!(decoded.assigned_partitions[0].topic == "orders");
assert!(decoded.assigned_partitions[0].partitions == vec![0, 1]);
assert!(classic.state == ClassicGroupState::Stable);
let asn2 = member
.assignment
.clone()
.expect("seed assignment still set after stabilize");
assert!(asn2 == asn);
assert!(classic.protocol_name.as_deref() == Some("range"));
assert!(classic.leader_id.as_deref() == Some("m1"));
}
}