1use std::sync::Arc;
12use std::time::{SystemTime, UNIX_EPOCH};
13
14use dashmap::{DashMap, DashSet};
15use tracing::{debug, info, warn};
16use uuid::Uuid;
17
18use crate::shard::{ShardId, ShardRegistry};
19use crate::types::NodeId;
20
21#[derive(Debug, Clone, PartialEq)]
25pub enum MigrationStatus {
26 Pending,
28 InProgress,
30 Verifying,
32 Complete,
34 Failed {
36 reason: String,
38 },
39}
40
41#[derive(Debug, Clone)]
46pub struct Migration {
47 pub id: Uuid,
49 pub shard_id: ShardId,
51 pub from_node: NodeId,
53 pub to_node: NodeId,
55 pub status: MigrationStatus,
57 pub started_at_ms: u64,
59 pub bytes_migrated: u64,
61 pub total_bytes: u64,
63}
64
65impl Migration {
66 fn new(shard_id: ShardId, from_node: NodeId, to_node: NodeId) -> Self {
67 let started_at_ms = SystemTime::now()
68 .duration_since(UNIX_EPOCH)
69 .map(|d| d.as_millis() as u64)
70 .unwrap_or(0);
71 Self {
72 id: Uuid::new_v4(),
73 shard_id,
74 from_node,
75 to_node,
76 status: MigrationStatus::Pending,
77 started_at_ms,
78 bytes_migrated: 0,
79 total_bytes: 0,
80 }
81 }
82}
83
84pub struct MigrationTracker {
98 migrations: DashMap<Uuid, Migration>,
100 active_shard_migrations: DashMap<ShardId, Uuid>,
102}
103
104impl MigrationTracker {
105 pub fn new() -> Self {
107 Self {
108 migrations: DashMap::new(),
109 active_shard_migrations: DashMap::new(),
110 }
111 }
112
113 pub fn begin_migration(
120 &self,
121 shard_id: ShardId,
122 from_node: NodeId,
123 to_node: NodeId,
124 ) -> Result<Uuid, String> {
125 if self.active_shard_migrations.contains_key(&shard_id) {
127 return Err(format!(
128 "shard {} already has an active migration",
129 shard_id
130 ));
131 }
132
133 let migration = Migration::new(shard_id, from_node, to_node);
134 let id = migration.id;
135
136 self.active_shard_migrations.insert(shard_id, id);
137 self.migrations.insert(id, migration);
138
139 info!(
140 migration_id = %id,
141 shard_id = shard_id,
142 from_node = from_node,
143 to_node = to_node,
144 "Migration begun"
145 );
146
147 Ok(id)
148 }
149
150 pub fn update_progress(&self, id: Uuid, bytes_migrated: u64, total_bytes: u64) -> bool {
156 match self.migrations.get_mut(&id) {
157 None => {
158 warn!(migration_id = %id, "update_progress: migration not found");
159 false
160 }
161 Some(mut m) => {
162 m.bytes_migrated = bytes_migrated;
163 m.total_bytes = total_bytes;
164 if m.status == MigrationStatus::Pending {
165 m.status = MigrationStatus::InProgress;
166 }
167 debug!(
168 migration_id = %id,
169 bytes_migrated = bytes_migrated,
170 total_bytes = total_bytes,
171 "Migration progress updated"
172 );
173 true
174 }
175 }
176 }
177
178 pub fn complete_migration(&self, id: Uuid) -> bool {
183 match self.migrations.get_mut(&id) {
184 None => {
185 warn!(migration_id = %id, "complete_migration: migration not found");
186 false
187 }
188 Some(mut m) => {
189 let shard_id = m.shard_id;
190 m.status = MigrationStatus::Complete;
191 drop(m);
192 self.active_shard_migrations
194 .remove_if(&shard_id, |_, v| *v == id);
195 info!(migration_id = %id, shard_id = shard_id, "Migration completed");
196 true
197 }
198 }
199 }
200
201 pub fn fail_migration(&self, id: Uuid, reason: String) -> bool {
206 match self.migrations.get_mut(&id) {
207 None => {
208 warn!(migration_id = %id, "fail_migration: migration not found");
209 false
210 }
211 Some(mut m) => {
212 let shard_id = m.shard_id;
213 m.status = MigrationStatus::Failed {
214 reason: reason.clone(),
215 };
216 drop(m);
217 self.active_shard_migrations
218 .remove_if(&shard_id, |_, v| *v == id);
219 warn!(
220 migration_id = %id,
221 shard_id = shard_id,
222 reason = %reason,
223 "Migration failed"
224 );
225 true
226 }
227 }
228 }
229
230 pub fn get_migration(&self, id: Uuid) -> Option<Migration> {
232 self.migrations.get(&id).map(|m| m.clone())
233 }
234
235 pub fn active_migrations(&self) -> Vec<Migration> {
237 self.migrations
238 .iter()
239 .filter(|r| {
240 !matches!(
241 r.status,
242 MigrationStatus::Complete | MigrationStatus::Failed { .. }
243 )
244 })
245 .map(|r| r.clone())
246 .collect()
247 }
248
249 pub fn is_shard_migrating(&self, shard_id: &ShardId) -> bool {
251 self.active_shard_migrations.contains_key(shard_id)
252 }
253}
254
255impl Default for MigrationTracker {
256 fn default() -> Self {
257 Self::new()
258 }
259}
260
261pub fn compute_rebalance_plan(
272 registry: &ShardRegistry,
273 tracker: &MigrationTracker,
274 imbalance_threshold: f64,
275 max_concurrent_migrations: usize,
276) -> Vec<(ShardId, NodeId, NodeId)> {
277 use std::collections::HashMap;
278
279 let shards = registry.get_all();
280 if shards.is_empty() || max_concurrent_migrations == 0 {
281 return Vec::new();
282 }
283
284 let mut node_shards: HashMap<NodeId, Vec<ShardId>> = HashMap::new();
286 for shard in &shards {
287 node_shards.entry(shard.node_id).or_default().push(shard.id);
288 }
289
290 if node_shards.len() < 2 {
291 return Vec::new();
292 }
293
294 let mean = shards.len() as f64 / node_shards.len() as f64;
295
296 let mut overloaded: Vec<(NodeId, Vec<ShardId>)> = node_shards
298 .iter()
299 .filter(|(_, ids)| ids.len() as f64 > mean * (1.0 + imbalance_threshold))
300 .map(|(nid, ids)| (*nid, ids.clone()))
301 .collect();
302
303 let mut underloaded: Vec<(NodeId, usize)> = node_shards
304 .iter()
305 .filter(|(_, ids)| (ids.len() as f64) < mean * (1.0 - imbalance_threshold))
306 .map(|(nid, ids)| (*nid, ids.len()))
307 .collect();
308
309 if overloaded.is_empty() || underloaded.is_empty() {
310 return Vec::new();
311 }
312
313 overloaded.sort_by_key(|(nid, _)| *nid);
315 underloaded.sort_by_key(|(nid, _)| *nid);
316
317 let mut plan: Vec<(ShardId, NodeId, NodeId)> = Vec::new();
318
319 'outer: for (from_node, shard_ids) in &overloaded {
320 for shard_id in shard_ids {
321 if tracker.is_shard_migrating(shard_id) {
322 continue;
323 }
324 if let Some((to_node, _)) = underloaded.first_mut() {
325 plan.push((*shard_id, *from_node, *to_node));
326 if plan.len() >= max_concurrent_migrations {
327 break 'outer;
328 }
329 }
330 }
331 }
332
333 plan
334}
335
336#[cfg(test)]
339mod tests {
340 use super::*;
341 use crate::shard::{KeyRange, ShardMetadata, ShardRegistry};
342 use amaters_core::Key;
343
344 fn make_registry_with_distribution(
345 distribution: &[(ShardId, NodeId, &str, &str)],
346 ) -> ShardRegistry {
347 let registry = ShardRegistry::new();
348 for &(shard_id, node_id, start, end) in distribution {
349 let range =
350 KeyRange::new(Key::from_str(start), Key::from_str(end)).expect("valid range");
351 let shard = ShardMetadata::new(shard_id, range, node_id);
352 registry.register(shard).expect("register");
353 }
354 registry
355 }
356
357 #[test]
360 fn test_begin_migration_prevents_duplicate() {
361 let tracker = MigrationTracker::new();
362
363 let result = tracker.begin_migration(1, 10, 20);
365 assert!(result.is_ok(), "first migration should succeed");
366
367 let result2 = tracker.begin_migration(1, 10, 20);
369 assert!(
370 result2.is_err(),
371 "duplicate migration for shard 1 should be rejected"
372 );
373 let err_msg = result2.expect_err("second migration should fail");
374 assert!(
375 err_msg.contains("shard 1"),
376 "error message should mention the shard id"
377 );
378 }
379
380 #[test]
383 fn test_migration_lifecycle() {
384 let tracker = MigrationTracker::new();
385
386 let id = tracker.begin_migration(2, 10, 20).expect("begin_migration");
387
388 let m = tracker.get_migration(id).expect("get migration");
390 assert_eq!(m.status, MigrationStatus::Pending);
391 assert!(tracker.is_shard_migrating(&2));
392
393 assert!(tracker.update_progress(id, 512, 1024));
395 let m = tracker.get_migration(id).expect("get migration");
396 assert_eq!(m.status, MigrationStatus::InProgress);
397 assert_eq!(m.bytes_migrated, 512);
398 assert_eq!(m.total_bytes, 1024);
399
400 assert!(tracker.complete_migration(id));
402 let m = tracker.get_migration(id).expect("get migration");
403 assert_eq!(m.status, MigrationStatus::Complete);
404 assert!(!tracker.is_shard_migrating(&2));
406
407 assert!(tracker.begin_migration(2, 20, 10).is_ok());
409 }
410
411 #[test]
414 fn test_migration_failed_state() {
415 let tracker = MigrationTracker::new();
416 let id = tracker.begin_migration(3, 10, 20).expect("begin_migration");
417
418 assert!(tracker.fail_migration(id, "disk full".to_string()));
419 let m = tracker.get_migration(id).expect("get migration");
420 assert!(
421 matches!(m.status, MigrationStatus::Failed { ref reason } if reason == "disk full"),
422 "expected Failed with reason 'disk full', got {:?}",
423 m.status
424 );
425 assert!(!tracker.is_shard_migrating(&3));
427 }
428
429 #[test]
432 fn test_rebalance_plan_targets_overloaded_node() {
433 let registry = make_registry_with_distribution(&[
436 (1, 1, "a0", "a1"),
437 (2, 1, "a1", "a2"),
438 (3, 1, "a2", "a3"),
439 (4, 1, "a3", "a4"),
440 (5, 1, "a4", "a5"),
441 (6, 1, "a5", "a6"),
442 (7, 2, "b0", "b1"),
443 (8, 2, "b1", "b2"),
444 ]);
445 let tracker = MigrationTracker::new();
446 let plan = compute_rebalance_plan(®istry, &tracker, 0.2, 10);
448
449 assert!(
450 !plan.is_empty(),
451 "plan should be non-empty for imbalanced cluster"
452 );
453 for (shard_id, from_node, to_node) in &plan {
454 assert_eq!(*from_node, 1, "moves should come from overloaded node 1");
455 assert_eq!(*to_node, 2, "moves should go to underloaded node 2");
456 assert!(
457 *shard_id >= 1 && *shard_id <= 6,
458 "only shards on node 1 should be moved"
459 );
460 }
461 }
462
463 #[test]
466 fn test_no_rebalance_when_balanced() {
467 let registry = make_registry_with_distribution(&[
469 (1, 1, "a0", "a1"),
470 (2, 1, "a1", "a2"),
471 (3, 1, "a2", "a3"),
472 (4, 1, "a3", "a4"),
473 (5, 2, "b0", "b1"),
474 (6, 2, "b1", "b2"),
475 (7, 2, "b2", "b3"),
476 (8, 2, "b3", "b4"),
477 ]);
478 let tracker = MigrationTracker::new();
479 let plan = compute_rebalance_plan(®istry, &tracker, 0.2, 10);
480
481 assert!(
482 plan.is_empty(),
483 "plan should be empty for balanced cluster, got {:?}",
484 plan
485 );
486 }
487
488 #[test]
491 fn test_active_migrations_excludes_terminal() {
492 let tracker = MigrationTracker::new();
493 let id1 = tracker.begin_migration(10, 1, 2).expect("begin 10");
494 let id2 = tracker.begin_migration(11, 1, 2).expect("begin 11");
495 let id3 = tracker.begin_migration(12, 1, 2).expect("begin 12");
496
497 tracker.complete_migration(id1);
498 tracker.fail_migration(id2, "oops".to_string());
499
500 let active = tracker.active_migrations();
501 let active_ids: Vec<Uuid> = active.iter().map(|m| m.id).collect();
502
503 assert!(
504 !active_ids.contains(&id1),
505 "completed migration should not appear in active list"
506 );
507 assert!(
508 !active_ids.contains(&id2),
509 "failed migration should not appear in active list"
510 );
511 assert!(
512 active_ids.contains(&id3),
513 "pending migration should appear in active list"
514 );
515 }
516
517 #[test]
520 fn test_max_concurrent_migrations_respected() {
521 let registry = make_registry_with_distribution(&[
522 (1, 1, "a0", "a1"),
523 (2, 1, "a1", "a2"),
524 (3, 1, "a2", "a3"),
525 (4, 1, "a3", "a4"),
526 (5, 1, "a4", "a5"),
527 (6, 1, "a5", "a6"),
528 (7, 2, "b0", "b1"),
529 (8, 2, "b1", "b2"),
530 ]);
531 let tracker = MigrationTracker::new();
532 let plan = compute_rebalance_plan(®istry, &tracker, 0.2, 2);
533
534 assert!(
535 plan.len() <= 2,
536 "plan must not exceed max_concurrent_migrations=2, got {}",
537 plan.len()
538 );
539 }
540}