use std::collections::{BTreeMap, HashMap, HashSet};
use std::sync::Arc;
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use tracing::debug;
#[derive(Debug, Error)]
pub enum ShardManagerError {
#[error("unknown node: {0}")]
UnknownNode(String),
#[error("node already registered: {0}")]
NodeAlreadyExists(String),
#[error("n_shards must be >= 1")]
NoShards,
#[error("no nodes available to assign shards")]
NoNodes,
}
pub type ShardManagerResult<T> = std::result::Result<T, ShardManagerError>;
pub type ShardId = u32;
pub type NodeId = String;
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
pub struct ShardAssignment {
pub map: BTreeMap<ShardId, NodeId>,
}
impl ShardAssignment {
pub fn from_vec(nodes_per_shard: Vec<NodeId>) -> Self {
let map = nodes_per_shard
.into_iter()
.enumerate()
.map(|(i, n)| (i as ShardId, n))
.collect();
Self { map }
}
pub fn n_shards(&self) -> usize {
self.map.len()
}
pub fn owner_of(&self, shard: ShardId) -> Option<&NodeId> {
self.map.get(&shard)
}
pub fn counts(&self) -> HashMap<NodeId, usize> {
let mut counts = HashMap::new();
for owner in self.map.values() {
*counts.entry(owner.clone()).or_insert(0) += 1;
}
counts
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct ShardMove {
pub shard: ShardId,
pub from: Option<NodeId>,
pub to: NodeId,
}
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
pub struct RebalancePlan {
pub new_assignment: ShardAssignment,
pub moves: Vec<ShardMove>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ShardManagerConfig {
pub n_shards: u32,
}
impl Default for ShardManagerConfig {
fn default() -> Self {
Self { n_shards: 8 }
}
}
pub struct ShardManager {
config: ShardManagerConfig,
nodes: RwLock<BTreeMap<NodeId, NodeMeta>>,
assignment: RwLock<ShardAssignment>,
plans_emitted: Arc<RwLock<u64>>,
}
impl std::fmt::Debug for ShardManager {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ShardManager")
.field("config", &self.config)
.field("nodes", &self.nodes.read().keys().collect::<Vec<_>>())
.field("plans_emitted", &*self.plans_emitted.read())
.finish()
}
}
#[derive(Debug, Clone)]
struct NodeMeta {
seq: u64,
}
impl ShardManager {
pub fn new(config: ShardManagerConfig) -> ShardManagerResult<Self> {
if config.n_shards == 0 {
return Err(ShardManagerError::NoShards);
}
Ok(Self {
config,
nodes: RwLock::new(BTreeMap::new()),
assignment: RwLock::new(ShardAssignment::default()),
plans_emitted: Arc::new(RwLock::new(0)),
})
}
pub fn with_nodes(
config: ShardManagerConfig,
nodes: impl IntoIterator<Item = impl Into<NodeId>>,
) -> ShardManagerResult<Self> {
let mgr = Self::new(config)?;
for n in nodes {
let _ = mgr.add_node(n.into())?;
}
Ok(mgr)
}
pub fn plans_emitted(&self) -> u64 {
*self.plans_emitted.read()
}
pub fn owner_of(&self, shard: ShardId) -> Option<NodeId> {
self.assignment.read().owner_of(shard).cloned()
}
pub fn shards_owned_by(&self, node_id: &str) -> Vec<ShardId> {
self.assignment
.read()
.map
.iter()
.filter(|(_, owner)| owner.as_str() == node_id)
.map(|(s, _)| *s)
.collect()
}
pub fn current_assignment(&self) -> ShardAssignment {
self.assignment.read().clone()
}
pub fn node_count(&self) -> usize {
self.nodes.read().len()
}
pub fn add_node(&self, node_id: NodeId) -> ShardManagerResult<RebalancePlan> {
{
let mut nodes = self.nodes.write();
if nodes.contains_key(&node_id) {
return Err(ShardManagerError::NodeAlreadyExists(node_id));
}
let seq = nodes.len() as u64;
nodes.insert(node_id.clone(), NodeMeta { seq });
}
let plan = self.recompute_plan()?;
debug!(node = %node_id, moves = plan.moves.len(), "shard manager: add_node");
Ok(plan)
}
pub fn remove_node(&self, node_id: &str) -> ShardManagerResult<RebalancePlan> {
{
let mut nodes = self.nodes.write();
if nodes.remove(node_id).is_none() {
return Err(ShardManagerError::UnknownNode(node_id.to_string()));
}
}
let plan = self.recompute_plan()?;
debug!(node = %node_id, moves = plan.moves.len(), "shard manager: remove_node");
Ok(plan)
}
pub fn install_assignment(&self, new_assignment: ShardAssignment) -> RebalancePlan {
let old = self.assignment.read().clone();
let moves = compute_moves(&old, &new_assignment);
*self.assignment.write() = new_assignment.clone();
*self.plans_emitted.write() += 1;
RebalancePlan {
new_assignment,
moves,
}
}
fn recompute_plan(&self) -> ShardManagerResult<RebalancePlan> {
let nodes_snap = self.nodes.read().clone();
if nodes_snap.is_empty() {
let empty = ShardAssignment::default();
let old = self.assignment.read().clone();
let moves: Vec<ShardMove> = old
.map
.iter()
.map(|(shard, owner)| ShardMove {
shard: *shard,
from: Some(owner.clone()),
to: String::new(),
})
.collect();
*self.assignment.write() = empty.clone();
*self.plans_emitted.write() += 1;
return Ok(RebalancePlan {
new_assignment: empty,
moves,
});
}
let nodes: Vec<NodeId> = {
let mut by_seq: Vec<(u64, NodeId)> = nodes_snap
.iter()
.map(|(id, m)| (m.seq, id.clone()))
.collect();
by_seq.sort();
by_seq.into_iter().map(|(_, id)| id).collect()
};
let new_assignment = balanced_assignment(self.config.n_shards, &nodes);
let old = self.assignment.read().clone();
let moves = compute_moves(&old, &new_assignment);
*self.assignment.write() = new_assignment.clone();
*self.plans_emitted.write() += 1;
Ok(RebalancePlan {
new_assignment,
moves,
})
}
}
fn balanced_assignment(n_shards: u32, nodes: &[NodeId]) -> ShardAssignment {
if nodes.is_empty() {
return ShardAssignment::default();
}
let n = nodes.len() as u32;
let mut map = BTreeMap::new();
for shard in 0..n_shards {
let owner = &nodes[(shard % n) as usize];
map.insert(shard, owner.clone());
}
ShardAssignment { map }
}
fn compute_moves(old: &ShardAssignment, new_assignment: &ShardAssignment) -> Vec<ShardMove> {
let mut moves = Vec::new();
let all_shards: HashSet<ShardId> = old
.map
.keys()
.chain(new_assignment.map.keys())
.cloned()
.collect();
let mut shards: Vec<ShardId> = all_shards.into_iter().collect();
shards.sort();
for shard in shards {
let from = old.map.get(&shard).cloned();
let to = new_assignment.map.get(&shard).cloned();
match (from, to) {
(Some(f), Some(t)) if f == t => {}
(Some(f), Some(t)) => moves.push(ShardMove {
shard,
from: Some(f),
to: t,
}),
(None, Some(t)) => moves.push(ShardMove {
shard,
from: None,
to: t,
}),
(Some(f), None) => moves.push(ShardMove {
shard,
from: Some(f),
to: String::new(),
}),
(None, None) => {}
}
}
moves
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn balanced_assignment_round_robins() {
let assignment =
balanced_assignment(6, &["n1".to_string(), "n2".to_string(), "n3".to_string()]);
let counts = assignment.counts();
for c in counts.values() {
assert_eq!(*c, 2);
}
}
#[test]
fn add_node_initial_plan() {
let mgr = ShardManager::new(ShardManagerConfig { n_shards: 4 }).expect("ok");
let plan = mgr.add_node("n1".into()).expect("add");
assert_eq!(plan.new_assignment.n_shards(), 4);
for owner in plan.new_assignment.map.values() {
assert_eq!(owner, "n1");
}
}
#[test]
fn add_node_balances_existing() {
let mgr = ShardManager::new(ShardManagerConfig { n_shards: 6 }).expect("ok");
mgr.add_node("n1".into()).expect("ok");
let plan = mgr.add_node("n2".into()).expect("ok");
let counts = plan.new_assignment.counts();
assert_eq!(counts.get("n1"), Some(&3));
assert_eq!(counts.get("n2"), Some(&3));
assert_eq!(plan.moves.len(), 3);
}
#[test]
fn remove_node_redistributes() {
let mgr = ShardManager::new(ShardManagerConfig { n_shards: 6 }).expect("ok");
mgr.add_node("n1".into()).expect("ok");
mgr.add_node("n2".into()).expect("ok");
mgr.add_node("n3".into()).expect("ok");
let plan = mgr.remove_node("n2").expect("ok");
let counts = plan.new_assignment.counts();
assert!(!counts.contains_key("n2"));
let total: usize = counts.values().sum();
assert_eq!(total, 6);
}
#[test]
fn empty_node_list_returns_empty_assignment() {
let mgr = ShardManager::new(ShardManagerConfig { n_shards: 3 }).expect("ok");
mgr.add_node("n1".into()).expect("ok");
let plan = mgr.remove_node("n1").expect("ok");
assert!(plan.new_assignment.map.is_empty());
assert_eq!(plan.moves.len(), 3);
}
#[test]
fn install_assignment_overrides_state() {
let mgr = ShardManager::new(ShardManagerConfig { n_shards: 2 }).expect("ok");
let new_assignment = ShardAssignment::from_vec(vec!["nA".into(), "nB".into()]);
let plan = mgr.install_assignment(new_assignment.clone());
assert_eq!(plan.new_assignment, new_assignment);
assert_eq!(mgr.owner_of(0), Some("nA".to_string()));
assert_eq!(mgr.owner_of(1), Some("nB".to_string()));
assert_eq!(plan.moves.len(), 2);
}
#[test]
fn duplicate_add_rejected() {
let mgr = ShardManager::new(ShardManagerConfig { n_shards: 2 }).expect("ok");
mgr.add_node("n1".into()).expect("ok");
let err = mgr.add_node("n1".into()).expect_err("should fail");
assert!(matches!(err, ShardManagerError::NodeAlreadyExists(_)));
}
#[test]
fn unknown_remove_rejected() {
let mgr = ShardManager::new(ShardManagerConfig { n_shards: 2 }).expect("ok");
let err = mgr.remove_node("ghost").expect_err("should fail");
assert!(matches!(err, ShardManagerError::UnknownNode(_)));
}
#[test]
fn n_shards_zero_rejected() {
let err = ShardManager::new(ShardManagerConfig { n_shards: 0 }).expect_err("should fail");
assert!(matches!(err, ShardManagerError::NoShards));
}
#[test]
fn shards_owned_by_returns_correct_subset() {
let mgr =
ShardManager::with_nodes(ShardManagerConfig { n_shards: 4 }, ["n1", "n2"]).expect("ok");
let s1 = mgr.shards_owned_by("n1");
let s2 = mgr.shards_owned_by("n2");
assert_eq!(s1, vec![0, 2]);
assert_eq!(s2, vec![1, 3]);
}
}