use crate::error::{ClusterError, Result};
use crate::node::{Node, NodeCapabilities, NodeId};
use hashring::HashRing;
use std::collections::{HashMap, HashSet};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum PlacementStrategy {
RoundRobin,
#[default]
ConsistentHash,
LeastLoaded,
RackAware,
}
#[derive(Debug, Clone)]
pub struct PlacementConfig {
pub strategy: PlacementStrategy,
pub rack_aware: bool,
pub virtual_nodes: usize,
pub max_partitions_per_node: usize,
}
impl Default for PlacementConfig {
fn default() -> Self {
Self {
strategy: PlacementStrategy::ConsistentHash,
rack_aware: true,
virtual_nodes: 150, max_partitions_per_node: 0,
}
}
}
pub struct PartitionPlacer {
config: PlacementConfig,
ring: HashRing<NodeId>,
nodes: HashMap<NodeId, NodePlacementInfo>,
racks: HashMap<String, Vec<NodeId>>,
}
#[derive(Debug, Clone)]
struct NodePlacementInfo {
id: NodeId,
rack: Option<String>,
capabilities: NodeCapabilities,
leader_count: u32,
replica_count: u32,
#[allow(dead_code)]
weight: u32,
}
impl PartitionPlacer {
pub fn new(config: PlacementConfig) -> Self {
Self {
config,
ring: HashRing::new(),
nodes: HashMap::new(),
racks: HashMap::new(),
}
}
pub fn add_node(&mut self, node: &Node) {
let info = NodePlacementInfo {
id: node.info.id.clone(),
rack: node.info.rack.clone(),
capabilities: node.info.capabilities,
leader_count: node.partition_leader_count,
replica_count: node.partition_replica_count,
weight: 100, };
for i in 0..self.config.virtual_nodes {
let vnode = format!("{}#{}", node.info.id, i);
self.ring.add(vnode);
}
if let Some(rack) = &node.info.rack {
self.racks
.entry(rack.clone())
.or_default()
.push(node.info.id.clone());
}
self.nodes.insert(node.info.id.clone(), info);
}
pub fn remove_node(&mut self, node_id: &NodeId) {
for i in 0..self.config.virtual_nodes {
let vnode = format!("{}#{}", node_id, i);
self.ring.remove(&vnode);
}
if let Some(info) = self.nodes.get(node_id) {
if let Some(rack) = &info.rack {
if let Some(rack_nodes) = self.racks.get_mut(rack) {
rack_nodes.retain(|n| n != node_id);
}
}
}
self.nodes.remove(node_id);
}
pub fn update_node_load(&mut self, node_id: &NodeId, leader_count: u32, replica_count: u32) {
if let Some(info) = self.nodes.get_mut(node_id) {
info.leader_count = leader_count;
info.replica_count = replica_count;
}
}
pub fn assign_partition(
&self,
topic: &str,
partition: u32,
replication_factor: u16,
) -> Result<Vec<NodeId>> {
let eligible_nodes: Vec<_> = self
.nodes
.values()
.filter(|n| n.capabilities.replica_eligible)
.collect();
if eligible_nodes.len() < replication_factor as usize {
return Err(ClusterError::InvalidReplicationFactor {
factor: replication_factor,
nodes: eligible_nodes.len(),
});
}
match self.config.strategy {
PlacementStrategy::ConsistentHash => {
self.assign_consistent_hash(topic, partition, replication_factor)
}
PlacementStrategy::RoundRobin => {
self.assign_round_robin(topic, partition, replication_factor)
}
PlacementStrategy::LeastLoaded => self.assign_least_loaded(replication_factor),
PlacementStrategy::RackAware => {
self.assign_rack_aware(topic, partition, replication_factor)
}
}
}
fn assign_consistent_hash(
&self,
topic: &str,
partition: u32,
replication_factor: u16,
) -> Result<Vec<NodeId>> {
let key = format!("{}-{}", topic, partition);
let mut replicas = Vec::with_capacity(replication_factor as usize);
let mut seen_nodes = HashSet::new();
let mut seen_racks = HashSet::new();
let ring_nodes: Vec<_> = (0..self.config.virtual_nodes * self.nodes.len())
.filter_map(|i| {
let probe_key = format!("{}-{}", key, i);
self.ring.get(&probe_key).map(|vnode| {
vnode.split('#').next().unwrap_or(vnode).to_string()
})
})
.collect();
for node_id in ring_nodes {
if replicas.len() >= replication_factor as usize {
break;
}
if seen_nodes.contains(&node_id) {
continue;
}
if self.config.rack_aware {
if let Some(info) = self.nodes.get(&node_id) {
if let Some(rack) = &info.rack {
if seen_racks.contains(rack) && seen_racks.len() < self.racks.len() {
continue;
}
seen_racks.insert(rack.clone());
}
}
}
seen_nodes.insert(node_id.clone());
replicas.push(node_id);
}
if replicas.len() < replication_factor as usize {
for node_id in self.nodes.keys() {
if replicas.len() >= replication_factor as usize {
break;
}
if !seen_nodes.contains(node_id) {
replicas.push(node_id.clone());
seen_nodes.insert(node_id.clone());
}
}
}
if replicas.len() < replication_factor as usize {
return Err(ClusterError::InvalidReplicationFactor {
factor: replication_factor,
nodes: self.nodes.len(),
});
}
Ok(replicas)
}
fn assign_round_robin(
&self,
_topic: &str,
partition: u32,
replication_factor: u16,
) -> Result<Vec<NodeId>> {
let eligible: Vec<_> = self
.nodes
.values()
.filter(|n| n.capabilities.replica_eligible)
.map(|n| n.id.clone())
.collect();
if eligible.len() < replication_factor as usize {
return Err(ClusterError::InvalidReplicationFactor {
factor: replication_factor,
nodes: eligible.len(),
});
}
let mut replicas = Vec::with_capacity(replication_factor as usize);
let start = partition as usize % eligible.len();
for i in 0..replication_factor as usize {
let idx = (start + i) % eligible.len();
replicas.push(eligible[idx].clone());
}
Ok(replicas)
}
fn assign_least_loaded(&self, replication_factor: u16) -> Result<Vec<NodeId>> {
let mut eligible: Vec<_> = self
.nodes
.values()
.filter(|n| n.capabilities.replica_eligible)
.collect();
if eligible.len() < replication_factor as usize {
return Err(ClusterError::InvalidReplicationFactor {
factor: replication_factor,
nodes: eligible.len(),
});
}
eligible.sort_by_key(|n| n.leader_count * 3 + n.replica_count);
let replicas: Vec<_> = eligible
.iter()
.take(replication_factor as usize)
.map(|n| n.id.clone())
.collect();
Ok(replicas)
}
fn assign_rack_aware(
&self,
topic: &str,
partition: u32,
replication_factor: u16,
) -> Result<Vec<NodeId>> {
let key = format!("{}-{}", topic, partition);
let mut replicas = Vec::with_capacity(replication_factor as usize);
let mut used_racks = HashSet::new();
let mut used_nodes = HashSet::new();
let mut rack_list: Vec<_> = self.racks.keys().cloned().collect();
rack_list.sort();
for rack in &rack_list {
if replicas.len() >= replication_factor as usize {
break;
}
if let Some(rack_nodes) = self.racks.get(rack) {
let idx = {
let hash = fxhash(&format!("{}-{}", key, rack));
hash as usize % rack_nodes.len()
};
let node_id = &rack_nodes[idx];
if let Some(info) = self.nodes.get(node_id) {
if info.capabilities.replica_eligible && !used_nodes.contains(node_id) {
replicas.push(node_id.clone());
used_nodes.insert(node_id.clone());
used_racks.insert(rack.clone());
}
}
}
}
if replicas.len() < replication_factor as usize {
let mut remaining: Vec<_> = self
.nodes
.values()
.filter(|n| n.capabilities.replica_eligible && !used_nodes.contains(&n.id))
.collect();
remaining.sort_by_key(|n| n.leader_count * 3 + n.replica_count);
for info in remaining {
if replicas.len() >= replication_factor as usize {
break;
}
replicas.push(info.id.clone());
}
}
if replicas.len() < replication_factor as usize {
return Err(ClusterError::InvalidReplicationFactor {
factor: replication_factor,
nodes: self.nodes.len(),
});
}
Ok(replicas)
}
pub fn calculate_reassignments(
&self,
current_assignments: &HashMap<String, Vec<Vec<NodeId>>>,
_added_nodes: &[NodeId],
removed_nodes: &[NodeId],
) -> HashMap<String, Vec<(u32, Vec<NodeId>)>> {
let mut reassignments = HashMap::new();
let removed_set: HashSet<_> = removed_nodes.iter().cloned().collect();
for (topic, partitions) in current_assignments {
let mut topic_reassignments = Vec::new();
for (partition_idx, current_replicas) in partitions.iter().enumerate() {
let has_removed = current_replicas.iter().any(|n| removed_set.contains(n));
if has_removed {
let replication_factor = current_replicas.len() as u16;
if let Ok(new_replicas) =
self.assign_partition(topic, partition_idx as u32, replication_factor)
{
topic_reassignments.push((partition_idx as u32, new_replicas));
}
}
}
if !topic_reassignments.is_empty() {
reassignments.insert(topic.clone(), topic_reassignments);
}
}
reassignments
}
pub fn get_distribution_stats(&self) -> DistributionStats {
let mut leader_counts: Vec<u32> = self.nodes.values().map(|n| n.leader_count).collect();
let mut replica_counts: Vec<u32> = self.nodes.values().map(|n| n.replica_count).collect();
leader_counts.sort();
replica_counts.sort();
let leader_sum: u32 = leader_counts.iter().sum();
let replica_sum: u32 = replica_counts.iter().sum();
DistributionStats {
node_count: self.nodes.len(),
rack_count: self.racks.len(),
total_leaders: leader_sum,
total_replicas: replica_sum,
leader_min: leader_counts.first().copied().unwrap_or(0),
leader_max: leader_counts.last().copied().unwrap_or(0),
leader_avg: if self.nodes.is_empty() {
0.0
} else {
leader_sum as f64 / self.nodes.len() as f64
},
replica_min: replica_counts.first().copied().unwrap_or(0),
replica_max: replica_counts.last().copied().unwrap_or(0),
replica_avg: if self.nodes.is_empty() {
0.0
} else {
replica_sum as f64 / self.nodes.len() as f64
},
}
}
}
#[derive(Debug, Clone)]
pub struct DistributionStats {
pub node_count: usize,
pub rack_count: usize,
pub total_leaders: u32,
pub total_replicas: u32,
pub leader_min: u32,
pub leader_max: u32,
pub leader_avg: f64,
pub replica_min: u32,
pub replica_max: u32,
pub replica_avg: f64,
}
fn fxhash(s: &str) -> u64 {
const FNV_OFFSET: u64 = 0xcbf29ce484222325;
const FNV_PRIME: u64 = 0x100000001b3;
let mut hash = FNV_OFFSET;
for byte in s.bytes() {
hash ^= byte as u64;
hash = hash.wrapping_mul(FNV_PRIME);
}
hash
}
#[cfg(test)]
mod tests {
use super::*;
use crate::node::NodeInfo;
fn create_test_node(id: &str, rack: Option<&str>) -> Node {
let info = NodeInfo::new(
id,
format!(
"127.0.0.1:{}",
9092 + id.chars().last().unwrap().to_digit(10).unwrap_or(0)
)
.parse()
.unwrap(),
format!(
"127.0.0.1:{}",
9093 + id.chars().last().unwrap().to_digit(10).unwrap_or(0)
)
.parse()
.unwrap(),
);
let info = if let Some(r) = rack {
info.with_rack(r)
} else {
info
};
Node::new(info)
}
#[test]
fn test_consistent_hash_placement() {
let mut placer = PartitionPlacer::new(PlacementConfig::default());
for i in 1..=5 {
let node = create_test_node(
&format!("node-{}", i),
Some(&format!("rack-{}", (i % 3) + 1)),
);
placer.add_node(&node);
}
let replicas = placer.assign_partition("test-topic", 0, 3).unwrap();
assert_eq!(replicas.len(), 3);
let replicas2 = placer.assign_partition("test-topic", 0, 3).unwrap();
assert_eq!(replicas, replicas2);
}
#[test]
fn test_rack_awareness() {
let config = PlacementConfig {
rack_aware: true,
..Default::default()
};
let mut placer = PartitionPlacer::new(config);
for i in 1..=6 {
let rack = format!("rack-{}", ((i - 1) / 2) + 1);
let node = create_test_node(&format!("node-{}", i), Some(&rack));
placer.add_node(&node);
}
let replicas = placer.assign_partition("test", 0, 3).unwrap();
assert_eq!(replicas.len(), 3);
let mut racks = HashSet::new();
for replica in &replicas {
if let Some(info) = placer.nodes.get(replica) {
if let Some(rack) = &info.rack {
racks.insert(rack.clone());
}
}
}
assert_eq!(racks.len(), 3);
}
#[test]
fn test_replication_factor_validation() {
let mut placer = PartitionPlacer::new(PlacementConfig::default());
placer.add_node(&create_test_node("node-1", None));
placer.add_node(&create_test_node("node-2", None));
let result = placer.assign_partition("test", 0, 3);
assert!(matches!(
result,
Err(ClusterError::InvalidReplicationFactor { .. })
));
let replicas = placer.assign_partition("test", 0, 2).unwrap();
assert_eq!(replicas.len(), 2);
}
}