use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::Arc;
use thiserror::Error;
use tracing::{debug, info, warn};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RebalanceConfig {
pub max_concurrent_moves: u32,
pub rebalance_delay_ms: u64,
pub move_timeout_ms: u64,
pub throttle_bytes_per_sec: u64,
pub min_rebalance_interval_ms: u64,
pub auto_rebalance: bool,
pub balance_threshold: f64,
}
impl Default for RebalanceConfig {
fn default() -> Self {
Self {
max_concurrent_moves: 2,
rebalance_delay_ms: 5000, move_timeout_ms: 300000, throttle_bytes_per_sec: 0, min_rebalance_interval_ms: 60000, auto_rebalance: true,
balance_threshold: 0.1, }
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum RebalanceTrigger {
NodeJoined,
NodeLeft,
Manual,
Periodic,
ReplicaChange,
ShardChange,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum RebalanceState {
Idle,
Pending,
Planning,
Executing,
Completed,
Failed,
Cancelled,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ShardMove {
pub move_id: String,
pub shard_id: String,
pub source_node: String,
pub target_node: String,
pub state: MoveState,
pub bytes_transferred: u64,
pub total_bytes: u64,
pub started_at: Option<u64>,
pub completed_at: Option<u64>,
pub error: Option<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum MoveState {
Queued,
Copying,
Verifying,
Routing,
Cleanup,
Completed,
Failed,
Cancelled,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RebalancePlan {
pub plan_id: String,
pub trigger: RebalanceTrigger,
pub moves: Vec<ShardMove>,
pub state: RebalanceState,
pub created_at: u64,
pub started_at: Option<u64>,
pub completed_at: Option<u64>,
pub initial_imbalance: f64,
pub final_imbalance: Option<f64>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct RebalanceStats {
pub total_rebalances: u64,
pub successful_rebalances: u64,
pub failed_rebalances: u64,
pub cancelled_rebalances: u64,
pub total_shards_moved: u64,
pub total_bytes_moved: u64,
pub current_moves_in_progress: u32,
pub pending_moves: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NodeLoad {
pub node_id: String,
pub shard_count: u32,
pub total_vectors: u64,
pub total_bytes: u64,
pub available_capacity: u64,
pub is_available: bool,
}
#[derive(Debug, Error)]
pub enum RebalanceError {
#[error("Rebalance already in progress: {0}")]
AlreadyInProgress(String),
#[error("No rebalance in progress")]
NoRebalanceInProgress,
#[error("Node not found: {0}")]
NodeNotFound(String),
#[error("Shard not found: {0}")]
ShardNotFound(String),
#[error("Move not found: {0}")]
MoveNotFound(String),
#[error("Move timed out: {0}")]
MoveTimeout(String),
#[error("Not enough capacity on target node: {0}")]
InsufficientCapacity(String),
#[error("Rebalance cancelled")]
Cancelled,
#[error("Rebalance failed: {0}")]
Failed(String),
#[error("Cluster not balanced enough to proceed")]
ClusterUnbalanced,
}
pub type Result<T> = std::result::Result<T, RebalanceError>;
pub struct RebalanceManager {
config: RebalanceConfig,
current_plan: Arc<RwLock<Option<RebalancePlan>>>,
node_loads: Arc<RwLock<HashMap<String, NodeLoad>>>,
shard_assignments: Arc<RwLock<HashMap<String, String>>>, stats: Arc<RwLock<RebalanceStats>>,
cancelled: AtomicBool,
last_rebalance_time: AtomicU64,
}
impl RebalanceManager {
pub fn new(config: RebalanceConfig) -> Self {
Self {
config,
current_plan: Arc::new(RwLock::new(None)),
node_loads: Arc::new(RwLock::new(HashMap::new())),
shard_assignments: Arc::new(RwLock::new(HashMap::new())),
stats: Arc::new(RwLock::new(RebalanceStats::default())),
cancelled: AtomicBool::new(false),
last_rebalance_time: AtomicU64::new(0),
}
}
pub fn register_node(&self, load: NodeLoad) {
let mut loads = self.node_loads.write();
loads.insert(load.node_id.clone(), load);
}
pub fn unregister_node(&self, node_id: &str) {
let mut loads = self.node_loads.write();
loads.remove(node_id);
}
pub fn update_node_load(&self, node_id: &str, update: impl FnOnce(&mut NodeLoad)) {
let mut loads = self.node_loads.write();
if let Some(load) = loads.get_mut(node_id) {
update(load);
}
}
pub fn register_shard(&self, shard_id: &str, node_id: &str) {
let mut assignments = self.shard_assignments.write();
assignments.insert(shard_id.to_string(), node_id.to_string());
}
pub fn unregister_shard(&self, shard_id: &str) {
let mut assignments = self.shard_assignments.write();
assignments.remove(shard_id);
}
pub fn calculate_imbalance(&self) -> f64 {
let loads = self.node_loads.read();
let available_nodes: Vec<&NodeLoad> = loads.values().filter(|n| n.is_available).collect();
if available_nodes.is_empty() {
return 0.0;
}
let total_shards: u32 = available_nodes.iter().map(|n| n.shard_count).sum();
let avg_shards = total_shards as f64 / available_nodes.len() as f64;
if avg_shards == 0.0 {
return 0.0;
}
let variance: f64 = available_nodes
.iter()
.map(|n| {
let diff = n.shard_count as f64 - avg_shards;
diff * diff
})
.sum::<f64>()
/ available_nodes.len() as f64;
let std_dev = variance.sqrt();
(std_dev / avg_shards).min(1.0)
}
pub fn needs_rebalance(&self) -> bool {
let imbalance = self.calculate_imbalance();
imbalance > self.config.balance_threshold
}
pub fn can_rebalance(&self) -> bool {
let last_time = self.last_rebalance_time.load(Ordering::SeqCst);
let now = current_time_ms();
now - last_time >= self.config.min_rebalance_interval_ms
}
pub fn trigger_rebalance(&self, trigger: RebalanceTrigger) -> Result<RebalancePlan> {
{
let plan = self.current_plan.read();
if let Some(ref p) = *plan {
if matches!(
p.state,
RebalanceState::Pending | RebalanceState::Planning | RebalanceState::Executing
) {
return Err(RebalanceError::AlreadyInProgress(p.plan_id.clone()));
}
}
}
if !self.can_rebalance() && trigger != RebalanceTrigger::Manual {
return Err(RebalanceError::Failed(
"Too soon since last rebalance".to_string(),
));
}
self.cancelled.store(false, Ordering::SeqCst);
let initial_imbalance = self.calculate_imbalance();
let plan = RebalancePlan {
plan_id: generate_plan_id(),
trigger,
moves: Vec::new(),
state: RebalanceState::Pending,
created_at: current_time_ms(),
started_at: None,
completed_at: None,
initial_imbalance,
final_imbalance: None,
};
{
let mut current = self.current_plan.write();
*current = Some(plan.clone());
}
info!(
"Triggered rebalance: {} (trigger: {:?}, imbalance: {:.2}%)",
plan.plan_id,
trigger,
initial_imbalance * 100.0
);
Ok(plan)
}
pub fn create_move_plan(&self) -> Result<Vec<ShardMove>> {
let loads = self.node_loads.read();
let assignments = self.shard_assignments.read();
let available_nodes: Vec<&NodeLoad> = loads.values().filter(|n| n.is_available).collect();
if available_nodes.is_empty() {
return Ok(Vec::new());
}
let total_shards: u32 = available_nodes.iter().map(|n| n.shard_count).sum();
let target_per_node = total_shards / available_nodes.len() as u32;
let remainder = total_shards % available_nodes.len() as u32;
let mut overloaded: Vec<(&NodeLoad, u32)> = Vec::new(); let mut underloaded: Vec<(&NodeLoad, u32)> = Vec::new();
for (i, node) in available_nodes.iter().enumerate() {
let target = target_per_node + if (i as u32) < remainder { 1 } else { 0 };
if node.shard_count > target {
overloaded.push((node, node.shard_count - target));
} else if node.shard_count < target {
underloaded.push((node, target - node.shard_count));
}
}
overloaded.sort_by(|a, b| b.1.cmp(&a.1)); underloaded.sort_by(|a, b| b.1.cmp(&a.1));
let mut moves = Vec::new();
let mut move_count = 0;
for (overloaded_node, mut excess) in overloaded {
if excess == 0 {
continue;
}
let shards_on_node: Vec<&String> = assignments
.iter()
.filter(|(_, node_id)| *node_id == &overloaded_node.node_id)
.map(|(shard_id, _)| shard_id)
.collect();
for shard_id in shards_on_node {
if excess == 0 {
break;
}
for (underloaded_node, deficit) in underloaded.iter_mut() {
if *deficit == 0 {
continue;
}
moves.push(ShardMove {
move_id: format!("move-{}", move_count),
shard_id: shard_id.clone(),
source_node: overloaded_node.node_id.clone(),
target_node: underloaded_node.node_id.clone(),
state: MoveState::Queued,
bytes_transferred: 0,
total_bytes: 0, started_at: None,
completed_at: None,
error: None,
});
move_count += 1;
excess -= 1;
*deficit -= 1;
break;
}
}
}
debug!("Created rebalance plan with {} moves", moves.len());
Ok(moves)
}
pub fn start_execution(&self) -> Result<()> {
let moves = self.create_move_plan()?;
let mut plan = self.current_plan.write();
let p = plan.as_mut().ok_or(RebalanceError::NoRebalanceInProgress)?;
if p.state != RebalanceState::Pending {
return Err(RebalanceError::Failed(
"Invalid state for execution".to_string(),
));
}
p.moves = moves;
p.state = RebalanceState::Executing;
p.started_at = Some(current_time_ms());
info!(
"Started rebalance execution: {} ({} moves)",
p.plan_id,
p.moves.len()
);
Ok(())
}
pub fn start_move(&self, move_id: &str) -> Result<()> {
let mut plan = self.current_plan.write();
let p = plan.as_mut().ok_or(RebalanceError::NoRebalanceInProgress)?;
let m = p
.moves
.iter_mut()
.find(|m| m.move_id == move_id)
.ok_or_else(|| RebalanceError::MoveNotFound(move_id.to_string()))?;
if m.state != MoveState::Queued {
return Ok(()); }
m.state = MoveState::Copying;
m.started_at = Some(current_time_ms());
let mut stats = self.stats.write();
stats.current_moves_in_progress += 1;
debug!(
"Started move: {} ({} -> {})",
move_id, m.source_node, m.target_node
);
Ok(())
}
pub fn update_move_progress(
&self,
move_id: &str,
bytes_transferred: u64,
total_bytes: u64,
) -> Result<()> {
let mut plan = self.current_plan.write();
let p = plan.as_mut().ok_or(RebalanceError::NoRebalanceInProgress)?;
let m = p
.moves
.iter_mut()
.find(|m| m.move_id == move_id)
.ok_or_else(|| RebalanceError::MoveNotFound(move_id.to_string()))?;
m.bytes_transferred = bytes_transferred;
m.total_bytes = total_bytes;
Ok(())
}
pub fn advance_to_verify(&self, move_id: &str) -> Result<()> {
let mut plan = self.current_plan.write();
let p = plan.as_mut().ok_or(RebalanceError::NoRebalanceInProgress)?;
let m = p
.moves
.iter_mut()
.find(|m| m.move_id == move_id)
.ok_or_else(|| RebalanceError::MoveNotFound(move_id.to_string()))?;
m.state = MoveState::Verifying;
debug!("Move {} advanced to verification", move_id);
Ok(())
}
pub fn advance_to_routing(&self, move_id: &str) -> Result<()> {
let mut plan = self.current_plan.write();
let p = plan.as_mut().ok_or(RebalanceError::NoRebalanceInProgress)?;
let m = p
.moves
.iter_mut()
.find(|m| m.move_id == move_id)
.ok_or_else(|| RebalanceError::MoveNotFound(move_id.to_string()))?;
m.state = MoveState::Routing;
debug!("Move {} advanced to routing update", move_id);
Ok(())
}
pub fn complete_move(&self, move_id: &str) -> Result<()> {
let (shard_id, target_node) = {
let mut plan = self.current_plan.write();
let p = plan.as_mut().ok_or(RebalanceError::NoRebalanceInProgress)?;
let m = p
.moves
.iter_mut()
.find(|m| m.move_id == move_id)
.ok_or_else(|| RebalanceError::MoveNotFound(move_id.to_string()))?;
m.state = MoveState::Completed;
m.completed_at = Some(current_time_ms());
(m.shard_id.clone(), m.target_node.clone())
};
{
let mut assignments = self.shard_assignments.write();
assignments.insert(shard_id, target_node);
}
{
let mut stats = self.stats.write();
stats.current_moves_in_progress = stats.current_moves_in_progress.saturating_sub(1);
stats.total_shards_moved += 1;
}
self.check_completion()?;
debug!("Completed move: {}", move_id);
Ok(())
}
pub fn fail_move(&self, move_id: &str, error: &str) -> Result<()> {
{
let mut plan = self.current_plan.write();
let p = plan.as_mut().ok_or(RebalanceError::NoRebalanceInProgress)?;
let m = p
.moves
.iter_mut()
.find(|m| m.move_id == move_id)
.ok_or_else(|| RebalanceError::MoveNotFound(move_id.to_string()))?;
m.state = MoveState::Failed;
m.error = Some(error.to_string());
m.completed_at = Some(current_time_ms());
}
{
let mut stats = self.stats.write();
stats.current_moves_in_progress = stats.current_moves_in_progress.saturating_sub(1);
}
warn!("Move {} failed: {}", move_id, error);
Ok(())
}
fn check_completion(&self) -> Result<()> {
let mut plan = self.current_plan.write();
let p = match plan.as_mut() {
Some(p) => p,
None => return Ok(()),
};
if p.state != RebalanceState::Executing {
return Ok(());
}
let all_done = p.moves.iter().all(|m| {
matches!(
m.state,
MoveState::Completed | MoveState::Failed | MoveState::Cancelled
)
});
if all_done {
let any_failed = p.moves.iter().any(|m| m.state == MoveState::Failed);
if any_failed {
p.state = RebalanceState::Failed;
let mut stats = self.stats.write();
stats.failed_rebalances += 1;
} else {
p.state = RebalanceState::Completed;
let mut stats = self.stats.write();
stats.successful_rebalances += 1;
}
p.completed_at = Some(current_time_ms());
self.last_rebalance_time
.store(current_time_ms(), Ordering::SeqCst);
drop(plan);
let final_imbalance = self.calculate_imbalance();
let mut plan = self.current_plan.write();
if let Some(p) = plan.as_mut() {
p.final_imbalance = Some(final_imbalance);
}
info!(
"Rebalance completed (final imbalance: {:.2}%)",
final_imbalance * 100.0
);
}
Ok(())
}
pub fn cancel(&self) -> Result<()> {
self.cancelled.store(true, Ordering::SeqCst);
let mut plan = self.current_plan.write();
let p = plan.as_mut().ok_or(RebalanceError::NoRebalanceInProgress)?;
for m in p.moves.iter_mut() {
if matches!(
m.state,
MoveState::Queued | MoveState::Copying | MoveState::Verifying | MoveState::Routing
) {
m.state = MoveState::Cancelled;
}
}
p.state = RebalanceState::Cancelled;
p.completed_at = Some(current_time_ms());
let mut stats = self.stats.write();
stats.cancelled_rebalances += 1;
info!("Rebalance cancelled: {}", p.plan_id);
Ok(())
}
pub fn is_cancelled(&self) -> bool {
self.cancelled.load(Ordering::SeqCst)
}
pub fn get_plan(&self) -> Option<RebalancePlan> {
self.current_plan.read().clone()
}
pub fn get_queued_moves(&self) -> Vec<String> {
let plan = self.current_plan.read();
match plan.as_ref() {
Some(p) => p
.moves
.iter()
.filter(|m| m.state == MoveState::Queued)
.map(|m| m.move_id.clone())
.collect(),
None => Vec::new(),
}
}
pub fn get_active_moves(&self) -> Vec<ShardMove> {
let plan = self.current_plan.read();
match plan.as_ref() {
Some(p) => p
.moves
.iter()
.filter(|m| {
matches!(
m.state,
MoveState::Copying | MoveState::Verifying | MoveState::Routing
)
})
.cloned()
.collect(),
None => Vec::new(),
}
}
pub fn get_stats(&self) -> RebalanceStats {
let mut stats = self.stats.read().clone();
if let Some(ref plan) = *self.current_plan.read() {
stats.pending_moves = plan
.moves
.iter()
.filter(|m| m.state == MoveState::Queued)
.count() as u32;
}
stats
}
pub fn get_node_load(&self, node_id: &str) -> Option<NodeLoad> {
self.node_loads.read().get(node_id).cloned()
}
pub fn get_all_node_loads(&self) -> Vec<NodeLoad> {
self.node_loads.read().values().cloned().collect()
}
pub fn on_node_joined(&self, node_id: &str, capacity: u64) -> Result<Option<RebalancePlan>> {
self.register_node(NodeLoad {
node_id: node_id.to_string(),
shard_count: 0,
total_vectors: 0,
total_bytes: 0,
available_capacity: capacity,
is_available: true,
});
info!("Node joined: {}", node_id);
if self.config.auto_rebalance && self.needs_rebalance() && self.can_rebalance() {
return Ok(Some(self.trigger_rebalance(RebalanceTrigger::NodeJoined)?));
}
Ok(None)
}
pub fn on_node_left(&self, node_id: &str) -> Result<Option<RebalancePlan>> {
self.update_node_load(node_id, |load| {
load.is_available = false;
});
info!("Node left: {}", node_id);
if self.config.auto_rebalance {
return Ok(Some(self.trigger_rebalance(RebalanceTrigger::NodeLeft)?));
}
Ok(None)
}
}
fn current_time_ms() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64
}
fn generate_plan_id() -> String {
use std::time::{SystemTime, UNIX_EPOCH};
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_millis();
format!("rebalance-{}", timestamp)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rebalance_config_defaults() {
let config = RebalanceConfig::default();
assert_eq!(config.max_concurrent_moves, 2);
assert_eq!(config.rebalance_delay_ms, 5000);
assert!(config.auto_rebalance);
assert_eq!(config.balance_threshold, 0.1);
}
#[test]
fn test_register_and_unregister_node() {
let manager = RebalanceManager::new(RebalanceConfig::default());
manager.register_node(NodeLoad {
node_id: "node1".to_string(),
shard_count: 5,
total_vectors: 1000,
total_bytes: 10000,
available_capacity: 100000,
is_available: true,
});
assert!(manager.get_node_load("node1").is_some());
manager.unregister_node("node1");
assert!(manager.get_node_load("node1").is_none());
}
#[test]
fn test_calculate_imbalance_balanced() {
let manager = RebalanceManager::new(RebalanceConfig::default());
for i in 0..3 {
manager.register_node(NodeLoad {
node_id: format!("node{}", i),
shard_count: 10,
total_vectors: 1000,
total_bytes: 10000,
available_capacity: 100000,
is_available: true,
});
}
let imbalance = manager.calculate_imbalance();
assert_eq!(imbalance, 0.0);
}
#[test]
fn test_calculate_imbalance_unbalanced() {
let manager = RebalanceManager::new(RebalanceConfig::default());
manager.register_node(NodeLoad {
node_id: "node0".to_string(),
shard_count: 30,
total_vectors: 1000,
total_bytes: 10000,
available_capacity: 100000,
is_available: true,
});
manager.register_node(NodeLoad {
node_id: "node1".to_string(),
shard_count: 5,
total_vectors: 1000,
total_bytes: 10000,
available_capacity: 100000,
is_available: true,
});
manager.register_node(NodeLoad {
node_id: "node2".to_string(),
shard_count: 5,
total_vectors: 1000,
total_bytes: 10000,
available_capacity: 100000,
is_available: true,
});
let imbalance = manager.calculate_imbalance();
assert!(imbalance > 0.0);
assert!(manager.needs_rebalance());
}
#[test]
fn test_trigger_rebalance() {
let config = RebalanceConfig {
min_rebalance_interval_ms: 0, ..Default::default()
};
let manager = RebalanceManager::new(config);
manager.register_node(NodeLoad {
node_id: "node1".to_string(),
shard_count: 10,
total_vectors: 1000,
total_bytes: 10000,
available_capacity: 100000,
is_available: true,
});
let plan = manager.trigger_rebalance(RebalanceTrigger::Manual).unwrap();
assert_eq!(plan.state, RebalanceState::Pending);
assert_eq!(plan.trigger, RebalanceTrigger::Manual);
}
#[test]
fn test_create_move_plan() {
let config = RebalanceConfig {
min_rebalance_interval_ms: 0,
..Default::default()
};
let manager = RebalanceManager::new(config);
manager.register_node(NodeLoad {
node_id: "node0".to_string(),
shard_count: 4,
total_vectors: 0,
total_bytes: 0,
available_capacity: 100000,
is_available: true,
});
manager.register_node(NodeLoad {
node_id: "node1".to_string(),
shard_count: 0,
total_vectors: 0,
total_bytes: 0,
available_capacity: 100000,
is_available: true,
});
for i in 0..4 {
manager.register_shard(&format!("shard{}", i), "node0");
}
manager.trigger_rebalance(RebalanceTrigger::Manual).unwrap();
let moves = manager.create_move_plan().unwrap();
assert_eq!(moves.len(), 2);
for m in &moves {
assert_eq!(m.source_node, "node0");
assert_eq!(m.target_node, "node1");
}
}
#[test]
fn test_move_lifecycle() {
let config = RebalanceConfig {
min_rebalance_interval_ms: 0,
..Default::default()
};
let manager = RebalanceManager::new(config);
manager.register_node(NodeLoad {
node_id: "node0".to_string(),
shard_count: 2,
total_vectors: 0,
total_bytes: 0,
available_capacity: 100000,
is_available: true,
});
manager.register_node(NodeLoad {
node_id: "node1".to_string(),
shard_count: 0,
total_vectors: 0,
total_bytes: 0,
available_capacity: 100000,
is_available: true,
});
manager.register_shard("shard0", "node0");
manager.register_shard("shard1", "node0");
manager.trigger_rebalance(RebalanceTrigger::Manual).unwrap();
manager.start_execution().unwrap();
let queued = manager.get_queued_moves();
assert!(!queued.is_empty());
let move_id = &queued[0];
manager.start_move(move_id).unwrap();
manager.update_move_progress(move_id, 500, 1000).unwrap();
manager.advance_to_verify(move_id).unwrap();
manager.advance_to_routing(move_id).unwrap();
manager.complete_move(move_id).unwrap();
let stats = manager.get_stats();
assert_eq!(stats.total_shards_moved, 1);
}
#[test]
fn test_cancel_rebalance() {
let config = RebalanceConfig {
min_rebalance_interval_ms: 0,
..Default::default()
};
let manager = RebalanceManager::new(config);
manager.register_node(NodeLoad {
node_id: "node0".to_string(),
shard_count: 4,
total_vectors: 0,
total_bytes: 0,
available_capacity: 100000,
is_available: true,
});
manager.register_node(NodeLoad {
node_id: "node1".to_string(),
shard_count: 0,
total_vectors: 0,
total_bytes: 0,
available_capacity: 100000,
is_available: true,
});
manager.trigger_rebalance(RebalanceTrigger::Manual).unwrap();
manager.start_execution().unwrap();
manager.cancel().unwrap();
let plan = manager.get_plan().unwrap();
assert_eq!(plan.state, RebalanceState::Cancelled);
assert!(manager.is_cancelled());
let stats = manager.get_stats();
assert_eq!(stats.cancelled_rebalances, 1);
}
#[test]
fn test_on_node_joined() {
let config = RebalanceConfig {
min_rebalance_interval_ms: 0,
auto_rebalance: true,
balance_threshold: 0.01, ..Default::default()
};
let manager = RebalanceManager::new(config);
manager.register_node(NodeLoad {
node_id: "node0".to_string(),
shard_count: 10,
total_vectors: 0,
total_bytes: 0,
available_capacity: 100000,
is_available: true,
});
let result = manager.on_node_joined("node1", 100000).unwrap();
assert!(result.is_some());
let plan = result.unwrap();
assert_eq!(plan.trigger, RebalanceTrigger::NodeJoined);
}
#[test]
fn test_on_node_left() {
let config = RebalanceConfig {
min_rebalance_interval_ms: 0,
auto_rebalance: true,
..Default::default()
};
let manager = RebalanceManager::new(config);
manager.register_node(NodeLoad {
node_id: "node0".to_string(),
shard_count: 5,
total_vectors: 0,
total_bytes: 0,
available_capacity: 100000,
is_available: true,
});
manager.register_node(NodeLoad {
node_id: "node1".to_string(),
shard_count: 5,
total_vectors: 0,
total_bytes: 0,
available_capacity: 100000,
is_available: true,
});
let result = manager.on_node_left("node1").unwrap();
assert!(result.is_some());
let plan = result.unwrap();
assert_eq!(plan.trigger, RebalanceTrigger::NodeLeft);
let load = manager.get_node_load("node1").unwrap();
assert!(!load.is_available);
}
#[test]
fn test_rebalance_stats() {
let config = RebalanceConfig {
min_rebalance_interval_ms: 0,
..Default::default()
};
let manager = RebalanceManager::new(config);
let stats = manager.get_stats();
assert_eq!(stats.total_rebalances, 0);
assert_eq!(stats.total_shards_moved, 0);
}
#[test]
fn test_duplicate_rebalance_rejected() {
let config = RebalanceConfig {
min_rebalance_interval_ms: 0,
..Default::default()
};
let manager = RebalanceManager::new(config);
manager.register_node(NodeLoad {
node_id: "node0".to_string(),
shard_count: 10,
total_vectors: 0,
total_bytes: 0,
available_capacity: 100000,
is_available: true,
});
manager.trigger_rebalance(RebalanceTrigger::Manual).unwrap();
manager.start_execution().unwrap();
let result = manager.trigger_rebalance(RebalanceTrigger::Manual);
assert!(matches!(result, Err(RebalanceError::AlreadyInProgress(_))));
}
}