Skip to main content

engine/distributed/
rebalance.rs

1//! Automatic Rebalancing for Distributed Dakera
2//!
3//! Provides automatic data rebalancing when nodes join or leave:
4//! - Shard redistribution on topology changes
5//! - Replica placement optimization
6//! - Minimal data movement strategies
7//! - Progress tracking and cancellation
8//! - Integration with consistent hashing
9
10use parking_lot::RwLock;
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
14use std::sync::Arc;
15use thiserror::Error;
16use tracing::{debug, info, warn};
17
18/// Configuration for rebalancing operations
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct RebalanceConfig {
21    /// Maximum concurrent shard moves
22    pub max_concurrent_moves: u32,
23    /// Delay before starting rebalance after topology change (ms)
24    pub rebalance_delay_ms: u64,
25    /// Maximum time for a single shard move (ms)
26    pub move_timeout_ms: u64,
27    /// Throttle rate for data transfer (bytes/sec, 0 = unlimited)
28    pub throttle_bytes_per_sec: u64,
29    /// Minimum interval between rebalance operations (ms)
30    pub min_rebalance_interval_ms: u64,
31    /// Enable automatic rebalancing
32    pub auto_rebalance: bool,
33    /// Target balance threshold (0.0-1.0, lower = more balanced)
34    pub balance_threshold: f64,
35}
36
37impl Default for RebalanceConfig {
38    fn default() -> Self {
39        Self {
40            max_concurrent_moves: 2,
41            rebalance_delay_ms: 5000,         // 5 second delay
42            move_timeout_ms: 300000,          // 5 minute timeout per move
43            throttle_bytes_per_sec: 0,        // Unlimited
44            min_rebalance_interval_ms: 60000, // 1 minute between rebalances
45            auto_rebalance: true,
46            balance_threshold: 0.1, // 10% imbalance tolerated
47        }
48    }
49}
50
51/// Reason for triggering rebalance
52#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
53pub enum RebalanceTrigger {
54    /// A new node joined the cluster
55    NodeJoined,
56    /// A node left the cluster (graceful or failure)
57    NodeLeft,
58    /// Manual rebalance request
59    Manual,
60    /// Periodic rebalance check
61    Periodic,
62    /// Replica count changed
63    ReplicaChange,
64    /// Shard split or merge
65    ShardChange,
66}
67
68/// State of the rebalance operation
69#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
70pub enum RebalanceState {
71    /// No rebalance in progress
72    Idle,
73    /// Waiting for delay before starting
74    Pending,
75    /// Planning shard movements
76    Planning,
77    /// Executing shard movements
78    Executing,
79    /// Rebalance completed successfully
80    Completed,
81    /// Rebalance failed
82    Failed,
83    /// Rebalance was cancelled
84    Cancelled,
85}
86
87/// A single shard movement operation
88#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct ShardMove {
90    /// Unique ID for this move
91    pub move_id: String,
92    /// Shard being moved
93    pub shard_id: String,
94    /// Source node
95    pub source_node: String,
96    /// Target node
97    pub target_node: String,
98    /// State of this move
99    pub state: MoveState,
100    /// Bytes transferred so far
101    pub bytes_transferred: u64,
102    /// Total bytes to transfer
103    pub total_bytes: u64,
104    /// When the move started
105    pub started_at: Option<u64>,
106    /// When the move completed
107    pub completed_at: Option<u64>,
108    /// Error message if failed
109    pub error: Option<String>,
110}
111
112/// State of a shard move
113#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
114pub enum MoveState {
115    /// Move is queued
116    Queued,
117    /// Data is being copied
118    Copying,
119    /// Verifying data integrity
120    Verifying,
121    /// Updating routing tables
122    Routing,
123    /// Cleaning up source
124    Cleanup,
125    /// Move completed
126    Completed,
127    /// Move failed
128    Failed,
129    /// Move was cancelled
130    Cancelled,
131}
132
133/// A rebalance plan
134#[derive(Debug, Clone, Serialize, Deserialize)]
135pub struct RebalancePlan {
136    /// Unique ID for this plan
137    pub plan_id: String,
138    /// Trigger that initiated the rebalance
139    pub trigger: RebalanceTrigger,
140    /// List of shard moves to execute
141    pub moves: Vec<ShardMove>,
142    /// Current state
143    pub state: RebalanceState,
144    /// When the plan was created
145    pub created_at: u64,
146    /// When execution started
147    pub started_at: Option<u64>,
148    /// When the plan completed
149    pub completed_at: Option<u64>,
150    /// Imbalance score before rebalance
151    pub initial_imbalance: f64,
152    /// Imbalance score after rebalance (estimated or actual)
153    pub final_imbalance: Option<f64>,
154}
155
156/// Statistics about rebalancing
157#[derive(Debug, Clone, Default, Serialize, Deserialize)]
158pub struct RebalanceStats {
159    pub total_rebalances: u64,
160    pub successful_rebalances: u64,
161    pub failed_rebalances: u64,
162    pub cancelled_rebalances: u64,
163    pub total_shards_moved: u64,
164    pub total_bytes_moved: u64,
165    pub current_moves_in_progress: u32,
166    pub pending_moves: u32,
167}
168
169/// Node capacity and load information
170#[derive(Debug, Clone, Serialize, Deserialize)]
171pub struct NodeLoad {
172    pub node_id: String,
173    pub shard_count: u32,
174    pub total_vectors: u64,
175    pub total_bytes: u64,
176    pub available_capacity: u64,
177    pub is_available: bool,
178}
179
180/// Errors during rebalancing
181#[derive(Debug, Error)]
182pub enum RebalanceError {
183    #[error("Rebalance already in progress: {0}")]
184    AlreadyInProgress(String),
185
186    #[error("No rebalance in progress")]
187    NoRebalanceInProgress,
188
189    #[error("Node not found: {0}")]
190    NodeNotFound(String),
191
192    #[error("Shard not found: {0}")]
193    ShardNotFound(String),
194
195    #[error("Move not found: {0}")]
196    MoveNotFound(String),
197
198    #[error("Move timed out: {0}")]
199    MoveTimeout(String),
200
201    #[error("Not enough capacity on target node: {0}")]
202    InsufficientCapacity(String),
203
204    #[error("Rebalance cancelled")]
205    Cancelled,
206
207    #[error("Rebalance failed: {0}")]
208    Failed(String),
209
210    #[error("Cluster not balanced enough to proceed")]
211    ClusterUnbalanced,
212}
213
214pub type Result<T> = std::result::Result<T, RebalanceError>;
215
216/// Manager for automatic rebalancing
217pub struct RebalanceManager {
218    config: RebalanceConfig,
219    current_plan: Arc<RwLock<Option<RebalancePlan>>>,
220    node_loads: Arc<RwLock<HashMap<String, NodeLoad>>>,
221    shard_assignments: Arc<RwLock<HashMap<String, String>>>, // shard_id -> node_id
222    stats: Arc<RwLock<RebalanceStats>>,
223    cancelled: AtomicBool,
224    last_rebalance_time: AtomicU64,
225}
226
227impl RebalanceManager {
228    /// Create a new rebalance manager
229    pub fn new(config: RebalanceConfig) -> Self {
230        Self {
231            config,
232            current_plan: Arc::new(RwLock::new(None)),
233            node_loads: Arc::new(RwLock::new(HashMap::new())),
234            shard_assignments: Arc::new(RwLock::new(HashMap::new())),
235            stats: Arc::new(RwLock::new(RebalanceStats::default())),
236            cancelled: AtomicBool::new(false),
237            last_rebalance_time: AtomicU64::new(0),
238        }
239    }
240
241    /// Register a node with its load information
242    pub fn register_node(&self, load: NodeLoad) {
243        let mut loads = self.node_loads.write();
244        loads.insert(load.node_id.clone(), load);
245    }
246
247    /// Remove a node from tracking
248    pub fn unregister_node(&self, node_id: &str) {
249        let mut loads = self.node_loads.write();
250        loads.remove(node_id);
251    }
252
253    /// Update node load information
254    pub fn update_node_load(&self, node_id: &str, update: impl FnOnce(&mut NodeLoad)) {
255        let mut loads = self.node_loads.write();
256        if let Some(load) = loads.get_mut(node_id) {
257            update(load);
258        }
259    }
260
261    /// Register a shard assignment
262    pub fn register_shard(&self, shard_id: &str, node_id: &str) {
263        let mut assignments = self.shard_assignments.write();
264        assignments.insert(shard_id.to_string(), node_id.to_string());
265    }
266
267    /// Remove a shard assignment
268    pub fn unregister_shard(&self, shard_id: &str) {
269        let mut assignments = self.shard_assignments.write();
270        assignments.remove(shard_id);
271    }
272
273    /// Calculate cluster imbalance score (0.0 = perfectly balanced, 1.0 = maximally imbalanced)
274    pub fn calculate_imbalance(&self) -> f64 {
275        let loads = self.node_loads.read();
276        let available_nodes: Vec<&NodeLoad> = loads.values().filter(|n| n.is_available).collect();
277
278        if available_nodes.is_empty() {
279            return 0.0;
280        }
281
282        let total_shards: u32 = available_nodes.iter().map(|n| n.shard_count).sum();
283        let avg_shards = total_shards as f64 / available_nodes.len() as f64;
284
285        if avg_shards == 0.0 {
286            return 0.0;
287        }
288
289        let variance: f64 = available_nodes
290            .iter()
291            .map(|n| {
292                let diff = n.shard_count as f64 - avg_shards;
293                diff * diff
294            })
295            .sum::<f64>()
296            / available_nodes.len() as f64;
297
298        let std_dev = variance.sqrt();
299        (std_dev / avg_shards).min(1.0)
300    }
301
302    /// Check if rebalance is needed
303    pub fn needs_rebalance(&self) -> bool {
304        let imbalance = self.calculate_imbalance();
305        imbalance > self.config.balance_threshold
306    }
307
308    /// Check if enough time has passed since last rebalance
309    pub fn can_rebalance(&self) -> bool {
310        let last_time = self.last_rebalance_time.load(Ordering::SeqCst);
311        let now = current_time_ms();
312        now - last_time >= self.config.min_rebalance_interval_ms
313    }
314
315    /// Trigger a rebalance operation
316    pub fn trigger_rebalance(&self, trigger: RebalanceTrigger) -> Result<RebalancePlan> {
317        // Check if rebalance already in progress
318        {
319            let plan = self.current_plan.read();
320            if let Some(ref p) = *plan {
321                if matches!(
322                    p.state,
323                    RebalanceState::Pending | RebalanceState::Planning | RebalanceState::Executing
324                ) {
325                    return Err(RebalanceError::AlreadyInProgress(p.plan_id.clone()));
326                }
327            }
328        }
329
330        // Check interval
331        if !self.can_rebalance() && trigger != RebalanceTrigger::Manual {
332            return Err(RebalanceError::Failed(
333                "Too soon since last rebalance".to_string(),
334            ));
335        }
336
337        self.cancelled.store(false, Ordering::SeqCst);
338        let initial_imbalance = self.calculate_imbalance();
339
340        // Create initial plan
341        let plan = RebalancePlan {
342            plan_id: generate_plan_id(),
343            trigger,
344            moves: Vec::new(),
345            state: RebalanceState::Pending,
346            created_at: current_time_ms(),
347            started_at: None,
348            completed_at: None,
349            initial_imbalance,
350            final_imbalance: None,
351        };
352
353        {
354            let mut current = self.current_plan.write();
355            *current = Some(plan.clone());
356        }
357
358        info!(
359            "Triggered rebalance: {} (trigger: {:?}, imbalance: {:.2}%)",
360            plan.plan_id,
361            trigger,
362            initial_imbalance * 100.0
363        );
364
365        Ok(plan)
366    }
367
368    /// Plan the shard movements
369    pub fn create_move_plan(&self) -> Result<Vec<ShardMove>> {
370        let loads = self.node_loads.read();
371        let assignments = self.shard_assignments.read();
372
373        let available_nodes: Vec<&NodeLoad> = loads.values().filter(|n| n.is_available).collect();
374
375        if available_nodes.is_empty() {
376            return Ok(Vec::new());
377        }
378
379        let total_shards: u32 = available_nodes.iter().map(|n| n.shard_count).sum();
380        let target_per_node = total_shards / available_nodes.len() as u32;
381        let remainder = total_shards % available_nodes.len() as u32;
382
383        // Identify overloaded and underloaded nodes
384        let mut overloaded: Vec<(&NodeLoad, u32)> = Vec::new(); // (node, excess)
385        let mut underloaded: Vec<(&NodeLoad, u32)> = Vec::new(); // (node, deficit)
386
387        for (i, node) in available_nodes.iter().enumerate() {
388            let target = target_per_node + if (i as u32) < remainder { 1 } else { 0 };
389            if node.shard_count > target {
390                overloaded.push((node, node.shard_count - target));
391            } else if node.shard_count < target {
392                underloaded.push((node, target - node.shard_count));
393            }
394        }
395
396        // Sort for consistent ordering
397        overloaded.sort_by(|a, b| b.1.cmp(&a.1)); // Most overloaded first
398        underloaded.sort_by(|a, b| b.1.cmp(&a.1)); // Most underloaded first
399
400        let mut moves = Vec::new();
401        let mut move_count = 0;
402
403        // Find shards to move from overloaded to underloaded nodes
404        for (overloaded_node, mut excess) in overloaded {
405            if excess == 0 {
406                continue;
407            }
408
409            // Find shards on this node
410            let shards_on_node: Vec<&String> = assignments
411                .iter()
412                .filter(|(_, node_id)| *node_id == &overloaded_node.node_id)
413                .map(|(shard_id, _)| shard_id)
414                .collect();
415
416            for shard_id in shards_on_node {
417                if excess == 0 {
418                    break;
419                }
420
421                // Find an underloaded node that can accept this shard
422                for (underloaded_node, deficit) in underloaded.iter_mut() {
423                    if *deficit == 0 {
424                        continue;
425                    }
426
427                    // Create move
428                    moves.push(ShardMove {
429                        move_id: format!("move-{}", move_count),
430                        shard_id: shard_id.clone(),
431                        source_node: overloaded_node.node_id.clone(),
432                        target_node: underloaded_node.node_id.clone(),
433                        state: MoveState::Queued,
434                        bytes_transferred: 0,
435                        total_bytes: 0, // Would be calculated from actual shard size
436                        started_at: None,
437                        completed_at: None,
438                        error: None,
439                    });
440
441                    move_count += 1;
442                    excess -= 1;
443                    *deficit -= 1;
444                    break;
445                }
446            }
447        }
448
449        debug!("Created rebalance plan with {} moves", moves.len());
450        Ok(moves)
451    }
452
453    /// Start executing the rebalance plan
454    pub fn start_execution(&self) -> Result<()> {
455        let moves = self.create_move_plan()?;
456
457        let mut plan = self.current_plan.write();
458        let p = plan.as_mut().ok_or(RebalanceError::NoRebalanceInProgress)?;
459
460        if p.state != RebalanceState::Pending {
461            return Err(RebalanceError::Failed(
462                "Invalid state for execution".to_string(),
463            ));
464        }
465
466        p.moves = moves;
467        p.state = RebalanceState::Executing;
468        p.started_at = Some(current_time_ms());
469
470        info!(
471            "Started rebalance execution: {} ({} moves)",
472            p.plan_id,
473            p.moves.len()
474        );
475        Ok(())
476    }
477
478    /// Start a specific move
479    pub fn start_move(&self, move_id: &str) -> Result<()> {
480        let mut plan = self.current_plan.write();
481        let p = plan.as_mut().ok_or(RebalanceError::NoRebalanceInProgress)?;
482
483        let m = p
484            .moves
485            .iter_mut()
486            .find(|m| m.move_id == move_id)
487            .ok_or_else(|| RebalanceError::MoveNotFound(move_id.to_string()))?;
488
489        if m.state != MoveState::Queued {
490            return Ok(()); // Already started
491        }
492
493        m.state = MoveState::Copying;
494        m.started_at = Some(current_time_ms());
495
496        let mut stats = self.stats.write();
497        stats.current_moves_in_progress += 1;
498
499        debug!(
500            "Started move: {} ({} -> {})",
501            move_id, m.source_node, m.target_node
502        );
503        Ok(())
504    }
505
506    /// Update move progress
507    pub fn update_move_progress(
508        &self,
509        move_id: &str,
510        bytes_transferred: u64,
511        total_bytes: u64,
512    ) -> Result<()> {
513        let mut plan = self.current_plan.write();
514        let p = plan.as_mut().ok_or(RebalanceError::NoRebalanceInProgress)?;
515
516        let m = p
517            .moves
518            .iter_mut()
519            .find(|m| m.move_id == move_id)
520            .ok_or_else(|| RebalanceError::MoveNotFound(move_id.to_string()))?;
521
522        m.bytes_transferred = bytes_transferred;
523        m.total_bytes = total_bytes;
524        Ok(())
525    }
526
527    /// Advance move to verification stage
528    pub fn advance_to_verify(&self, move_id: &str) -> Result<()> {
529        let mut plan = self.current_plan.write();
530        let p = plan.as_mut().ok_or(RebalanceError::NoRebalanceInProgress)?;
531
532        let m = p
533            .moves
534            .iter_mut()
535            .find(|m| m.move_id == move_id)
536            .ok_or_else(|| RebalanceError::MoveNotFound(move_id.to_string()))?;
537
538        m.state = MoveState::Verifying;
539        debug!("Move {} advanced to verification", move_id);
540        Ok(())
541    }
542
543    /// Advance move to routing update stage
544    pub fn advance_to_routing(&self, move_id: &str) -> Result<()> {
545        let mut plan = self.current_plan.write();
546        let p = plan.as_mut().ok_or(RebalanceError::NoRebalanceInProgress)?;
547
548        let m = p
549            .moves
550            .iter_mut()
551            .find(|m| m.move_id == move_id)
552            .ok_or_else(|| RebalanceError::MoveNotFound(move_id.to_string()))?;
553
554        m.state = MoveState::Routing;
555        debug!("Move {} advanced to routing update", move_id);
556        Ok(())
557    }
558
559    /// Complete a move
560    pub fn complete_move(&self, move_id: &str) -> Result<()> {
561        let (shard_id, target_node) = {
562            let mut plan = self.current_plan.write();
563            let p = plan.as_mut().ok_or(RebalanceError::NoRebalanceInProgress)?;
564
565            let m = p
566                .moves
567                .iter_mut()
568                .find(|m| m.move_id == move_id)
569                .ok_or_else(|| RebalanceError::MoveNotFound(move_id.to_string()))?;
570
571            m.state = MoveState::Completed;
572            m.completed_at = Some(current_time_ms());
573
574            (m.shard_id.clone(), m.target_node.clone())
575        };
576
577        // Update shard assignment
578        {
579            let mut assignments = self.shard_assignments.write();
580            assignments.insert(shard_id, target_node);
581        }
582
583        // Update stats
584        {
585            let mut stats = self.stats.write();
586            stats.current_moves_in_progress = stats.current_moves_in_progress.saturating_sub(1);
587            stats.total_shards_moved += 1;
588        }
589
590        // Check if rebalance is complete
591        self.check_completion()?;
592
593        debug!("Completed move: {}", move_id);
594        Ok(())
595    }
596
597    /// Fail a move
598    pub fn fail_move(&self, move_id: &str, error: &str) -> Result<()> {
599        {
600            let mut plan = self.current_plan.write();
601            let p = plan.as_mut().ok_or(RebalanceError::NoRebalanceInProgress)?;
602
603            let m = p
604                .moves
605                .iter_mut()
606                .find(|m| m.move_id == move_id)
607                .ok_or_else(|| RebalanceError::MoveNotFound(move_id.to_string()))?;
608
609            m.state = MoveState::Failed;
610            m.error = Some(error.to_string());
611            m.completed_at = Some(current_time_ms());
612        }
613
614        // Update stats
615        {
616            let mut stats = self.stats.write();
617            stats.current_moves_in_progress = stats.current_moves_in_progress.saturating_sub(1);
618        }
619
620        warn!("Move {} failed: {}", move_id, error);
621        Ok(())
622    }
623
624    /// Check if rebalance is complete
625    fn check_completion(&self) -> Result<()> {
626        let mut plan = self.current_plan.write();
627        let p = match plan.as_mut() {
628            Some(p) => p,
629            None => return Ok(()),
630        };
631
632        if p.state != RebalanceState::Executing {
633            return Ok(());
634        }
635
636        let all_done = p.moves.iter().all(|m| {
637            matches!(
638                m.state,
639                MoveState::Completed | MoveState::Failed | MoveState::Cancelled
640            )
641        });
642
643        if all_done {
644            let any_failed = p.moves.iter().any(|m| m.state == MoveState::Failed);
645
646            if any_failed {
647                p.state = RebalanceState::Failed;
648                let mut stats = self.stats.write();
649                stats.failed_rebalances += 1;
650            } else {
651                p.state = RebalanceState::Completed;
652                let mut stats = self.stats.write();
653                stats.successful_rebalances += 1;
654            }
655
656            p.completed_at = Some(current_time_ms());
657            self.last_rebalance_time
658                .store(current_time_ms(), Ordering::SeqCst);
659
660            // Calculate final imbalance
661            drop(plan);
662            let final_imbalance = self.calculate_imbalance();
663            let mut plan = self.current_plan.write();
664            if let Some(p) = plan.as_mut() {
665                p.final_imbalance = Some(final_imbalance);
666            }
667
668            info!(
669                "Rebalance completed (final imbalance: {:.2}%)",
670                final_imbalance * 100.0
671            );
672        }
673
674        Ok(())
675    }
676
677    /// Cancel the current rebalance
678    pub fn cancel(&self) -> Result<()> {
679        self.cancelled.store(true, Ordering::SeqCst);
680
681        let mut plan = self.current_plan.write();
682        let p = plan.as_mut().ok_or(RebalanceError::NoRebalanceInProgress)?;
683
684        // Cancel all queued/in-progress moves
685        for m in p.moves.iter_mut() {
686            if matches!(
687                m.state,
688                MoveState::Queued | MoveState::Copying | MoveState::Verifying | MoveState::Routing
689            ) {
690                m.state = MoveState::Cancelled;
691            }
692        }
693
694        p.state = RebalanceState::Cancelled;
695        p.completed_at = Some(current_time_ms());
696
697        let mut stats = self.stats.write();
698        stats.cancelled_rebalances += 1;
699
700        info!("Rebalance cancelled: {}", p.plan_id);
701        Ok(())
702    }
703
704    /// Check if rebalance was cancelled
705    pub fn is_cancelled(&self) -> bool {
706        self.cancelled.load(Ordering::SeqCst)
707    }
708
709    /// Get the current rebalance plan
710    pub fn get_plan(&self) -> Option<RebalancePlan> {
711        self.current_plan.read().clone()
712    }
713
714    /// Get moves that are ready to start
715    pub fn get_queued_moves(&self) -> Vec<String> {
716        let plan = self.current_plan.read();
717        match plan.as_ref() {
718            Some(p) => p
719                .moves
720                .iter()
721                .filter(|m| m.state == MoveState::Queued)
722                .map(|m| m.move_id.clone())
723                .collect(),
724            None => Vec::new(),
725        }
726    }
727
728    /// Get currently executing moves
729    pub fn get_active_moves(&self) -> Vec<ShardMove> {
730        let plan = self.current_plan.read();
731        match plan.as_ref() {
732            Some(p) => p
733                .moves
734                .iter()
735                .filter(|m| {
736                    matches!(
737                        m.state,
738                        MoveState::Copying | MoveState::Verifying | MoveState::Routing
739                    )
740                })
741                .cloned()
742                .collect(),
743            None => Vec::new(),
744        }
745    }
746
747    /// Get rebalance statistics
748    pub fn get_stats(&self) -> RebalanceStats {
749        let mut stats = self.stats.read().clone();
750
751        // Count pending moves
752        if let Some(ref plan) = *self.current_plan.read() {
753            stats.pending_moves = plan
754                .moves
755                .iter()
756                .filter(|m| m.state == MoveState::Queued)
757                .count() as u32;
758        }
759
760        stats
761    }
762
763    /// Get node load information
764    pub fn get_node_load(&self, node_id: &str) -> Option<NodeLoad> {
765        self.node_loads.read().get(node_id).cloned()
766    }
767
768    /// Get all node loads
769    pub fn get_all_node_loads(&self) -> Vec<NodeLoad> {
770        self.node_loads.read().values().cloned().collect()
771    }
772
773    /// Handle node join event
774    pub fn on_node_joined(&self, node_id: &str, capacity: u64) -> Result<Option<RebalancePlan>> {
775        // Register the new node
776        self.register_node(NodeLoad {
777            node_id: node_id.to_string(),
778            shard_count: 0,
779            total_vectors: 0,
780            total_bytes: 0,
781            available_capacity: capacity,
782            is_available: true,
783        });
784
785        info!("Node joined: {}", node_id);
786
787        // Trigger rebalance if enabled
788        if self.config.auto_rebalance && self.needs_rebalance() && self.can_rebalance() {
789            return Ok(Some(self.trigger_rebalance(RebalanceTrigger::NodeJoined)?));
790        }
791
792        Ok(None)
793    }
794
795    /// Handle node leave event
796    pub fn on_node_left(&self, node_id: &str) -> Result<Option<RebalancePlan>> {
797        // Mark node as unavailable
798        self.update_node_load(node_id, |load| {
799            load.is_available = false;
800        });
801
802        info!("Node left: {}", node_id);
803
804        // Trigger rebalance if enabled (to reassign shards from the departed node)
805        if self.config.auto_rebalance {
806            return Ok(Some(self.trigger_rebalance(RebalanceTrigger::NodeLeft)?));
807        }
808
809        Ok(None)
810    }
811}
812
813/// Get current time in milliseconds
814fn current_time_ms() -> u64 {
815    std::time::SystemTime::now()
816        .duration_since(std::time::UNIX_EPOCH)
817        .unwrap_or_default()
818        .as_millis() as u64
819}
820
821/// Generate a unique plan ID
822fn generate_plan_id() -> String {
823    use std::time::{SystemTime, UNIX_EPOCH};
824    let timestamp = SystemTime::now()
825        .duration_since(UNIX_EPOCH)
826        .unwrap_or_default()
827        .as_millis();
828    format!("rebalance-{}", timestamp)
829}
830
831#[cfg(test)]
832mod tests {
833    use super::*;
834
835    #[test]
836    fn test_rebalance_config_defaults() {
837        let config = RebalanceConfig::default();
838        assert_eq!(config.max_concurrent_moves, 2);
839        assert_eq!(config.rebalance_delay_ms, 5000);
840        assert!(config.auto_rebalance);
841        assert_eq!(config.balance_threshold, 0.1);
842    }
843
844    #[test]
845    fn test_register_and_unregister_node() {
846        let manager = RebalanceManager::new(RebalanceConfig::default());
847
848        manager.register_node(NodeLoad {
849            node_id: "node1".to_string(),
850            shard_count: 5,
851            total_vectors: 1000,
852            total_bytes: 10000,
853            available_capacity: 100000,
854            is_available: true,
855        });
856
857        assert!(manager.get_node_load("node1").is_some());
858
859        manager.unregister_node("node1");
860        assert!(manager.get_node_load("node1").is_none());
861    }
862
863    #[test]
864    fn test_calculate_imbalance_balanced() {
865        let manager = RebalanceManager::new(RebalanceConfig::default());
866
867        // All nodes have same shard count = perfectly balanced
868        for i in 0..3 {
869            manager.register_node(NodeLoad {
870                node_id: format!("node{}", i),
871                shard_count: 10,
872                total_vectors: 1000,
873                total_bytes: 10000,
874                available_capacity: 100000,
875                is_available: true,
876            });
877        }
878
879        let imbalance = manager.calculate_imbalance();
880        assert_eq!(imbalance, 0.0);
881    }
882
883    #[test]
884    fn test_calculate_imbalance_unbalanced() {
885        let manager = RebalanceManager::new(RebalanceConfig::default());
886
887        // One node has much more shards
888        manager.register_node(NodeLoad {
889            node_id: "node0".to_string(),
890            shard_count: 30,
891            total_vectors: 1000,
892            total_bytes: 10000,
893            available_capacity: 100000,
894            is_available: true,
895        });
896        manager.register_node(NodeLoad {
897            node_id: "node1".to_string(),
898            shard_count: 5,
899            total_vectors: 1000,
900            total_bytes: 10000,
901            available_capacity: 100000,
902            is_available: true,
903        });
904        manager.register_node(NodeLoad {
905            node_id: "node2".to_string(),
906            shard_count: 5,
907            total_vectors: 1000,
908            total_bytes: 10000,
909            available_capacity: 100000,
910            is_available: true,
911        });
912
913        let imbalance = manager.calculate_imbalance();
914        assert!(imbalance > 0.0);
915        assert!(manager.needs_rebalance());
916    }
917
918    #[test]
919    fn test_trigger_rebalance() {
920        let config = RebalanceConfig {
921            min_rebalance_interval_ms: 0, // Allow immediate rebalance
922            ..Default::default()
923        };
924        let manager = RebalanceManager::new(config);
925
926        manager.register_node(NodeLoad {
927            node_id: "node1".to_string(),
928            shard_count: 10,
929            total_vectors: 1000,
930            total_bytes: 10000,
931            available_capacity: 100000,
932            is_available: true,
933        });
934
935        let plan = manager.trigger_rebalance(RebalanceTrigger::Manual).unwrap();
936        assert_eq!(plan.state, RebalanceState::Pending);
937        assert_eq!(plan.trigger, RebalanceTrigger::Manual);
938    }
939
940    #[test]
941    fn test_create_move_plan() {
942        let config = RebalanceConfig {
943            min_rebalance_interval_ms: 0,
944            ..Default::default()
945        };
946        let manager = RebalanceManager::new(config);
947
948        // Create imbalanced cluster
949        manager.register_node(NodeLoad {
950            node_id: "node0".to_string(),
951            shard_count: 4,
952            total_vectors: 0,
953            total_bytes: 0,
954            available_capacity: 100000,
955            is_available: true,
956        });
957        manager.register_node(NodeLoad {
958            node_id: "node1".to_string(),
959            shard_count: 0,
960            total_vectors: 0,
961            total_bytes: 0,
962            available_capacity: 100000,
963            is_available: true,
964        });
965
966        // Register shards on node0
967        for i in 0..4 {
968            manager.register_shard(&format!("shard{}", i), "node0");
969        }
970
971        manager.trigger_rebalance(RebalanceTrigger::Manual).unwrap();
972        let moves = manager.create_move_plan().unwrap();
973
974        // Should move 2 shards to balance (4 shards / 2 nodes = 2 each)
975        assert_eq!(moves.len(), 2);
976        for m in &moves {
977            assert_eq!(m.source_node, "node0");
978            assert_eq!(m.target_node, "node1");
979        }
980    }
981
982    #[test]
983    fn test_move_lifecycle() {
984        let config = RebalanceConfig {
985            min_rebalance_interval_ms: 0,
986            ..Default::default()
987        };
988        let manager = RebalanceManager::new(config);
989
990        manager.register_node(NodeLoad {
991            node_id: "node0".to_string(),
992            shard_count: 2,
993            total_vectors: 0,
994            total_bytes: 0,
995            available_capacity: 100000,
996            is_available: true,
997        });
998        manager.register_node(NodeLoad {
999            node_id: "node1".to_string(),
1000            shard_count: 0,
1001            total_vectors: 0,
1002            total_bytes: 0,
1003            available_capacity: 100000,
1004            is_available: true,
1005        });
1006        manager.register_shard("shard0", "node0");
1007        manager.register_shard("shard1", "node0");
1008
1009        manager.trigger_rebalance(RebalanceTrigger::Manual).unwrap();
1010        manager.start_execution().unwrap();
1011
1012        let queued = manager.get_queued_moves();
1013        assert!(!queued.is_empty());
1014
1015        let move_id = &queued[0];
1016
1017        // Progress through states
1018        manager.start_move(move_id).unwrap();
1019        manager.update_move_progress(move_id, 500, 1000).unwrap();
1020        manager.advance_to_verify(move_id).unwrap();
1021        manager.advance_to_routing(move_id).unwrap();
1022        manager.complete_move(move_id).unwrap();
1023
1024        let stats = manager.get_stats();
1025        assert_eq!(stats.total_shards_moved, 1);
1026    }
1027
1028    #[test]
1029    fn test_cancel_rebalance() {
1030        let config = RebalanceConfig {
1031            min_rebalance_interval_ms: 0,
1032            ..Default::default()
1033        };
1034        let manager = RebalanceManager::new(config);
1035
1036        manager.register_node(NodeLoad {
1037            node_id: "node0".to_string(),
1038            shard_count: 4,
1039            total_vectors: 0,
1040            total_bytes: 0,
1041            available_capacity: 100000,
1042            is_available: true,
1043        });
1044        manager.register_node(NodeLoad {
1045            node_id: "node1".to_string(),
1046            shard_count: 0,
1047            total_vectors: 0,
1048            total_bytes: 0,
1049            available_capacity: 100000,
1050            is_available: true,
1051        });
1052
1053        manager.trigger_rebalance(RebalanceTrigger::Manual).unwrap();
1054        manager.start_execution().unwrap();
1055        manager.cancel().unwrap();
1056
1057        let plan = manager.get_plan().unwrap();
1058        assert_eq!(plan.state, RebalanceState::Cancelled);
1059        assert!(manager.is_cancelled());
1060
1061        let stats = manager.get_stats();
1062        assert_eq!(stats.cancelled_rebalances, 1);
1063    }
1064
1065    #[test]
1066    fn test_on_node_joined() {
1067        let config = RebalanceConfig {
1068            min_rebalance_interval_ms: 0,
1069            auto_rebalance: true,
1070            balance_threshold: 0.01, // Very low threshold
1071            ..Default::default()
1072        };
1073        let manager = RebalanceManager::new(config);
1074
1075        // Existing imbalanced cluster
1076        manager.register_node(NodeLoad {
1077            node_id: "node0".to_string(),
1078            shard_count: 10,
1079            total_vectors: 0,
1080            total_bytes: 0,
1081            available_capacity: 100000,
1082            is_available: true,
1083        });
1084
1085        // New node joins
1086        let result = manager.on_node_joined("node1", 100000).unwrap();
1087
1088        // Should trigger rebalance
1089        assert!(result.is_some());
1090        let plan = result.unwrap();
1091        assert_eq!(plan.trigger, RebalanceTrigger::NodeJoined);
1092    }
1093
1094    #[test]
1095    fn test_on_node_left() {
1096        let config = RebalanceConfig {
1097            min_rebalance_interval_ms: 0,
1098            auto_rebalance: true,
1099            ..Default::default()
1100        };
1101        let manager = RebalanceManager::new(config);
1102
1103        manager.register_node(NodeLoad {
1104            node_id: "node0".to_string(),
1105            shard_count: 5,
1106            total_vectors: 0,
1107            total_bytes: 0,
1108            available_capacity: 100000,
1109            is_available: true,
1110        });
1111        manager.register_node(NodeLoad {
1112            node_id: "node1".to_string(),
1113            shard_count: 5,
1114            total_vectors: 0,
1115            total_bytes: 0,
1116            available_capacity: 100000,
1117            is_available: true,
1118        });
1119
1120        // Node leaves
1121        let result = manager.on_node_left("node1").unwrap();
1122
1123        // Should trigger rebalance
1124        assert!(result.is_some());
1125        let plan = result.unwrap();
1126        assert_eq!(plan.trigger, RebalanceTrigger::NodeLeft);
1127
1128        // Node should be marked unavailable
1129        let load = manager.get_node_load("node1").unwrap();
1130        assert!(!load.is_available);
1131    }
1132
1133    #[test]
1134    fn test_rebalance_stats() {
1135        let config = RebalanceConfig {
1136            min_rebalance_interval_ms: 0,
1137            ..Default::default()
1138        };
1139        let manager = RebalanceManager::new(config);
1140
1141        let stats = manager.get_stats();
1142        assert_eq!(stats.total_rebalances, 0);
1143        assert_eq!(stats.total_shards_moved, 0);
1144    }
1145
1146    #[test]
1147    fn test_duplicate_rebalance_rejected() {
1148        let config = RebalanceConfig {
1149            min_rebalance_interval_ms: 0,
1150            ..Default::default()
1151        };
1152        let manager = RebalanceManager::new(config);
1153
1154        manager.register_node(NodeLoad {
1155            node_id: "node0".to_string(),
1156            shard_count: 10,
1157            total_vectors: 0,
1158            total_bytes: 0,
1159            available_capacity: 100000,
1160            is_available: true,
1161        });
1162
1163        manager.trigger_rebalance(RebalanceTrigger::Manual).unwrap();
1164        manager.start_execution().unwrap();
1165
1166        // Second rebalance should fail
1167        let result = manager.trigger_rebalance(RebalanceTrigger::Manual);
1168        assert!(matches!(result, Err(RebalanceError::AlreadyInProgress(_))));
1169    }
1170}