1use parking_lot::RwLock;
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
14use std::sync::Arc;
15use thiserror::Error;
16use tracing::{debug, info, warn};
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct RebalanceConfig {
21 pub max_concurrent_moves: u32,
23 pub rebalance_delay_ms: u64,
25 pub move_timeout_ms: u64,
27 pub throttle_bytes_per_sec: u64,
29 pub min_rebalance_interval_ms: u64,
31 pub auto_rebalance: bool,
33 pub balance_threshold: f64,
35}
36
37impl Default for RebalanceConfig {
38 fn default() -> Self {
39 Self {
40 max_concurrent_moves: 2,
41 rebalance_delay_ms: 5000, move_timeout_ms: 300000, throttle_bytes_per_sec: 0, min_rebalance_interval_ms: 60000, auto_rebalance: true,
46 balance_threshold: 0.1, }
48 }
49}
50
51#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
53pub enum RebalanceTrigger {
54 NodeJoined,
56 NodeLeft,
58 Manual,
60 Periodic,
62 ReplicaChange,
64 ShardChange,
66}
67
68#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
70pub enum RebalanceState {
71 Idle,
73 Pending,
75 Planning,
77 Executing,
79 Completed,
81 Failed,
83 Cancelled,
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct ShardMove {
90 pub move_id: String,
92 pub shard_id: String,
94 pub source_node: String,
96 pub target_node: String,
98 pub state: MoveState,
100 pub bytes_transferred: u64,
102 pub total_bytes: u64,
104 pub started_at: Option<u64>,
106 pub completed_at: Option<u64>,
108 pub error: Option<String>,
110}
111
112#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
114pub enum MoveState {
115 Queued,
117 Copying,
119 Verifying,
121 Routing,
123 Cleanup,
125 Completed,
127 Failed,
129 Cancelled,
131}
132
133#[derive(Debug, Clone, Serialize, Deserialize)]
135pub struct RebalancePlan {
136 pub plan_id: String,
138 pub trigger: RebalanceTrigger,
140 pub moves: Vec<ShardMove>,
142 pub state: RebalanceState,
144 pub created_at: u64,
146 pub started_at: Option<u64>,
148 pub completed_at: Option<u64>,
150 pub initial_imbalance: f64,
152 pub final_imbalance: Option<f64>,
154}
155
156#[derive(Debug, Clone, Default, Serialize, Deserialize)]
158pub struct RebalanceStats {
159 pub total_rebalances: u64,
160 pub successful_rebalances: u64,
161 pub failed_rebalances: u64,
162 pub cancelled_rebalances: u64,
163 pub total_shards_moved: u64,
164 pub total_bytes_moved: u64,
165 pub current_moves_in_progress: u32,
166 pub pending_moves: u32,
167}
168
169#[derive(Debug, Clone, Serialize, Deserialize)]
171pub struct NodeLoad {
172 pub node_id: String,
173 pub shard_count: u32,
174 pub total_vectors: u64,
175 pub total_bytes: u64,
176 pub available_capacity: u64,
177 pub is_available: bool,
178}
179
180#[derive(Debug, Error)]
182pub enum RebalanceError {
183 #[error("Rebalance already in progress: {0}")]
184 AlreadyInProgress(String),
185
186 #[error("No rebalance in progress")]
187 NoRebalanceInProgress,
188
189 #[error("Node not found: {0}")]
190 NodeNotFound(String),
191
192 #[error("Shard not found: {0}")]
193 ShardNotFound(String),
194
195 #[error("Move not found: {0}")]
196 MoveNotFound(String),
197
198 #[error("Move timed out: {0}")]
199 MoveTimeout(String),
200
201 #[error("Not enough capacity on target node: {0}")]
202 InsufficientCapacity(String),
203
204 #[error("Rebalance cancelled")]
205 Cancelled,
206
207 #[error("Rebalance failed: {0}")]
208 Failed(String),
209
210 #[error("Cluster not balanced enough to proceed")]
211 ClusterUnbalanced,
212}
213
214pub type Result<T> = std::result::Result<T, RebalanceError>;
215
216pub struct RebalanceManager {
218 config: RebalanceConfig,
219 current_plan: Arc<RwLock<Option<RebalancePlan>>>,
220 node_loads: Arc<RwLock<HashMap<String, NodeLoad>>>,
221 shard_assignments: Arc<RwLock<HashMap<String, String>>>, stats: Arc<RwLock<RebalanceStats>>,
223 cancelled: AtomicBool,
224 last_rebalance_time: AtomicU64,
225}
226
227impl RebalanceManager {
228 pub fn new(config: RebalanceConfig) -> Self {
230 Self {
231 config,
232 current_plan: Arc::new(RwLock::new(None)),
233 node_loads: Arc::new(RwLock::new(HashMap::new())),
234 shard_assignments: Arc::new(RwLock::new(HashMap::new())),
235 stats: Arc::new(RwLock::new(RebalanceStats::default())),
236 cancelled: AtomicBool::new(false),
237 last_rebalance_time: AtomicU64::new(0),
238 }
239 }
240
241 pub fn register_node(&self, load: NodeLoad) {
243 let mut loads = self.node_loads.write();
244 loads.insert(load.node_id.clone(), load);
245 }
246
247 pub fn unregister_node(&self, node_id: &str) {
249 let mut loads = self.node_loads.write();
250 loads.remove(node_id);
251 }
252
253 pub fn update_node_load(&self, node_id: &str, update: impl FnOnce(&mut NodeLoad)) {
255 let mut loads = self.node_loads.write();
256 if let Some(load) = loads.get_mut(node_id) {
257 update(load);
258 }
259 }
260
261 pub fn register_shard(&self, shard_id: &str, node_id: &str) {
263 let mut assignments = self.shard_assignments.write();
264 assignments.insert(shard_id.to_string(), node_id.to_string());
265 }
266
267 pub fn unregister_shard(&self, shard_id: &str) {
269 let mut assignments = self.shard_assignments.write();
270 assignments.remove(shard_id);
271 }
272
273 pub fn calculate_imbalance(&self) -> f64 {
275 let loads = self.node_loads.read();
276 let available_nodes: Vec<&NodeLoad> = loads.values().filter(|n| n.is_available).collect();
277
278 if available_nodes.is_empty() {
279 return 0.0;
280 }
281
282 let total_shards: u32 = available_nodes.iter().map(|n| n.shard_count).sum();
283 let avg_shards = total_shards as f64 / available_nodes.len() as f64;
284
285 if avg_shards == 0.0 {
286 return 0.0;
287 }
288
289 let variance: f64 = available_nodes
290 .iter()
291 .map(|n| {
292 let diff = n.shard_count as f64 - avg_shards;
293 diff * diff
294 })
295 .sum::<f64>()
296 / available_nodes.len() as f64;
297
298 let std_dev = variance.sqrt();
299 (std_dev / avg_shards).min(1.0)
300 }
301
302 pub fn needs_rebalance(&self) -> bool {
304 let imbalance = self.calculate_imbalance();
305 imbalance > self.config.balance_threshold
306 }
307
308 pub fn can_rebalance(&self) -> bool {
310 let last_time = self.last_rebalance_time.load(Ordering::SeqCst);
311 let now = current_time_ms();
312 now - last_time >= self.config.min_rebalance_interval_ms
313 }
314
315 pub fn trigger_rebalance(&self, trigger: RebalanceTrigger) -> Result<RebalancePlan> {
317 {
319 let plan = self.current_plan.read();
320 if let Some(ref p) = *plan {
321 if matches!(
322 p.state,
323 RebalanceState::Pending | RebalanceState::Planning | RebalanceState::Executing
324 ) {
325 return Err(RebalanceError::AlreadyInProgress(p.plan_id.clone()));
326 }
327 }
328 }
329
330 if !self.can_rebalance() && trigger != RebalanceTrigger::Manual {
332 return Err(RebalanceError::Failed(
333 "Too soon since last rebalance".to_string(),
334 ));
335 }
336
337 self.cancelled.store(false, Ordering::SeqCst);
338 let initial_imbalance = self.calculate_imbalance();
339
340 let plan = RebalancePlan {
342 plan_id: generate_plan_id(),
343 trigger,
344 moves: Vec::new(),
345 state: RebalanceState::Pending,
346 created_at: current_time_ms(),
347 started_at: None,
348 completed_at: None,
349 initial_imbalance,
350 final_imbalance: None,
351 };
352
353 {
354 let mut current = self.current_plan.write();
355 *current = Some(plan.clone());
356 }
357
358 info!(
359 "Triggered rebalance: {} (trigger: {:?}, imbalance: {:.2}%)",
360 plan.plan_id,
361 trigger,
362 initial_imbalance * 100.0
363 );
364
365 Ok(plan)
366 }
367
368 pub fn create_move_plan(&self) -> Result<Vec<ShardMove>> {
370 let loads = self.node_loads.read();
371 let assignments = self.shard_assignments.read();
372
373 let available_nodes: Vec<&NodeLoad> = loads.values().filter(|n| n.is_available).collect();
374
375 if available_nodes.is_empty() {
376 return Ok(Vec::new());
377 }
378
379 let total_shards: u32 = available_nodes.iter().map(|n| n.shard_count).sum();
380 let target_per_node = total_shards / available_nodes.len() as u32;
381 let remainder = total_shards % available_nodes.len() as u32;
382
383 let mut overloaded: Vec<(&NodeLoad, u32)> = Vec::new(); let mut underloaded: Vec<(&NodeLoad, u32)> = Vec::new(); for (i, node) in available_nodes.iter().enumerate() {
388 let target = target_per_node + if (i as u32) < remainder { 1 } else { 0 };
389 if node.shard_count > target {
390 overloaded.push((node, node.shard_count - target));
391 } else if node.shard_count < target {
392 underloaded.push((node, target - node.shard_count));
393 }
394 }
395
396 overloaded.sort_by(|a, b| b.1.cmp(&a.1)); underloaded.sort_by(|a, b| b.1.cmp(&a.1)); let mut moves = Vec::new();
401 let mut move_count = 0;
402
403 for (overloaded_node, mut excess) in overloaded {
405 if excess == 0 {
406 continue;
407 }
408
409 let shards_on_node: Vec<&String> = assignments
411 .iter()
412 .filter(|(_, node_id)| *node_id == &overloaded_node.node_id)
413 .map(|(shard_id, _)| shard_id)
414 .collect();
415
416 for shard_id in shards_on_node {
417 if excess == 0 {
418 break;
419 }
420
421 for (underloaded_node, deficit) in underloaded.iter_mut() {
423 if *deficit == 0 {
424 continue;
425 }
426
427 moves.push(ShardMove {
429 move_id: format!("move-{}", move_count),
430 shard_id: shard_id.clone(),
431 source_node: overloaded_node.node_id.clone(),
432 target_node: underloaded_node.node_id.clone(),
433 state: MoveState::Queued,
434 bytes_transferred: 0,
435 total_bytes: 0, started_at: None,
437 completed_at: None,
438 error: None,
439 });
440
441 move_count += 1;
442 excess -= 1;
443 *deficit -= 1;
444 break;
445 }
446 }
447 }
448
449 debug!("Created rebalance plan with {} moves", moves.len());
450 Ok(moves)
451 }
452
453 pub fn start_execution(&self) -> Result<()> {
455 let moves = self.create_move_plan()?;
456
457 let mut plan = self.current_plan.write();
458 let p = plan.as_mut().ok_or(RebalanceError::NoRebalanceInProgress)?;
459
460 if p.state != RebalanceState::Pending {
461 return Err(RebalanceError::Failed(
462 "Invalid state for execution".to_string(),
463 ));
464 }
465
466 p.moves = moves;
467 p.state = RebalanceState::Executing;
468 p.started_at = Some(current_time_ms());
469
470 info!(
471 "Started rebalance execution: {} ({} moves)",
472 p.plan_id,
473 p.moves.len()
474 );
475 Ok(())
476 }
477
478 pub fn start_move(&self, move_id: &str) -> Result<()> {
480 let mut plan = self.current_plan.write();
481 let p = plan.as_mut().ok_or(RebalanceError::NoRebalanceInProgress)?;
482
483 let m = p
484 .moves
485 .iter_mut()
486 .find(|m| m.move_id == move_id)
487 .ok_or_else(|| RebalanceError::MoveNotFound(move_id.to_string()))?;
488
489 if m.state != MoveState::Queued {
490 return Ok(()); }
492
493 m.state = MoveState::Copying;
494 m.started_at = Some(current_time_ms());
495
496 let mut stats = self.stats.write();
497 stats.current_moves_in_progress += 1;
498
499 debug!(
500 "Started move: {} ({} -> {})",
501 move_id, m.source_node, m.target_node
502 );
503 Ok(())
504 }
505
506 pub fn update_move_progress(
508 &self,
509 move_id: &str,
510 bytes_transferred: u64,
511 total_bytes: u64,
512 ) -> Result<()> {
513 let mut plan = self.current_plan.write();
514 let p = plan.as_mut().ok_or(RebalanceError::NoRebalanceInProgress)?;
515
516 let m = p
517 .moves
518 .iter_mut()
519 .find(|m| m.move_id == move_id)
520 .ok_or_else(|| RebalanceError::MoveNotFound(move_id.to_string()))?;
521
522 m.bytes_transferred = bytes_transferred;
523 m.total_bytes = total_bytes;
524 Ok(())
525 }
526
527 pub fn advance_to_verify(&self, move_id: &str) -> Result<()> {
529 let mut plan = self.current_plan.write();
530 let p = plan.as_mut().ok_or(RebalanceError::NoRebalanceInProgress)?;
531
532 let m = p
533 .moves
534 .iter_mut()
535 .find(|m| m.move_id == move_id)
536 .ok_or_else(|| RebalanceError::MoveNotFound(move_id.to_string()))?;
537
538 m.state = MoveState::Verifying;
539 debug!("Move {} advanced to verification", move_id);
540 Ok(())
541 }
542
543 pub fn advance_to_routing(&self, move_id: &str) -> Result<()> {
545 let mut plan = self.current_plan.write();
546 let p = plan.as_mut().ok_or(RebalanceError::NoRebalanceInProgress)?;
547
548 let m = p
549 .moves
550 .iter_mut()
551 .find(|m| m.move_id == move_id)
552 .ok_or_else(|| RebalanceError::MoveNotFound(move_id.to_string()))?;
553
554 m.state = MoveState::Routing;
555 debug!("Move {} advanced to routing update", move_id);
556 Ok(())
557 }
558
559 pub fn complete_move(&self, move_id: &str) -> Result<()> {
561 let (shard_id, target_node) = {
562 let mut plan = self.current_plan.write();
563 let p = plan.as_mut().ok_or(RebalanceError::NoRebalanceInProgress)?;
564
565 let m = p
566 .moves
567 .iter_mut()
568 .find(|m| m.move_id == move_id)
569 .ok_or_else(|| RebalanceError::MoveNotFound(move_id.to_string()))?;
570
571 m.state = MoveState::Completed;
572 m.completed_at = Some(current_time_ms());
573
574 (m.shard_id.clone(), m.target_node.clone())
575 };
576
577 {
579 let mut assignments = self.shard_assignments.write();
580 assignments.insert(shard_id, target_node);
581 }
582
583 {
585 let mut stats = self.stats.write();
586 stats.current_moves_in_progress = stats.current_moves_in_progress.saturating_sub(1);
587 stats.total_shards_moved += 1;
588 }
589
590 self.check_completion()?;
592
593 debug!("Completed move: {}", move_id);
594 Ok(())
595 }
596
597 pub fn fail_move(&self, move_id: &str, error: &str) -> Result<()> {
599 {
600 let mut plan = self.current_plan.write();
601 let p = plan.as_mut().ok_or(RebalanceError::NoRebalanceInProgress)?;
602
603 let m = p
604 .moves
605 .iter_mut()
606 .find(|m| m.move_id == move_id)
607 .ok_or_else(|| RebalanceError::MoveNotFound(move_id.to_string()))?;
608
609 m.state = MoveState::Failed;
610 m.error = Some(error.to_string());
611 m.completed_at = Some(current_time_ms());
612 }
613
614 {
616 let mut stats = self.stats.write();
617 stats.current_moves_in_progress = stats.current_moves_in_progress.saturating_sub(1);
618 }
619
620 warn!("Move {} failed: {}", move_id, error);
621 Ok(())
622 }
623
624 fn check_completion(&self) -> Result<()> {
626 let mut plan = self.current_plan.write();
627 let p = match plan.as_mut() {
628 Some(p) => p,
629 None => return Ok(()),
630 };
631
632 if p.state != RebalanceState::Executing {
633 return Ok(());
634 }
635
636 let all_done = p.moves.iter().all(|m| {
637 matches!(
638 m.state,
639 MoveState::Completed | MoveState::Failed | MoveState::Cancelled
640 )
641 });
642
643 if all_done {
644 let any_failed = p.moves.iter().any(|m| m.state == MoveState::Failed);
645
646 if any_failed {
647 p.state = RebalanceState::Failed;
648 let mut stats = self.stats.write();
649 stats.failed_rebalances += 1;
650 } else {
651 p.state = RebalanceState::Completed;
652 let mut stats = self.stats.write();
653 stats.successful_rebalances += 1;
654 }
655
656 p.completed_at = Some(current_time_ms());
657 self.last_rebalance_time
658 .store(current_time_ms(), Ordering::SeqCst);
659
660 drop(plan);
662 let final_imbalance = self.calculate_imbalance();
663 let mut plan = self.current_plan.write();
664 if let Some(p) = plan.as_mut() {
665 p.final_imbalance = Some(final_imbalance);
666 }
667
668 info!(
669 "Rebalance completed (final imbalance: {:.2}%)",
670 final_imbalance * 100.0
671 );
672 }
673
674 Ok(())
675 }
676
677 pub fn cancel(&self) -> Result<()> {
679 self.cancelled.store(true, Ordering::SeqCst);
680
681 let mut plan = self.current_plan.write();
682 let p = plan.as_mut().ok_or(RebalanceError::NoRebalanceInProgress)?;
683
684 for m in p.moves.iter_mut() {
686 if matches!(
687 m.state,
688 MoveState::Queued | MoveState::Copying | MoveState::Verifying | MoveState::Routing
689 ) {
690 m.state = MoveState::Cancelled;
691 }
692 }
693
694 p.state = RebalanceState::Cancelled;
695 p.completed_at = Some(current_time_ms());
696
697 let mut stats = self.stats.write();
698 stats.cancelled_rebalances += 1;
699
700 info!("Rebalance cancelled: {}", p.plan_id);
701 Ok(())
702 }
703
704 pub fn is_cancelled(&self) -> bool {
706 self.cancelled.load(Ordering::SeqCst)
707 }
708
709 pub fn get_plan(&self) -> Option<RebalancePlan> {
711 self.current_plan.read().clone()
712 }
713
714 pub fn get_queued_moves(&self) -> Vec<String> {
716 let plan = self.current_plan.read();
717 match plan.as_ref() {
718 Some(p) => p
719 .moves
720 .iter()
721 .filter(|m| m.state == MoveState::Queued)
722 .map(|m| m.move_id.clone())
723 .collect(),
724 None => Vec::new(),
725 }
726 }
727
728 pub fn get_active_moves(&self) -> Vec<ShardMove> {
730 let plan = self.current_plan.read();
731 match plan.as_ref() {
732 Some(p) => p
733 .moves
734 .iter()
735 .filter(|m| {
736 matches!(
737 m.state,
738 MoveState::Copying | MoveState::Verifying | MoveState::Routing
739 )
740 })
741 .cloned()
742 .collect(),
743 None => Vec::new(),
744 }
745 }
746
747 pub fn get_stats(&self) -> RebalanceStats {
749 let mut stats = self.stats.read().clone();
750
751 if let Some(ref plan) = *self.current_plan.read() {
753 stats.pending_moves = plan
754 .moves
755 .iter()
756 .filter(|m| m.state == MoveState::Queued)
757 .count() as u32;
758 }
759
760 stats
761 }
762
763 pub fn get_node_load(&self, node_id: &str) -> Option<NodeLoad> {
765 self.node_loads.read().get(node_id).cloned()
766 }
767
768 pub fn get_all_node_loads(&self) -> Vec<NodeLoad> {
770 self.node_loads.read().values().cloned().collect()
771 }
772
773 pub fn on_node_joined(&self, node_id: &str, capacity: u64) -> Result<Option<RebalancePlan>> {
775 self.register_node(NodeLoad {
777 node_id: node_id.to_string(),
778 shard_count: 0,
779 total_vectors: 0,
780 total_bytes: 0,
781 available_capacity: capacity,
782 is_available: true,
783 });
784
785 info!("Node joined: {}", node_id);
786
787 if self.config.auto_rebalance && self.needs_rebalance() && self.can_rebalance() {
789 return Ok(Some(self.trigger_rebalance(RebalanceTrigger::NodeJoined)?));
790 }
791
792 Ok(None)
793 }
794
795 pub fn on_node_left(&self, node_id: &str) -> Result<Option<RebalancePlan>> {
797 self.update_node_load(node_id, |load| {
799 load.is_available = false;
800 });
801
802 info!("Node left: {}", node_id);
803
804 if self.config.auto_rebalance {
806 return Ok(Some(self.trigger_rebalance(RebalanceTrigger::NodeLeft)?));
807 }
808
809 Ok(None)
810 }
811}
812
813fn current_time_ms() -> u64 {
815 std::time::SystemTime::now()
816 .duration_since(std::time::UNIX_EPOCH)
817 .unwrap_or_default()
818 .as_millis() as u64
819}
820
821fn generate_plan_id() -> String {
823 use std::time::{SystemTime, UNIX_EPOCH};
824 let timestamp = SystemTime::now()
825 .duration_since(UNIX_EPOCH)
826 .unwrap_or_default()
827 .as_millis();
828 format!("rebalance-{}", timestamp)
829}
830
831#[cfg(test)]
832mod tests {
833 use super::*;
834
835 #[test]
836 fn test_rebalance_config_defaults() {
837 let config = RebalanceConfig::default();
838 assert_eq!(config.max_concurrent_moves, 2);
839 assert_eq!(config.rebalance_delay_ms, 5000);
840 assert!(config.auto_rebalance);
841 assert_eq!(config.balance_threshold, 0.1);
842 }
843
844 #[test]
845 fn test_register_and_unregister_node() {
846 let manager = RebalanceManager::new(RebalanceConfig::default());
847
848 manager.register_node(NodeLoad {
849 node_id: "node1".to_string(),
850 shard_count: 5,
851 total_vectors: 1000,
852 total_bytes: 10000,
853 available_capacity: 100000,
854 is_available: true,
855 });
856
857 assert!(manager.get_node_load("node1").is_some());
858
859 manager.unregister_node("node1");
860 assert!(manager.get_node_load("node1").is_none());
861 }
862
863 #[test]
864 fn test_calculate_imbalance_balanced() {
865 let manager = RebalanceManager::new(RebalanceConfig::default());
866
867 for i in 0..3 {
869 manager.register_node(NodeLoad {
870 node_id: format!("node{}", i),
871 shard_count: 10,
872 total_vectors: 1000,
873 total_bytes: 10000,
874 available_capacity: 100000,
875 is_available: true,
876 });
877 }
878
879 let imbalance = manager.calculate_imbalance();
880 assert_eq!(imbalance, 0.0);
881 }
882
883 #[test]
884 fn test_calculate_imbalance_unbalanced() {
885 let manager = RebalanceManager::new(RebalanceConfig::default());
886
887 manager.register_node(NodeLoad {
889 node_id: "node0".to_string(),
890 shard_count: 30,
891 total_vectors: 1000,
892 total_bytes: 10000,
893 available_capacity: 100000,
894 is_available: true,
895 });
896 manager.register_node(NodeLoad {
897 node_id: "node1".to_string(),
898 shard_count: 5,
899 total_vectors: 1000,
900 total_bytes: 10000,
901 available_capacity: 100000,
902 is_available: true,
903 });
904 manager.register_node(NodeLoad {
905 node_id: "node2".to_string(),
906 shard_count: 5,
907 total_vectors: 1000,
908 total_bytes: 10000,
909 available_capacity: 100000,
910 is_available: true,
911 });
912
913 let imbalance = manager.calculate_imbalance();
914 assert!(imbalance > 0.0);
915 assert!(manager.needs_rebalance());
916 }
917
918 #[test]
919 fn test_trigger_rebalance() {
920 let config = RebalanceConfig {
921 min_rebalance_interval_ms: 0, ..Default::default()
923 };
924 let manager = RebalanceManager::new(config);
925
926 manager.register_node(NodeLoad {
927 node_id: "node1".to_string(),
928 shard_count: 10,
929 total_vectors: 1000,
930 total_bytes: 10000,
931 available_capacity: 100000,
932 is_available: true,
933 });
934
935 let plan = manager.trigger_rebalance(RebalanceTrigger::Manual).unwrap();
936 assert_eq!(plan.state, RebalanceState::Pending);
937 assert_eq!(plan.trigger, RebalanceTrigger::Manual);
938 }
939
940 #[test]
941 fn test_create_move_plan() {
942 let config = RebalanceConfig {
943 min_rebalance_interval_ms: 0,
944 ..Default::default()
945 };
946 let manager = RebalanceManager::new(config);
947
948 manager.register_node(NodeLoad {
950 node_id: "node0".to_string(),
951 shard_count: 4,
952 total_vectors: 0,
953 total_bytes: 0,
954 available_capacity: 100000,
955 is_available: true,
956 });
957 manager.register_node(NodeLoad {
958 node_id: "node1".to_string(),
959 shard_count: 0,
960 total_vectors: 0,
961 total_bytes: 0,
962 available_capacity: 100000,
963 is_available: true,
964 });
965
966 for i in 0..4 {
968 manager.register_shard(&format!("shard{}", i), "node0");
969 }
970
971 manager.trigger_rebalance(RebalanceTrigger::Manual).unwrap();
972 let moves = manager.create_move_plan().unwrap();
973
974 assert_eq!(moves.len(), 2);
976 for m in &moves {
977 assert_eq!(m.source_node, "node0");
978 assert_eq!(m.target_node, "node1");
979 }
980 }
981
982 #[test]
983 fn test_move_lifecycle() {
984 let config = RebalanceConfig {
985 min_rebalance_interval_ms: 0,
986 ..Default::default()
987 };
988 let manager = RebalanceManager::new(config);
989
990 manager.register_node(NodeLoad {
991 node_id: "node0".to_string(),
992 shard_count: 2,
993 total_vectors: 0,
994 total_bytes: 0,
995 available_capacity: 100000,
996 is_available: true,
997 });
998 manager.register_node(NodeLoad {
999 node_id: "node1".to_string(),
1000 shard_count: 0,
1001 total_vectors: 0,
1002 total_bytes: 0,
1003 available_capacity: 100000,
1004 is_available: true,
1005 });
1006 manager.register_shard("shard0", "node0");
1007 manager.register_shard("shard1", "node0");
1008
1009 manager.trigger_rebalance(RebalanceTrigger::Manual).unwrap();
1010 manager.start_execution().unwrap();
1011
1012 let queued = manager.get_queued_moves();
1013 assert!(!queued.is_empty());
1014
1015 let move_id = &queued[0];
1016
1017 manager.start_move(move_id).unwrap();
1019 manager.update_move_progress(move_id, 500, 1000).unwrap();
1020 manager.advance_to_verify(move_id).unwrap();
1021 manager.advance_to_routing(move_id).unwrap();
1022 manager.complete_move(move_id).unwrap();
1023
1024 let stats = manager.get_stats();
1025 assert_eq!(stats.total_shards_moved, 1);
1026 }
1027
1028 #[test]
1029 fn test_cancel_rebalance() {
1030 let config = RebalanceConfig {
1031 min_rebalance_interval_ms: 0,
1032 ..Default::default()
1033 };
1034 let manager = RebalanceManager::new(config);
1035
1036 manager.register_node(NodeLoad {
1037 node_id: "node0".to_string(),
1038 shard_count: 4,
1039 total_vectors: 0,
1040 total_bytes: 0,
1041 available_capacity: 100000,
1042 is_available: true,
1043 });
1044 manager.register_node(NodeLoad {
1045 node_id: "node1".to_string(),
1046 shard_count: 0,
1047 total_vectors: 0,
1048 total_bytes: 0,
1049 available_capacity: 100000,
1050 is_available: true,
1051 });
1052
1053 manager.trigger_rebalance(RebalanceTrigger::Manual).unwrap();
1054 manager.start_execution().unwrap();
1055 manager.cancel().unwrap();
1056
1057 let plan = manager.get_plan().unwrap();
1058 assert_eq!(plan.state, RebalanceState::Cancelled);
1059 assert!(manager.is_cancelled());
1060
1061 let stats = manager.get_stats();
1062 assert_eq!(stats.cancelled_rebalances, 1);
1063 }
1064
1065 #[test]
1066 fn test_on_node_joined() {
1067 let config = RebalanceConfig {
1068 min_rebalance_interval_ms: 0,
1069 auto_rebalance: true,
1070 balance_threshold: 0.01, ..Default::default()
1072 };
1073 let manager = RebalanceManager::new(config);
1074
1075 manager.register_node(NodeLoad {
1077 node_id: "node0".to_string(),
1078 shard_count: 10,
1079 total_vectors: 0,
1080 total_bytes: 0,
1081 available_capacity: 100000,
1082 is_available: true,
1083 });
1084
1085 let result = manager.on_node_joined("node1", 100000).unwrap();
1087
1088 assert!(result.is_some());
1090 let plan = result.unwrap();
1091 assert_eq!(plan.trigger, RebalanceTrigger::NodeJoined);
1092 }
1093
1094 #[test]
1095 fn test_on_node_left() {
1096 let config = RebalanceConfig {
1097 min_rebalance_interval_ms: 0,
1098 auto_rebalance: true,
1099 ..Default::default()
1100 };
1101 let manager = RebalanceManager::new(config);
1102
1103 manager.register_node(NodeLoad {
1104 node_id: "node0".to_string(),
1105 shard_count: 5,
1106 total_vectors: 0,
1107 total_bytes: 0,
1108 available_capacity: 100000,
1109 is_available: true,
1110 });
1111 manager.register_node(NodeLoad {
1112 node_id: "node1".to_string(),
1113 shard_count: 5,
1114 total_vectors: 0,
1115 total_bytes: 0,
1116 available_capacity: 100000,
1117 is_available: true,
1118 });
1119
1120 let result = manager.on_node_left("node1").unwrap();
1122
1123 assert!(result.is_some());
1125 let plan = result.unwrap();
1126 assert_eq!(plan.trigger, RebalanceTrigger::NodeLeft);
1127
1128 let load = manager.get_node_load("node1").unwrap();
1130 assert!(!load.is_available);
1131 }
1132
1133 #[test]
1134 fn test_rebalance_stats() {
1135 let config = RebalanceConfig {
1136 min_rebalance_interval_ms: 0,
1137 ..Default::default()
1138 };
1139 let manager = RebalanceManager::new(config);
1140
1141 let stats = manager.get_stats();
1142 assert_eq!(stats.total_rebalances, 0);
1143 assert_eq!(stats.total_shards_moved, 0);
1144 }
1145
1146 #[test]
1147 fn test_duplicate_rebalance_rejected() {
1148 let config = RebalanceConfig {
1149 min_rebalance_interval_ms: 0,
1150 ..Default::default()
1151 };
1152 let manager = RebalanceManager::new(config);
1153
1154 manager.register_node(NodeLoad {
1155 node_id: "node0".to_string(),
1156 shard_count: 10,
1157 total_vectors: 0,
1158 total_bytes: 0,
1159 available_capacity: 100000,
1160 is_available: true,
1161 });
1162
1163 manager.trigger_rebalance(RebalanceTrigger::Manual).unwrap();
1164 manager.start_execution().unwrap();
1165
1166 let result = manager.trigger_rebalance(RebalanceTrigger::Manual);
1168 assert!(matches!(result, Err(RebalanceError::AlreadyInProgress(_))));
1169 }
1170}