use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};
use dashmap::{DashMap, DashSet};
use tracing::{debug, info, warn};
use uuid::Uuid;
use crate::shard::{ShardId, ShardRegistry};
use crate::types::NodeId;
#[derive(Debug, Clone, PartialEq)]
pub enum MigrationStatus {
Pending,
InProgress,
Verifying,
Complete,
Failed {
reason: String,
},
}
#[derive(Debug, Clone)]
pub struct Migration {
pub id: Uuid,
pub shard_id: ShardId,
pub from_node: NodeId,
pub to_node: NodeId,
pub status: MigrationStatus,
pub started_at_ms: u64,
pub bytes_migrated: u64,
pub total_bytes: u64,
}
impl Migration {
fn new(shard_id: ShardId, from_node: NodeId, to_node: NodeId) -> Self {
let started_at_ms = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_millis() as u64)
.unwrap_or(0);
Self {
id: Uuid::new_v4(),
shard_id,
from_node,
to_node,
status: MigrationStatus::Pending,
started_at_ms,
bytes_migrated: 0,
total_bytes: 0,
}
}
}
pub struct MigrationTracker {
migrations: DashMap<Uuid, Migration>,
active_shard_migrations: DashMap<ShardId, Uuid>,
}
impl MigrationTracker {
pub fn new() -> Self {
Self {
migrations: DashMap::new(),
active_shard_migrations: DashMap::new(),
}
}
pub fn begin_migration(
&self,
shard_id: ShardId,
from_node: NodeId,
to_node: NodeId,
) -> Result<Uuid, String> {
if self.active_shard_migrations.contains_key(&shard_id) {
return Err(format!(
"shard {} already has an active migration",
shard_id
));
}
let migration = Migration::new(shard_id, from_node, to_node);
let id = migration.id;
self.active_shard_migrations.insert(shard_id, id);
self.migrations.insert(id, migration);
info!(
migration_id = %id,
shard_id = shard_id,
from_node = from_node,
to_node = to_node,
"Migration begun"
);
Ok(id)
}
pub fn update_progress(&self, id: Uuid, bytes_migrated: u64, total_bytes: u64) -> bool {
match self.migrations.get_mut(&id) {
None => {
warn!(migration_id = %id, "update_progress: migration not found");
false
}
Some(mut m) => {
m.bytes_migrated = bytes_migrated;
m.total_bytes = total_bytes;
if m.status == MigrationStatus::Pending {
m.status = MigrationStatus::InProgress;
}
debug!(
migration_id = %id,
bytes_migrated = bytes_migrated,
total_bytes = total_bytes,
"Migration progress updated"
);
true
}
}
}
pub fn complete_migration(&self, id: Uuid) -> bool {
match self.migrations.get_mut(&id) {
None => {
warn!(migration_id = %id, "complete_migration: migration not found");
false
}
Some(mut m) => {
let shard_id = m.shard_id;
m.status = MigrationStatus::Complete;
drop(m);
self.active_shard_migrations
.remove_if(&shard_id, |_, v| *v == id);
info!(migration_id = %id, shard_id = shard_id, "Migration completed");
true
}
}
}
pub fn fail_migration(&self, id: Uuid, reason: String) -> bool {
match self.migrations.get_mut(&id) {
None => {
warn!(migration_id = %id, "fail_migration: migration not found");
false
}
Some(mut m) => {
let shard_id = m.shard_id;
m.status = MigrationStatus::Failed {
reason: reason.clone(),
};
drop(m);
self.active_shard_migrations
.remove_if(&shard_id, |_, v| *v == id);
warn!(
migration_id = %id,
shard_id = shard_id,
reason = %reason,
"Migration failed"
);
true
}
}
}
pub fn get_migration(&self, id: Uuid) -> Option<Migration> {
self.migrations.get(&id).map(|m| m.clone())
}
pub fn active_migrations(&self) -> Vec<Migration> {
self.migrations
.iter()
.filter(|r| {
!matches!(
r.status,
MigrationStatus::Complete | MigrationStatus::Failed { .. }
)
})
.map(|r| r.clone())
.collect()
}
pub fn is_shard_migrating(&self, shard_id: &ShardId) -> bool {
self.active_shard_migrations.contains_key(shard_id)
}
}
impl Default for MigrationTracker {
fn default() -> Self {
Self::new()
}
}
pub fn compute_rebalance_plan(
registry: &ShardRegistry,
tracker: &MigrationTracker,
imbalance_threshold: f64,
max_concurrent_migrations: usize,
) -> Vec<(ShardId, NodeId, NodeId)> {
use std::collections::HashMap;
let shards = registry.get_all();
if shards.is_empty() || max_concurrent_migrations == 0 {
return Vec::new();
}
let mut node_shards: HashMap<NodeId, Vec<ShardId>> = HashMap::new();
for shard in &shards {
node_shards.entry(shard.node_id).or_default().push(shard.id);
}
if node_shards.len() < 2 {
return Vec::new();
}
let mean = shards.len() as f64 / node_shards.len() as f64;
let mut overloaded: Vec<(NodeId, Vec<ShardId>)> = node_shards
.iter()
.filter(|(_, ids)| ids.len() as f64 > mean * (1.0 + imbalance_threshold))
.map(|(nid, ids)| (*nid, ids.clone()))
.collect();
let mut underloaded: Vec<(NodeId, usize)> = node_shards
.iter()
.filter(|(_, ids)| (ids.len() as f64) < mean * (1.0 - imbalance_threshold))
.map(|(nid, ids)| (*nid, ids.len()))
.collect();
if overloaded.is_empty() || underloaded.is_empty() {
return Vec::new();
}
overloaded.sort_by_key(|(nid, _)| *nid);
underloaded.sort_by_key(|(nid, _)| *nid);
let mut plan: Vec<(ShardId, NodeId, NodeId)> = Vec::new();
'outer: for (from_node, shard_ids) in &overloaded {
for shard_id in shard_ids {
if tracker.is_shard_migrating(shard_id) {
continue;
}
if let Some((to_node, _)) = underloaded.first_mut() {
plan.push((*shard_id, *from_node, *to_node));
if plan.len() >= max_concurrent_migrations {
break 'outer;
}
}
}
}
plan
}
#[cfg(test)]
mod tests {
use super::*;
use crate::shard::{KeyRange, ShardMetadata, ShardRegistry};
use amaters_core::Key;
fn make_registry_with_distribution(
distribution: &[(ShardId, NodeId, &str, &str)],
) -> ShardRegistry {
let registry = ShardRegistry::new();
for &(shard_id, node_id, start, end) in distribution {
let range =
KeyRange::new(Key::from_str(start), Key::from_str(end)).expect("valid range");
let shard = ShardMetadata::new(shard_id, range, node_id);
registry.register(shard).expect("register");
}
registry
}
#[test]
fn test_begin_migration_prevents_duplicate() {
let tracker = MigrationTracker::new();
let result = tracker.begin_migration(1, 10, 20);
assert!(result.is_ok(), "first migration should succeed");
let result2 = tracker.begin_migration(1, 10, 20);
assert!(
result2.is_err(),
"duplicate migration for shard 1 should be rejected"
);
let err_msg = result2.expect_err("second migration should fail");
assert!(
err_msg.contains("shard 1"),
"error message should mention the shard id"
);
}
#[test]
fn test_migration_lifecycle() {
let tracker = MigrationTracker::new();
let id = tracker.begin_migration(2, 10, 20).expect("begin_migration");
let m = tracker.get_migration(id).expect("get migration");
assert_eq!(m.status, MigrationStatus::Pending);
assert!(tracker.is_shard_migrating(&2));
assert!(tracker.update_progress(id, 512, 1024));
let m = tracker.get_migration(id).expect("get migration");
assert_eq!(m.status, MigrationStatus::InProgress);
assert_eq!(m.bytes_migrated, 512);
assert_eq!(m.total_bytes, 1024);
assert!(tracker.complete_migration(id));
let m = tracker.get_migration(id).expect("get migration");
assert_eq!(m.status, MigrationStatus::Complete);
assert!(!tracker.is_shard_migrating(&2));
assert!(tracker.begin_migration(2, 20, 10).is_ok());
}
#[test]
fn test_migration_failed_state() {
let tracker = MigrationTracker::new();
let id = tracker.begin_migration(3, 10, 20).expect("begin_migration");
assert!(tracker.fail_migration(id, "disk full".to_string()));
let m = tracker.get_migration(id).expect("get migration");
assert!(
matches!(m.status, MigrationStatus::Failed { ref reason } if reason == "disk full"),
"expected Failed with reason 'disk full', got {:?}",
m.status
);
assert!(!tracker.is_shard_migrating(&3));
}
#[test]
fn test_rebalance_plan_targets_overloaded_node() {
let registry = make_registry_with_distribution(&[
(1, 1, "a0", "a1"),
(2, 1, "a1", "a2"),
(3, 1, "a2", "a3"),
(4, 1, "a3", "a4"),
(5, 1, "a4", "a5"),
(6, 1, "a5", "a6"),
(7, 2, "b0", "b1"),
(8, 2, "b1", "b2"),
]);
let tracker = MigrationTracker::new();
let plan = compute_rebalance_plan(®istry, &tracker, 0.2, 10);
assert!(
!plan.is_empty(),
"plan should be non-empty for imbalanced cluster"
);
for (shard_id, from_node, to_node) in &plan {
assert_eq!(*from_node, 1, "moves should come from overloaded node 1");
assert_eq!(*to_node, 2, "moves should go to underloaded node 2");
assert!(
*shard_id >= 1 && *shard_id <= 6,
"only shards on node 1 should be moved"
);
}
}
#[test]
fn test_no_rebalance_when_balanced() {
let registry = make_registry_with_distribution(&[
(1, 1, "a0", "a1"),
(2, 1, "a1", "a2"),
(3, 1, "a2", "a3"),
(4, 1, "a3", "a4"),
(5, 2, "b0", "b1"),
(6, 2, "b1", "b2"),
(7, 2, "b2", "b3"),
(8, 2, "b3", "b4"),
]);
let tracker = MigrationTracker::new();
let plan = compute_rebalance_plan(®istry, &tracker, 0.2, 10);
assert!(
plan.is_empty(),
"plan should be empty for balanced cluster, got {:?}",
plan
);
}
#[test]
fn test_active_migrations_excludes_terminal() {
let tracker = MigrationTracker::new();
let id1 = tracker.begin_migration(10, 1, 2).expect("begin 10");
let id2 = tracker.begin_migration(11, 1, 2).expect("begin 11");
let id3 = tracker.begin_migration(12, 1, 2).expect("begin 12");
tracker.complete_migration(id1);
tracker.fail_migration(id2, "oops".to_string());
let active = tracker.active_migrations();
let active_ids: Vec<Uuid> = active.iter().map(|m| m.id).collect();
assert!(
!active_ids.contains(&id1),
"completed migration should not appear in active list"
);
assert!(
!active_ids.contains(&id2),
"failed migration should not appear in active list"
);
assert!(
active_ids.contains(&id3),
"pending migration should appear in active list"
);
}
#[test]
fn test_max_concurrent_migrations_respected() {
let registry = make_registry_with_distribution(&[
(1, 1, "a0", "a1"),
(2, 1, "a1", "a2"),
(3, 1, "a2", "a3"),
(4, 1, "a3", "a4"),
(5, 1, "a4", "a5"),
(6, 1, "a5", "a6"),
(7, 2, "b0", "b1"),
(8, 2, "b1", "b2"),
]);
let tracker = MigrationTracker::new();
let plan = compute_rebalance_plan(®istry, &tracker, 0.2, 2);
assert!(
plan.len() <= 2,
"plan must not exceed max_concurrent_migrations=2, got {}",
plan.len()
);
}
}