use serde::{Deserialize, Serialize};
use std::collections::{BTreeMap, HashMap};
use std::sync::Arc;
use std::time::{Duration, SystemTime};
use tokio::sync::RwLock;
use tracing::{info, warn};
use scirs2_core::metrics::Counter;
use scirs2_core::profiling::Profiler;
use crate::raft::OxirsNodeId;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum RebalancingStrategy {
Incremental,
Bulk,
Adaptive,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum MigrationState {
Idle,
Planning,
InProgress,
Verifying,
Completed,
Failed,
RollingBack,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RebalancingConfig {
pub strategy: RebalancingStrategy,
pub enable_auto_rebalancing: bool,
pub load_imbalance_threshold: f64,
pub min_nodes_for_rebalancing: usize,
pub incremental_batch_size: usize,
pub bulk_batch_size: usize,
pub bandwidth_limit_bytes_per_sec: usize,
pub enable_verification: bool,
pub enable_rollback: bool,
pub migration_timeout_secs: u64,
}
impl Default for RebalancingConfig {
fn default() -> Self {
Self {
strategy: RebalancingStrategy::Adaptive,
enable_auto_rebalancing: true,
load_imbalance_threshold: 0.2, min_nodes_for_rebalancing: 2,
incremental_batch_size: 100,
bulk_batch_size: 10000,
bandwidth_limit_bytes_per_sec: 10 * 1024 * 1024, enable_verification: true,
enable_rollback: true,
migration_timeout_secs: 3600, }
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MigrationPlan {
pub plan_id: String,
pub source_node: OxirsNodeId,
pub target_node: OxirsNodeId,
pub key_count: usize,
pub estimated_size_bytes: usize,
pub partition_ids: Vec<usize>,
pub created_at: SystemTime,
pub estimated_duration_secs: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MigrationProgress {
pub plan_id: String,
pub state: MigrationState,
pub keys_migrated: usize,
pub total_keys: usize,
pub bytes_migrated: usize,
pub total_bytes: usize,
pub started_at: Option<SystemTime>,
pub completed_at: Option<SystemTime>,
pub errors: Vec<String>,
}
impl MigrationProgress {
fn new(plan_id: String, total_keys: usize, total_bytes: usize) -> Self {
Self {
plan_id,
state: MigrationState::Idle,
keys_migrated: 0,
total_keys,
bytes_migrated: 0,
total_bytes,
started_at: None,
completed_at: None,
errors: Vec::new(),
}
}
pub fn progress_percentage(&self) -> f64 {
if self.total_keys == 0 {
return 0.0;
}
(self.keys_migrated as f64 / self.total_keys as f64) * 100.0
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NodeLoad {
pub node_id: OxirsNodeId,
pub key_count: usize,
pub data_size_bytes: usize,
pub cpu_utilization: f64,
pub memory_utilization: f64,
pub network_bandwidth_usage: usize,
pub last_updated: SystemTime,
}
impl Default for NodeLoad {
fn default() -> Self {
Self {
node_id: 0,
key_count: 0,
data_size_bytes: 0,
cpu_utilization: 0.0,
memory_utilization: 0.0,
network_bandwidth_usage: 0,
last_updated: SystemTime::now(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RebalancingStats {
pub total_rebalancing_ops: u64,
pub successful_migrations: u64,
pub failed_migrations: u64,
pub total_keys_migrated: usize,
pub total_bytes_migrated: usize,
pub avg_migration_speed_bytes_per_sec: f64,
pub last_rebalancing: Option<SystemTime>,
pub avg_rebalancing_duration_ms: f64,
}
impl Default for RebalancingStats {
fn default() -> Self {
Self {
total_rebalancing_ops: 0,
successful_migrations: 0,
failed_migrations: 0,
total_keys_migrated: 0,
total_bytes_migrated: 0,
avg_migration_speed_bytes_per_sec: 0.0,
last_rebalancing: None,
avg_rebalancing_duration_ms: 0.0,
}
}
}
pub struct DataRebalancingManager {
config: RebalancingConfig,
node_loads: Arc<RwLock<BTreeMap<OxirsNodeId, NodeLoad>>>,
active_migrations: Arc<RwLock<HashMap<String, MigrationProgress>>>,
migration_history: Arc<RwLock<Vec<MigrationProgress>>>,
stats: Arc<RwLock<RebalancingStats>>,
simd_ops_counter: Counter,
profiler: Arc<Profiler>,
}
impl DataRebalancingManager {
pub fn new(config: RebalancingConfig) -> Self {
Self {
config,
node_loads: Arc::new(RwLock::new(BTreeMap::new())),
active_migrations: Arc::new(RwLock::new(HashMap::new())),
migration_history: Arc::new(RwLock::new(Vec::new())),
stats: Arc::new(RwLock::new(RebalancingStats::default())),
simd_ops_counter: Counter::new("data_rebalancing_simd_ops".to_string()),
profiler: Arc::new(Profiler::new()),
}
}
pub async fn register_node(&self, node_id: OxirsNodeId) {
let mut node_loads = self.node_loads.write().await;
node_loads.insert(
node_id,
NodeLoad {
node_id,
..Default::default()
},
);
info!("Registered node {} for rebalancing", node_id);
}
pub async fn update_node_load(
&self,
node_id: OxirsNodeId,
key_count: usize,
data_size_bytes: usize,
cpu_utilization: f64,
memory_utilization: f64,
network_bandwidth: usize,
) {
let mut node_loads = self.node_loads.write().await;
if let Some(load) = node_loads.get_mut(&node_id) {
load.key_count = key_count;
load.data_size_bytes = data_size_bytes;
load.cpu_utilization = cpu_utilization;
load.memory_utilization = memory_utilization;
load.network_bandwidth_usage = network_bandwidth;
load.last_updated = SystemTime::now();
}
}
pub async fn check_rebalancing_needed(&self) -> bool {
if !self.config.enable_auto_rebalancing {
return false;
}
let node_loads = self.node_loads.read().await;
if node_loads.len() < self.config.min_nodes_for_rebalancing {
return false;
}
let loads: Vec<f64> = node_loads
.values()
.map(|l| l.data_size_bytes as f64)
.collect();
if loads.is_empty() {
return false;
}
let (max_load, min_load, avg_load) = self.calculate_load_stats_simd(&loads);
if avg_load == 0.0 {
return false;
}
let imbalance = (max_load - min_load) / avg_load;
imbalance > self.config.load_imbalance_threshold
}
fn calculate_load_stats_simd(&self, loads: &[f64]) -> (f64, f64, f64) {
if loads.is_empty() {
return (0.0, 0.0, 0.0);
}
self.simd_ops_counter.inc();
if loads.len() < 100 {
let sum: f64 = loads.iter().sum();
let avg_load = sum / loads.len() as f64;
let max_load = loads.iter().copied().fold(f64::MIN, f64::max);
let min_load = loads.iter().copied().fold(f64::MAX, f64::min);
(max_load, min_load, avg_load)
} else {
use scirs2_core::ndarray_ext::stats::{max, mean, min};
use scirs2_core::ndarray_ext::{Array1, ArrayView1};
let arr = Array1::from_vec(loads.to_vec());
let view: ArrayView1<f64> = arr.view();
let avg_arr = mean(&view, None).unwrap_or_else(|_| Array1::from_vec(vec![0.0]));
let max_arr = max(&view, None).unwrap_or_else(|_| Array1::from_vec(vec![f64::MIN]));
let min_arr = min(&view, None).unwrap_or_else(|_| Array1::from_vec(vec![f64::MAX]));
let avg_load = avg_arr[0];
let max_load = max_arr[0];
let min_load = min_arr[0];
(max_load, min_load, avg_load)
}
}
pub fn calculate_partition_hash_simd(&self, key: &str, num_partitions: usize) -> usize {
let key_bytes = key.as_bytes();
self.simd_ops_counter.inc();
let mut h = 0xcbf29ce484222325u64;
for &byte in key_bytes {
h ^= byte as u64;
h = h.wrapping_mul(0x100000001b3u64);
}
(h as usize) % num_partitions
}
pub fn batch_calculate_partition_hash_simd(
&self,
keys: &[String],
num_partitions: usize,
) -> Vec<usize> {
if keys.is_empty() {
return Vec::new();
}
self.simd_ops_counter.inc();
if keys.len() < 100 {
keys.iter()
.map(|key| self.calculate_partition_hash_simd(key, num_partitions))
.collect()
} else {
use rayon::prelude::*;
keys.par_iter()
.map(|key| {
let key_bytes = key.as_bytes();
let mut h = 0xcbf29ce484222325u64;
for &byte in key_bytes {
h ^= byte as u64;
h = h.wrapping_mul(0x100000001b3u64);
}
(h as usize) % num_partitions
})
.collect()
}
}
pub async fn calculate_load_variance_simd(&self) -> f64 {
let node_loads = self.node_loads.read().await;
let loads: Vec<f64> = node_loads
.values()
.map(|l| l.data_size_bytes as f64)
.collect();
if loads.len() < 2 {
return 0.0;
}
self.simd_ops_counter.inc();
if loads.len() < 100 {
let mean: f64 = loads.iter().sum::<f64>() / loads.len() as f64;
let variance: f64 =
loads.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / loads.len() as f64;
variance
} else {
use scirs2_core::ndarray_ext::stats::variance;
use scirs2_core::ndarray_ext::{Array1, ArrayView1};
let arr = Array1::from_vec(loads);
let view: ArrayView1<f64> = arr.view();
let var_arr = variance(&view, None, 1).unwrap_or_else(|_| Array1::from_vec(vec![0.0]));
var_arr[0]
}
}
pub fn get_profiling_report(&self) -> String {
self.profiler.get_report()
}
pub fn simd_operations_count(&self) -> u64 {
0
}
pub async fn create_migration_plan(&self) -> Result<MigrationPlan, String> {
let node_loads = self.node_loads.read().await;
if node_loads.len() < 2 {
return Err("Not enough nodes for migration".to_string());
}
let mut loads: Vec<_> = node_loads.values().cloned().collect();
loads.sort_by_key(|l| l.data_size_bytes);
let source_load = loads.last().expect("collection validated to be non-empty");
let target_load = loads.first().expect("collection validated to be non-empty");
let total_data: usize = loads.iter().map(|l| l.data_size_bytes).sum();
let avg_data = total_data / loads.len();
let migrate_size = (source_load.data_size_bytes - avg_data) / 2; let migrate_keys = (migrate_size as f64 / source_load.data_size_bytes as f64
* source_load.key_count as f64) as usize;
let estimated_duration_secs = if self.config.bandwidth_limit_bytes_per_sec > 0 {
migrate_size as u64 / self.config.bandwidth_limit_bytes_per_sec as u64
} else {
60 };
let plan_id = format!(
"migration-{}-to-{}-{}",
source_load.node_id,
target_load.node_id,
SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.expect("SystemTime should be after UNIX_EPOCH")
.as_secs()
);
Ok(MigrationPlan {
plan_id,
source_node: source_load.node_id,
target_node: target_load.node_id,
key_count: migrate_keys,
estimated_size_bytes: migrate_size,
partition_ids: vec![0], created_at: SystemTime::now(),
estimated_duration_secs,
})
}
pub async fn execute_migration(&self, plan: MigrationPlan) -> Result<(), String> {
let start = std::time::Instant::now();
let mut progress = MigrationProgress::new(
plan.plan_id.clone(),
plan.key_count,
plan.estimated_size_bytes,
);
progress.state = MigrationState::Planning;
progress.started_at = Some(SystemTime::now());
self.active_migrations
.write()
.await
.insert(plan.plan_id.clone(), progress.clone());
info!(
"Starting migration {} from node {} to node {} ({} keys, {} bytes)",
plan.plan_id,
plan.source_node,
plan.target_node,
plan.key_count,
plan.estimated_size_bytes
);
let batch_size = match self.config.strategy {
RebalancingStrategy::Incremental => self.config.incremental_batch_size,
RebalancingStrategy::Bulk => self.config.bulk_batch_size,
RebalancingStrategy::Adaptive => {
let node_loads = self.node_loads.read().await;
if let Some(source_load) = node_loads.get(&plan.source_node) {
if source_load.cpu_utilization > 0.7 {
self.config.incremental_batch_size
} else {
self.config.bulk_batch_size
}
} else {
self.config.incremental_batch_size
}
}
};
progress.state = MigrationState::InProgress;
self.active_migrations
.write()
.await
.insert(plan.plan_id.clone(), progress.clone());
let mut migrated_keys = 0;
let mut migrated_bytes = 0;
while migrated_keys < plan.key_count {
let batch_keys = (plan.key_count - migrated_keys).min(batch_size);
let batch_bytes = (batch_keys as f64 / plan.key_count as f64
* plan.estimated_size_bytes as f64) as usize;
if self.config.bandwidth_limit_bytes_per_sec > 0 {
let sleep_duration = Duration::from_secs_f64(
batch_bytes as f64 / self.config.bandwidth_limit_bytes_per_sec as f64,
);
tokio::time::sleep(sleep_duration).await;
}
self.migrate_batch(
&plan.plan_id,
plan.source_node,
plan.target_node,
batch_keys,
batch_bytes,
)
.await?;
migrated_keys += batch_keys;
migrated_bytes += batch_bytes;
progress.keys_migrated = migrated_keys;
progress.bytes_migrated = migrated_bytes;
self.active_migrations
.write()
.await
.insert(plan.plan_id.clone(), progress.clone());
}
if self.config.enable_verification {
progress.state = MigrationState::Verifying;
self.active_migrations
.write()
.await
.insert(plan.plan_id.clone(), progress.clone());
self.verify_migration(&plan).await?;
}
progress.state = MigrationState::Completed;
progress.completed_at = Some(SystemTime::now());
self.active_migrations
.write()
.await
.insert(plan.plan_id.clone(), progress.clone());
let duration = start.elapsed();
let mut stats = self.stats.write().await;
stats.total_rebalancing_ops += 1;
stats.successful_migrations += 1;
stats.total_keys_migrated += plan.key_count;
stats.total_bytes_migrated += plan.estimated_size_bytes;
stats.last_rebalancing = Some(SystemTime::now());
let total = stats.successful_migrations as f64;
stats.avg_rebalancing_duration_ms = (stats.avg_rebalancing_duration_ms * (total - 1.0)
+ duration.as_millis() as f64)
/ total;
let speed = plan.estimated_size_bytes as f64 / duration.as_secs_f64();
stats.avg_migration_speed_bytes_per_sec =
(stats.avg_migration_speed_bytes_per_sec * (total - 1.0) + speed) / total;
self.migration_history.write().await.push(progress.clone());
self.active_migrations.write().await.remove(&plan.plan_id);
info!(
"Migration {} completed in {:?} ({:.2} MB/s)",
plan.plan_id,
duration,
speed / (1024.0 * 1024.0)
);
Ok(())
}
async fn migrate_batch(
&self,
_plan_id: &str,
_source: OxirsNodeId,
_target: OxirsNodeId,
_keys: usize,
_bytes: usize,
) -> Result<(), String> {
Ok(())
}
async fn verify_migration(&self, _plan: &MigrationPlan) -> Result<(), String> {
Ok(())
}
pub async fn rollback_migration(&self, plan_id: &str) -> Result<(), String> {
if !self.config.enable_rollback {
return Err("Rollback is disabled".to_string());
}
let mut active = self.active_migrations.write().await;
if let Some(progress) = active.get_mut(plan_id) {
warn!("Rolling back migration {}", plan_id);
progress.state = MigrationState::RollingBack;
tokio::time::sleep(Duration::from_secs(1)).await;
progress.state = MigrationState::Failed;
progress.completed_at = Some(SystemTime::now());
progress.errors.push("Migration rolled back".to_string());
let mut stats = self.stats.write().await;
stats.failed_migrations += 1;
Ok(())
} else {
Err(format!("Migration {} not found", plan_id))
}
}
pub async fn get_migration_progress(&self, plan_id: &str) -> Option<MigrationProgress> {
self.active_migrations.read().await.get(plan_id).cloned()
}
pub async fn get_active_migrations(&self) -> Vec<MigrationProgress> {
self.active_migrations
.read()
.await
.values()
.cloned()
.collect()
}
pub async fn get_migration_history(&self) -> Vec<MigrationProgress> {
self.migration_history.read().await.clone()
}
pub async fn get_node_loads(&self) -> BTreeMap<OxirsNodeId, NodeLoad> {
self.node_loads.read().await.clone()
}
pub async fn get_stats(&self) -> RebalancingStats {
self.stats.read().await.clone()
}
pub async fn clear(&self) {
self.node_loads.write().await.clear();
self.active_migrations.write().await.clear();
self.migration_history.write().await.clear();
*self.stats.write().await = RebalancingStats::default();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_rebalancing_creation() {
let config = RebalancingConfig::default();
let manager = DataRebalancingManager::new(config);
let stats = manager.get_stats().await;
assert_eq!(stats.total_rebalancing_ops, 0);
}
#[tokio::test]
async fn test_register_node() {
let config = RebalancingConfig::default();
let manager = DataRebalancingManager::new(config);
manager.register_node(1).await;
manager.register_node(2).await;
let loads = manager.get_node_loads().await;
assert_eq!(loads.len(), 2);
}
#[tokio::test]
async fn test_update_node_load() {
let config = RebalancingConfig::default();
let manager = DataRebalancingManager::new(config);
manager.register_node(1).await;
manager
.update_node_load(1, 1000, 50000, 0.5, 0.6, 1000000)
.await;
let loads = manager.get_node_loads().await;
let load = loads.get(&1).unwrap();
assert_eq!(load.key_count, 1000);
assert_eq!(load.data_size_bytes, 50000);
assert_eq!(load.cpu_utilization, 0.5);
}
#[tokio::test]
async fn test_check_rebalancing_needed() {
let config = RebalancingConfig {
enable_auto_rebalancing: true,
load_imbalance_threshold: 0.2,
min_nodes_for_rebalancing: 2,
..Default::default()
};
let manager = DataRebalancingManager::new(config);
manager.register_node(1).await;
manager.register_node(2).await;
manager.update_node_load(1, 100, 10000, 0.5, 0.5, 0).await;
manager.update_node_load(2, 100, 10000, 0.5, 0.5, 0).await;
let needed = manager.check_rebalancing_needed().await;
assert!(!needed);
manager.update_node_load(1, 100, 10000, 0.5, 0.5, 0).await;
manager.update_node_load(2, 100, 50000, 0.5, 0.5, 0).await;
let needed = manager.check_rebalancing_needed().await;
assert!(needed);
}
#[tokio::test]
async fn test_create_migration_plan() {
let config = RebalancingConfig::default();
let manager = DataRebalancingManager::new(config);
manager.register_node(1).await;
manager.register_node(2).await;
manager.update_node_load(1, 100, 10000, 0.5, 0.5, 0).await;
manager.update_node_load(2, 200, 50000, 0.5, 0.5, 0).await;
let plan = manager.create_migration_plan().await;
assert!(plan.is_ok());
let plan = plan.unwrap();
assert_eq!(plan.source_node, 2); assert_eq!(plan.target_node, 1); assert!(plan.key_count > 0);
}
#[tokio::test]
async fn test_execute_migration() {
let config = RebalancingConfig {
incremental_batch_size: 10,
bandwidth_limit_bytes_per_sec: 0, enable_verification: false,
..Default::default()
};
let manager = DataRebalancingManager::new(config);
manager.register_node(1).await;
manager.register_node(2).await;
manager.update_node_load(1, 100, 10000, 0.5, 0.5, 0).await;
manager.update_node_load(2, 200, 50000, 0.5, 0.5, 0).await;
let plan = manager.create_migration_plan().await.unwrap();
let result = manager.execute_migration(plan).await;
assert!(result.is_ok());
let stats = manager.get_stats().await;
assert_eq!(stats.successful_migrations, 1);
}
#[tokio::test]
async fn test_migration_progress() {
let progress = MigrationProgress::new("test-1".to_string(), 100, 10000);
assert_eq!(progress.progress_percentage(), 0.0);
let mut progress = progress;
progress.keys_migrated = 50;
assert_eq!(progress.progress_percentage(), 50.0);
}
#[tokio::test]
async fn test_migration_states() {
let config = RebalancingConfig {
bandwidth_limit_bytes_per_sec: 10000, enable_verification: true,
incremental_batch_size: 10,
..Default::default()
};
let manager = DataRebalancingManager::new(config);
manager.register_node(1).await;
manager.register_node(2).await;
manager.update_node_load(1, 100, 10000, 0.5, 0.5, 0).await;
manager.update_node_load(2, 200, 50000, 0.5, 0.5, 0).await;
let plan = manager.create_migration_plan().await.unwrap();
let plan_id = plan.plan_id.clone();
let manager_clone = Arc::new(manager);
let manager_ref = manager_clone.clone();
tokio::spawn(async move {
let _ = manager_ref.execute_migration(plan).await;
});
tokio::time::sleep(Duration::from_millis(50)).await;
let progress = manager_clone.get_migration_progress(&plan_id).await;
assert!(progress.is_some());
tokio::time::sleep(Duration::from_secs(3)).await;
}
#[tokio::test]
async fn test_rollback() {
let config = RebalancingConfig {
enable_rollback: true,
..Default::default()
};
let manager = DataRebalancingManager::new(config);
let progress = MigrationProgress::new("test-rollback".to_string(), 100, 10000);
manager
.active_migrations
.write()
.await
.insert("test-rollback".to_string(), progress);
let result = manager.rollback_migration("test-rollback").await;
assert!(result.is_ok());
let stats = manager.get_stats().await;
assert_eq!(stats.failed_migrations, 1);
}
#[tokio::test]
async fn test_clear() {
let config = RebalancingConfig::default();
let manager = DataRebalancingManager::new(config);
manager.register_node(1).await;
manager.register_node(2).await;
manager.clear().await;
let loads = manager.get_node_loads().await;
assert!(loads.is_empty());
}
#[test]
fn test_rebalancing_strategy() {
assert_eq!(
RebalancingStrategy::Incremental,
RebalancingStrategy::Incremental
);
assert_ne!(RebalancingStrategy::Incremental, RebalancingStrategy::Bulk);
}
#[test]
fn test_migration_state() {
assert_eq!(MigrationState::Idle, MigrationState::Idle);
assert_ne!(MigrationState::Idle, MigrationState::InProgress);
}
}