use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet};
pub type MemberInput = (String, Vec<String>, Vec<(String, i32)>, i32);
#[must_use]
pub fn assign(
members: &[MemberInput],
topic_partitions: &HashMap<String, i32>,
) -> HashMap<String, Vec<(String, i32)>> {
if members.is_empty() {
return HashMap::new();
}
let mut member_ids: Vec<String> = members.iter().map(|(id, _, _, _)| id.clone()).collect();
member_ids.sort();
let subs: BTreeMap<String, BTreeSet<String>> = members
.iter()
.map(|(id, subs, _, _)| (id.clone(), subs.iter().cloned().collect()))
.collect();
let gens: BTreeMap<String, i32> = members
.iter()
.map(|(id, _, _, generation)| (id.clone(), *generation))
.collect();
let current_assignment =
prepopulate_current_assignments(members, &subs, &gens, topic_partitions);
let mut previous_owner: HashMap<(String, i32), String> = HashMap::new();
for (m, parts) in ¤t_assignment {
for tp in parts {
previous_owner.insert(tp.clone(), m.clone());
}
}
let all_equal = {
let mut iter = subs.values();
match iter.next() {
None => true,
Some(first) => iter.all(|s| s == first),
}
};
let raw_assignment = if all_equal {
constrained_assign(&member_ids, &subs, ¤t_assignment, topic_partitions)
} else {
general_assign(&member_ids, &subs, ¤t_assignment, topic_partitions)
};
let mut adjusted: HashMap<String, Vec<(String, i32)>> = HashMap::new();
for id in &member_ids {
adjusted.insert(id.clone(), Vec::new());
}
for id in &member_ids {
let new_parts = raw_assignment.get(id).cloned().unwrap_or_default();
for tp in new_parts {
match previous_owner.get(&tp) {
Some(prev) if prev != id => {
}
_ => {
adjusted.get_mut(id).unwrap().push(tp);
}
}
}
}
for v in adjusted.values_mut() {
v.sort();
}
adjusted
}
fn prepopulate_current_assignments(
members: &[MemberInput],
subs: &BTreeMap<String, BTreeSet<String>>,
_gens: &BTreeMap<String, i32>,
topic_partitions: &HashMap<String, i32>,
) -> BTreeMap<String, Vec<(String, i32)>> {
let mut best: HashMap<(String, i32), (i32, String, bool)> = HashMap::new();
for (id, _id_subs, owned, generation) in members {
let id_subs = subs.get(id).cloned().unwrap_or_default();
for (t, p) in owned {
let Some(&pcount) = topic_partitions.get(t) else {
continue;
};
if *p < 0 || *p >= pcount {
continue;
}
if !id_subs.contains(t) {
continue;
}
let key = (t.clone(), *p);
match best.get(&key) {
None => {
best.insert(key, (*generation, id.clone(), false));
}
Some((existing_gen, _existing_id, _tied)) => {
if *generation > *existing_gen {
best.insert(key, (*generation, id.clone(), false));
} else if *generation == *existing_gen {
let prev = best.get(&key).cloned().unwrap();
best.insert(key, (prev.0, prev.1, true));
}
}
}
}
}
let mut out: BTreeMap<String, Vec<(String, i32)>> = BTreeMap::new();
for (id, _, _, _) in members {
out.insert(id.clone(), Vec::new());
}
for (tp, (_gen, owner, tied)) in best {
if tied {
continue;
}
out.entry(owner).or_default().push(tp);
}
for v in out.values_mut() {
v.sort();
}
out
}
fn constrained_assign(
member_ids: &[String],
subs: &BTreeMap<String, BTreeSet<String>>,
current_assignment: &BTreeMap<String, Vec<(String, i32)>>,
topic_partitions: &HashMap<String, i32>,
) -> HashMap<String, Vec<(String, i32)>> {
let shared: BTreeSet<String> = member_ids
.first()
.and_then(|m| subs.get(m).cloned())
.unwrap_or_default();
let mut all_partitions: Vec<(String, i32)> = Vec::new();
let mut topics_sorted: Vec<&String> = shared.iter().collect();
topics_sorted.sort();
for t in topics_sorted {
let Some(&n) = topic_partitions.get(t) else {
continue;
};
for p in 0..n {
all_partitions.push((t.clone(), p));
}
}
let num_members = member_ids.len();
if num_members == 0 {
return HashMap::new();
}
let total = all_partitions.len();
let base = total / num_members;
let remainder = total % num_members;
let mut targets: BTreeMap<String, usize> = BTreeMap::new();
for (i, id) in member_ids.iter().enumerate() {
let extra = usize::from(i < remainder);
targets.insert(id.clone(), base + extra);
}
let mut out: HashMap<String, Vec<(String, i32)>> = HashMap::new();
let mut taken: HashSet<(String, i32)> = HashSet::new();
for id in member_ids {
let mut owned: Vec<(String, i32)> = current_assignment.get(id).cloned().unwrap_or_default();
owned.sort();
let target = *targets.get(id).unwrap_or(&0);
let kept: Vec<(String, i32)> = owned.into_iter().take(target).collect();
for tp in &kept {
taken.insert(tp.clone());
}
out.insert(id.clone(), kept);
}
let mut unassigned: Vec<(String, i32)> = all_partitions
.into_iter()
.filter(|tp| !taken.contains(tp))
.collect();
unassigned.sort();
let mut member_cursor = 0usize;
for tp in unassigned {
let mut placed = false;
for _ in 0..num_members {
let id = &member_ids[member_cursor % num_members];
member_cursor += 1;
let target = *targets.get(id).unwrap_or(&0);
let slot = out.entry(id.clone()).or_default();
if slot.len() < target {
slot.push(tp);
placed = true;
break;
}
}
debug_assert!(
placed,
"constrained_assign: unassigned partition found no slot despite Σtargets == total"
);
if !placed {
break;
}
}
for v in out.values_mut() {
v.sort();
}
out
}
#[allow(clippy::too_many_lines)]
fn general_assign(
member_ids: &[String],
subs: &BTreeMap<String, BTreeSet<String>>,
current_assignment: &BTreeMap<String, Vec<(String, i32)>>,
topic_partitions: &HashMap<String, i32>,
) -> HashMap<String, Vec<(String, i32)>> {
let mut subscribers_per_topic: BTreeMap<String, BTreeSet<String>> = BTreeMap::new();
for (id, s) in subs {
for t in s {
if topic_partitions.contains_key(t) {
subscribers_per_topic
.entry(t.clone())
.or_default()
.insert(id.clone());
}
}
}
let mut sorted_partitions: Vec<(String, i32)> = Vec::new();
for t in subscribers_per_topic.keys() {
let Some(&n) = topic_partitions.get(t) else {
continue;
};
for p in 0..n {
sorted_partitions.push((t.clone(), p));
}
}
sorted_partitions.sort_by(|a, b| {
let sa = subscribers_per_topic.get(&a.0).map_or(0, BTreeSet::len);
let sb = subscribers_per_topic.get(&b.0).map_or(0, BTreeSet::len);
sa.cmp(&sb)
.then_with(|| a.0.cmp(&b.0))
.then_with(|| a.1.cmp(&b.1))
});
let mut new_assignment: BTreeMap<String, Vec<(String, i32)>> = BTreeMap::new();
for id in member_ids {
new_assignment.insert(id.clone(), Vec::new());
}
let mut prev_owner: HashMap<(String, i32), String> = HashMap::new();
for (m, parts) in current_assignment {
for tp in parts {
prev_owner.insert(tp.clone(), m.clone());
}
}
for tp in &sorted_partitions {
let subscribed: BTreeSet<String> = subscribers_per_topic
.get(&tp.0)
.cloned()
.unwrap_or_default();
if subscribed.is_empty() {
continue;
}
let mut best: Option<(usize, &String)> = None;
for id in &subscribed {
let load = new_assignment.get(id).map_or(0, Vec::len);
match best {
None => best = Some((load, id)),
Some((bload, _)) if load < bload => best = Some((load, id)),
_ => {}
}
}
let (min_load, _) = best.expect("subscribed non-empty");
let chosen: String = match prev_owner.get(tp) {
Some(prev) if subscribed.contains(prev) => {
let prev_load = new_assignment.get(prev).map_or(0, Vec::len);
if prev_load <= min_load {
prev.clone()
} else {
pick_least_loaded(&subscribed, &new_assignment)
}
}
_ => pick_least_loaded(&subscribed, &new_assignment),
};
new_assignment
.get_mut(&chosen)
.expect("chosen exists")
.push(tp.clone());
}
let max_iters = sorted_partitions.len().saturating_mul(member_ids.len()) + 16;
for _ in 0..max_iters {
let mut heaviest: Option<(usize, String)> = None;
let mut lightest: Option<(usize, String)> = None;
for id in member_ids {
let load = new_assignment.get(id).map_or(0, Vec::len);
match &heaviest {
None => heaviest = Some((load, id.clone())),
Some((hl, _)) if load > *hl => heaviest = Some((load, id.clone())),
_ => {}
}
match &lightest {
None => lightest = Some((load, id.clone())),
Some((ll, _)) if load < *ll => lightest = Some((load, id.clone())),
_ => {}
}
}
let Some((hload, hid)) = heaviest else { break };
let Some((lload, lid)) = lightest else { break };
if hload <= lload + 1 {
break;
}
let l_subs = subs.get(&lid).cloned().unwrap_or_default();
let h_parts = new_assignment.get(&hid).cloned().unwrap_or_default();
let mut candidates: Vec<(String, i32)> = h_parts
.into_iter()
.filter(|tp| l_subs.contains(&tp.0))
.collect();
if candidates.is_empty() {
break;
}
candidates.sort_by(|a, b| {
let a_was_lids = prev_owner.get(a) == Some(&lid);
let b_was_lids = prev_owner.get(b) == Some(&lid);
a_was_lids
.cmp(&b_was_lids)
.then_with(|| a.0.cmp(&b.0))
.then_with(|| a.1.cmp(&b.1))
});
let moved = candidates.into_iter().next().expect("non-empty");
let h_vec = new_assignment.get_mut(&hid).unwrap();
if let Some(pos) = h_vec.iter().position(|x| x == &moved) {
h_vec.remove(pos);
}
new_assignment.get_mut(&lid).unwrap().push(moved);
}
let mut out: HashMap<String, Vec<(String, i32)>> = HashMap::new();
for (k, mut v) in new_assignment {
v.sort();
out.insert(k, v);
}
out
}
fn pick_least_loaded(
subscribed: &BTreeSet<String>,
assignment: &BTreeMap<String, Vec<(String, i32)>>,
) -> String {
let mut best: Option<(usize, String)> = None;
for id in subscribed {
let load = assignment.get(id).map_or(0, Vec::len);
match &best {
None => best = Some((load, id.clone())),
Some((bl, bid)) => {
if load < *bl || (load == *bl && id < bid) {
best = Some((load, id.clone()));
}
}
}
}
best.expect("non-empty subscription").1
}
#[cfg(test)]
mod tests {
use super::*;
use assert2::assert;
fn tp(items: &[(&str, i32)]) -> Vec<(String, i32)> {
items.iter().map(|(t, p)| ((*t).to_string(), *p)).collect()
}
fn topics(ts: &[&str]) -> Vec<String> {
ts.iter().map(|s| (*s).to_string()).collect()
}
fn total_assigned(out: &HashMap<String, Vec<(String, i32)>>) -> usize {
out.values().map(Vec::len).sum()
}
#[test]
fn single_member_takes_all() {
let mut topic_parts = HashMap::new();
topic_parts.insert("t".to_string(), 4);
let a = assign(
&[("m1".to_string(), topics(&["t"]), vec![], 0)],
&topic_parts,
);
assert!(a["m1"].len() == 4);
assert!(a["m1"] == tp(&[("t", 0), ("t", 1), ("t", 2), ("t", 3)]));
}
#[test]
fn fresh_join_balances() {
let mut topic_parts = HashMap::new();
topic_parts.insert("t".to_string(), 4);
let a = assign(
&[
("m1".to_string(), topics(&["t"]), vec![], 0),
("m2".to_string(), topics(&["t"]), vec![], 0),
],
&topic_parts,
);
assert!(a["m1"].len() == 2);
assert!(a["m2"].len() == 2);
assert!(total_assigned(&a) == 4);
}
#[test]
fn all_owned_stable_no_op() {
let mut topic_parts = HashMap::new();
topic_parts.insert("t".to_string(), 4);
let owned_m1 = tp(&[("t", 0), ("t", 1)]);
let owned_m2 = tp(&[("t", 2), ("t", 3)]);
let a = assign(
&[
("m1".to_string(), topics(&["t"]), owned_m1.clone(), 5),
("m2".to_string(), topics(&["t"]), owned_m2.clone(), 5),
],
&topic_parts,
);
assert!(a["m1"] == owned_m1);
assert!(a["m2"] == owned_m2);
assert!(total_assigned(&a) == 4);
}
#[test]
fn partial_revocation_on_member_join() {
let mut topic_parts = HashMap::new();
topic_parts.insert("t".to_string(), 9);
let owned_m1 = tp(&[("t", 0), ("t", 1), ("t", 2), ("t", 3)]);
let owned_m2 = tp(&[("t", 4), ("t", 5), ("t", 6), ("t", 7)]);
let a = assign(
&[
("m1".to_string(), topics(&["t"]), owned_m1.clone(), 5),
("m2".to_string(), topics(&["t"]), owned_m2.clone(), 5),
("m3".to_string(), topics(&["t"]), vec![], 5),
],
&topic_parts,
);
assert!(a["m1"].len() == 3);
assert!(a["m2"].len() == 3);
assert!(a["m3"].len() <= 1, "m3 got {:?}", a["m3"]);
assert!(
total_assigned(&a) < 9,
"expected omitted partitions, got {} total",
total_assigned(&a)
);
for tp_m3 in &a["m3"] {
assert!(!a["m1"].contains(tp_m3));
assert!(!a["m2"].contains(tp_m3));
}
}
#[test]
fn phase2_picks_up_revoked() {
let mut topic_parts = HashMap::new();
topic_parts.insert("t".to_string(), 9);
let phase1 = assign(
&[
(
"m1".to_string(),
topics(&["t"]),
tp(&[("t", 0), ("t", 1), ("t", 2), ("t", 3)]),
5,
),
(
"m2".to_string(),
topics(&["t"]),
tp(&[("t", 4), ("t", 5), ("t", 6), ("t", 7)]),
5,
),
("m3".to_string(), topics(&["t"]), vec![], 5),
],
&topic_parts,
);
let phase2 = assign(
&[
("m1".to_string(), topics(&["t"]), phase1["m1"].clone(), 6),
("m2".to_string(), topics(&["t"]), phase1["m2"].clone(), 6),
("m3".to_string(), topics(&["t"]), phase1["m3"].clone(), 6),
],
&topic_parts,
);
assert!(total_assigned(&phase2) == 9, "phase 2 must place all 9");
assert!(phase2["m1"].len() == 3);
assert!(phase2["m2"].len() == 3);
assert!(phase2["m3"].len() == 3);
}
#[test]
fn member_leaves_partitions_redistributed() {
let mut topic_parts = HashMap::new();
topic_parts.insert("t".to_string(), 4);
let a = assign(
&[(
"m2".to_string(),
topics(&["t"]),
tp(&[("t", 2), ("t", 3)]),
5,
)],
&topic_parts,
);
assert!(a["m2"].len() == 4);
assert!(total_assigned(&a) == 4);
}
#[test]
fn multi_topic_asymmetric_subscriptions() {
let mut topic_parts = HashMap::new();
topic_parts.insert("t1".to_string(), 2);
topic_parts.insert("t2".to_string(), 2);
let a = assign(
&[
("m1".to_string(), topics(&["t1", "t2"]), vec![], 0),
("m2".to_string(), topics(&["t1"]), vec![], 0),
],
&topic_parts,
);
for (t, _) in &a["m2"] {
assert!(t == "t1");
}
let m1_t2: Vec<&(String, i32)> = a["m1"].iter().filter(|(t, _)| t == "t2").collect();
assert!(m1_t2.len() == 2);
assert!(total_assigned(&a) == 4);
assert!(a["m1"].len() == 2);
assert!(a["m2"].len() == 2);
}
#[test]
fn generation_zombie_lower_loses() {
let mut topic_parts = HashMap::new();
topic_parts.insert("t".to_string(), 4);
let a = assign(
&[
(
"m1".to_string(),
topics(&["t"]),
tp(&[("t", 0), ("t", 1)]),
5,
),
("m2".to_string(), topics(&["t"]), tp(&[("t", 0)]), 4),
],
&topic_parts,
);
assert!(a["m1"].contains(&("t".to_string(), 0)));
assert!(a["m1"].contains(&("t".to_string(), 1)));
assert!(!a["m2"].contains(&("t".to_string(), 0)));
}
#[test]
fn generation_zombie_tie_both_lose() {
let mut topic_parts = HashMap::new();
topic_parts.insert("t".to_string(), 4);
let a = assign(
&[
(
"m1".to_string(),
topics(&["t"]),
tp(&[("t", 0), ("t", 1)]),
5,
),
(
"m2".to_string(),
topics(&["t"]),
tp(&[("t", 0), ("t", 2)]),
5,
),
],
&topic_parts,
);
assert!(total_assigned(&a) == 4);
}
#[test]
fn brand_new_topic() {
let mut topic_parts = HashMap::new();
topic_parts.insert("newt".to_string(), 3);
let a = assign(
&[
("m1".to_string(), topics(&["newt"]), vec![], 5),
("m2".to_string(), topics(&["newt"]), vec![], 5),
],
&topic_parts,
);
assert!(total_assigned(&a) == 3);
assert!(a["m1"].len() == 2);
assert!(a["m2"].len() == 1);
}
#[test]
fn partition_count_decreased() {
let mut topic_parts = HashMap::new();
topic_parts.insert("t".to_string(), 3);
let a = assign(
&[(
"m1".to_string(),
topics(&["t"]),
tp(&[("t", 0), ("t", 5)]),
5,
)],
&topic_parts,
);
assert!(a["m1"].len() == 3);
assert!(!a["m1"].contains(&("t".to_string(), 5)));
}
#[test]
fn empty_members_returns_empty() {
let topic_parts = HashMap::new();
let a = assign(&[], &topic_parts);
assert!(a.is_empty());
}
#[test]
fn member_with_no_subscriptions() {
let mut topic_parts = HashMap::new();
topic_parts.insert("t".to_string(), 2);
let a = assign(
&[
("m1".to_string(), topics(&["t"]), vec![], 0),
("m2".to_string(), vec![], vec![], 0),
],
&topic_parts,
);
assert!(a["m1"].len() == 2);
assert!(a["m2"].len() == 0);
}
}