Skip to main content

amaters_cluster/
migration.rs

1//! Live data migration tracker for load-balanced shard movement.
2//!
3//! [`MigrationTracker`] maintains a concurrent registry of in-progress
4//! [`Migration`]s and enforces the invariant that a shard can participate in
5//! at most one migration at a time.
6//!
7//! The companion function [`compute_rebalance_plan`] analyses a
8//! [`ShardRegistry`] snapshot and proposes `(shard_id, from_node, to_node)`
9//! triples that would reduce inter-node imbalance.
10
11use std::sync::Arc;
12use std::time::{SystemTime, UNIX_EPOCH};
13
14use dashmap::{DashMap, DashSet};
15use tracing::{debug, info, warn};
16use uuid::Uuid;
17
18use crate::shard::{ShardId, ShardRegistry};
19use crate::types::NodeId;
20
21// ── MigrationStatus ───────────────────────────────────────────────────────────
22
23/// Progress state of a single shard migration.
24#[derive(Debug, Clone, PartialEq)]
25pub enum MigrationStatus {
26    /// Accepted, not yet started.
27    Pending,
28    /// Data is actively being copied.
29    InProgress,
30    /// Copy complete; checksums are being verified before traffic cutover.
31    Verifying,
32    /// Migration finished successfully; source shard may be released.
33    Complete,
34    /// Migration aborted.
35    Failed {
36        /// Human-readable failure reason.
37        reason: String,
38    },
39}
40
41// ── Migration ─────────────────────────────────────────────────────────────────
42
43/// Represents one in-progress shard migration (moving data from one node to
44/// another).
45#[derive(Debug, Clone)]
46pub struct Migration {
47    /// Unique identifier for this migration.
48    pub id: Uuid,
49    /// The shard being migrated.
50    pub shard_id: ShardId,
51    /// Source node.
52    pub from_node: NodeId,
53    /// Destination node.
54    pub to_node: NodeId,
55    /// Current status.
56    pub status: MigrationStatus,
57    /// Wall-clock milliseconds when the migration was accepted.
58    pub started_at_ms: u64,
59    /// Bytes copied so far (updated by progress calls).
60    pub bytes_migrated: u64,
61    /// Total bytes to copy (may be 0 if unknown at start).
62    pub total_bytes: u64,
63}
64
65impl Migration {
66    fn new(shard_id: ShardId, from_node: NodeId, to_node: NodeId) -> Self {
67        let started_at_ms = SystemTime::now()
68            .duration_since(UNIX_EPOCH)
69            .map(|d| d.as_millis() as u64)
70            .unwrap_or(0);
71        Self {
72            id: Uuid::new_v4(),
73            shard_id,
74            from_node,
75            to_node,
76            status: MigrationStatus::Pending,
77            started_at_ms,
78            bytes_migrated: 0,
79            total_bytes: 0,
80        }
81    }
82}
83
84// ── MigrationTracker ─────────────────────────────────────────────────────────
85
86/// Concurrent registry of active [`Migration`]s.
87///
88/// # Conflict prevention
89///
90/// A shard can be involved in at most one migration at a time.  Calling
91/// [`begin_migration`][Self::begin_migration] for a shard that already has an
92/// active migration returns an `Err`.
93///
94/// Migrations in terminal states ([`MigrationStatus::Complete`] or
95/// [`MigrationStatus::Failed`]) no longer block new migrations for the same
96/// shard.
97pub struct MigrationTracker {
98    /// All migrations indexed by their UUID.
99    migrations: DashMap<Uuid, Migration>,
100    /// Maps `shard_id → migration_id` for non-terminal migrations only.
101    active_shard_migrations: DashMap<ShardId, Uuid>,
102}
103
104impl MigrationTracker {
105    /// Construct an empty tracker.
106    pub fn new() -> Self {
107        Self {
108            migrations: DashMap::new(),
109            active_shard_migrations: DashMap::new(),
110        }
111    }
112
113    /// Begin a new migration for `shard_id` from `from_node` to `to_node`.
114    ///
115    /// # Errors
116    ///
117    /// Returns `Err` if the shard is already involved in a non-terminal
118    /// migration.
119    pub fn begin_migration(
120        &self,
121        shard_id: ShardId,
122        from_node: NodeId,
123        to_node: NodeId,
124    ) -> Result<Uuid, String> {
125        // Conflict check: shard must not already have an active migration.
126        if self.active_shard_migrations.contains_key(&shard_id) {
127            return Err(format!(
128                "shard {} already has an active migration",
129                shard_id
130            ));
131        }
132
133        let migration = Migration::new(shard_id, from_node, to_node);
134        let id = migration.id;
135
136        self.active_shard_migrations.insert(shard_id, id);
137        self.migrations.insert(id, migration);
138
139        info!(
140            migration_id = %id,
141            shard_id = shard_id,
142            from_node = from_node,
143            to_node = to_node,
144            "Migration begun"
145        );
146
147        Ok(id)
148    }
149
150    /// Update the byte-level progress of migration `id`.
151    ///
152    /// Returns `true` if the migration existed and its status was updated to
153    /// [`MigrationStatus::InProgress`].  Returns `false` if the migration was
154    /// not found.
155    pub fn update_progress(&self, id: Uuid, bytes_migrated: u64, total_bytes: u64) -> bool {
156        match self.migrations.get_mut(&id) {
157            None => {
158                warn!(migration_id = %id, "update_progress: migration not found");
159                false
160            }
161            Some(mut m) => {
162                m.bytes_migrated = bytes_migrated;
163                m.total_bytes = total_bytes;
164                if m.status == MigrationStatus::Pending {
165                    m.status = MigrationStatus::InProgress;
166                }
167                debug!(
168                    migration_id = %id,
169                    bytes_migrated = bytes_migrated,
170                    total_bytes = total_bytes,
171                    "Migration progress updated"
172                );
173                true
174            }
175        }
176    }
177
178    /// Mark migration `id` as [`MigrationStatus::Complete`] and release the
179    /// shard lock.
180    ///
181    /// Returns `true` if the migration existed.
182    pub fn complete_migration(&self, id: Uuid) -> bool {
183        match self.migrations.get_mut(&id) {
184            None => {
185                warn!(migration_id = %id, "complete_migration: migration not found");
186                false
187            }
188            Some(mut m) => {
189                let shard_id = m.shard_id;
190                m.status = MigrationStatus::Complete;
191                drop(m);
192                // Release shard lock only if it still maps to this migration.
193                self.active_shard_migrations
194                    .remove_if(&shard_id, |_, v| *v == id);
195                info!(migration_id = %id, shard_id = shard_id, "Migration completed");
196                true
197            }
198        }
199    }
200
201    /// Mark migration `id` as [`MigrationStatus::Failed`] and release the
202    /// shard lock.
203    ///
204    /// Returns `true` if the migration existed.
205    pub fn fail_migration(&self, id: Uuid, reason: String) -> bool {
206        match self.migrations.get_mut(&id) {
207            None => {
208                warn!(migration_id = %id, "fail_migration: migration not found");
209                false
210            }
211            Some(mut m) => {
212                let shard_id = m.shard_id;
213                m.status = MigrationStatus::Failed {
214                    reason: reason.clone(),
215                };
216                drop(m);
217                self.active_shard_migrations
218                    .remove_if(&shard_id, |_, v| *v == id);
219                warn!(
220                    migration_id = %id,
221                    shard_id = shard_id,
222                    reason = %reason,
223                    "Migration failed"
224                );
225                true
226            }
227        }
228    }
229
230    /// Return a snapshot of the [`Migration`] with the given `id`, or `None`.
231    pub fn get_migration(&self, id: Uuid) -> Option<Migration> {
232        self.migrations.get(&id).map(|m| m.clone())
233    }
234
235    /// Return snapshots of all migrations that are not in a terminal state.
236    pub fn active_migrations(&self) -> Vec<Migration> {
237        self.migrations
238            .iter()
239            .filter(|r| {
240                !matches!(
241                    r.status,
242                    MigrationStatus::Complete | MigrationStatus::Failed { .. }
243                )
244            })
245            .map(|r| r.clone())
246            .collect()
247    }
248
249    /// Return `true` if `shard_id` currently has a non-terminal migration.
250    pub fn is_shard_migrating(&self, shard_id: &ShardId) -> bool {
251        self.active_shard_migrations.contains_key(shard_id)
252    }
253}
254
255impl Default for MigrationTracker {
256    fn default() -> Self {
257        Self::new()
258    }
259}
260
261// ── Rebalance plan ────────────────────────────────────────────────────────────
262
263/// Compute which shards should migrate to rebalance load.
264///
265/// Returns at most `max_concurrent_migrations` `(shard_id, from_node, to_node)`
266/// proposals.  Nodes with a shard count more than `imbalance_threshold` above
267/// the cluster mean are considered overloaded; nodes below the mean by the same
268/// threshold are considered underloaded.
269///
270/// Shards that already have an active migration are excluded.
271pub fn compute_rebalance_plan(
272    registry: &ShardRegistry,
273    tracker: &MigrationTracker,
274    imbalance_threshold: f64,
275    max_concurrent_migrations: usize,
276) -> Vec<(ShardId, NodeId, NodeId)> {
277    use std::collections::HashMap;
278
279    let shards = registry.get_all();
280    if shards.is_empty() || max_concurrent_migrations == 0 {
281        return Vec::new();
282    }
283
284    // Build per-node shard lists.
285    let mut node_shards: HashMap<NodeId, Vec<ShardId>> = HashMap::new();
286    for shard in &shards {
287        node_shards.entry(shard.node_id).or_default().push(shard.id);
288    }
289
290    if node_shards.len() < 2 {
291        return Vec::new();
292    }
293
294    let mean = shards.len() as f64 / node_shards.len() as f64;
295
296    // Identify overloaded and underloaded nodes.
297    let mut overloaded: Vec<(NodeId, Vec<ShardId>)> = node_shards
298        .iter()
299        .filter(|(_, ids)| ids.len() as f64 > mean * (1.0 + imbalance_threshold))
300        .map(|(nid, ids)| (*nid, ids.clone()))
301        .collect();
302
303    let mut underloaded: Vec<(NodeId, usize)> = node_shards
304        .iter()
305        .filter(|(_, ids)| (ids.len() as f64) < mean * (1.0 - imbalance_threshold))
306        .map(|(nid, ids)| (*nid, ids.len()))
307        .collect();
308
309    if overloaded.is_empty() || underloaded.is_empty() {
310        return Vec::new();
311    }
312
313    // Sort for determinism.
314    overloaded.sort_by_key(|(nid, _)| *nid);
315    underloaded.sort_by_key(|(nid, _)| *nid);
316
317    let mut plan: Vec<(ShardId, NodeId, NodeId)> = Vec::new();
318
319    'outer: for (from_node, shard_ids) in &overloaded {
320        for shard_id in shard_ids {
321            if tracker.is_shard_migrating(shard_id) {
322                continue;
323            }
324            if let Some((to_node, _)) = underloaded.first_mut() {
325                plan.push((*shard_id, *from_node, *to_node));
326                if plan.len() >= max_concurrent_migrations {
327                    break 'outer;
328                }
329            }
330        }
331    }
332
333    plan
334}
335
336// ── Tests ─────────────────────────────────────────────────────────────────────
337
338#[cfg(test)]
339mod tests {
340    use super::*;
341    use crate::shard::{KeyRange, ShardMetadata, ShardRegistry};
342    use amaters_core::Key;
343
344    fn make_registry_with_distribution(
345        distribution: &[(ShardId, NodeId, &str, &str)],
346    ) -> ShardRegistry {
347        let registry = ShardRegistry::new();
348        for &(shard_id, node_id, start, end) in distribution {
349            let range =
350                KeyRange::new(Key::from_str(start), Key::from_str(end)).expect("valid range");
351            let shard = ShardMetadata::new(shard_id, range, node_id);
352            registry.register(shard).expect("register");
353        }
354        registry
355    }
356
357    // ── test_begin_migration_prevents_duplicate ───────────────────────────────
358
359    #[test]
360    fn test_begin_migration_prevents_duplicate() {
361        let tracker = MigrationTracker::new();
362
363        // First migration for shard 1: should succeed.
364        let result = tracker.begin_migration(1, 10, 20);
365        assert!(result.is_ok(), "first migration should succeed");
366
367        // Second migration for the same shard: must fail.
368        let result2 = tracker.begin_migration(1, 10, 20);
369        assert!(
370            result2.is_err(),
371            "duplicate migration for shard 1 should be rejected"
372        );
373        let err_msg = result2.expect_err("second migration should fail");
374        assert!(
375            err_msg.contains("shard 1"),
376            "error message should mention the shard id"
377        );
378    }
379
380    // ── test_migration_lifecycle ──────────────────────────────────────────────
381
382    #[test]
383    fn test_migration_lifecycle() {
384        let tracker = MigrationTracker::new();
385
386        let id = tracker.begin_migration(2, 10, 20).expect("begin_migration");
387
388        // Initial state: Pending.
389        let m = tracker.get_migration(id).expect("get migration");
390        assert_eq!(m.status, MigrationStatus::Pending);
391        assert!(tracker.is_shard_migrating(&2));
392
393        // Update progress → InProgress.
394        assert!(tracker.update_progress(id, 512, 1024));
395        let m = tracker.get_migration(id).expect("get migration");
396        assert_eq!(m.status, MigrationStatus::InProgress);
397        assert_eq!(m.bytes_migrated, 512);
398        assert_eq!(m.total_bytes, 1024);
399
400        // Complete.
401        assert!(tracker.complete_migration(id));
402        let m = tracker.get_migration(id).expect("get migration");
403        assert_eq!(m.status, MigrationStatus::Complete);
404        // Shard lock should be released.
405        assert!(!tracker.is_shard_migrating(&2));
406
407        // Should now be able to start a new migration for the same shard.
408        assert!(tracker.begin_migration(2, 20, 10).is_ok());
409    }
410
411    // ── test_migration_failed_state ───────────────────────────────────────────
412
413    #[test]
414    fn test_migration_failed_state() {
415        let tracker = MigrationTracker::new();
416        let id = tracker.begin_migration(3, 10, 20).expect("begin_migration");
417
418        assert!(tracker.fail_migration(id, "disk full".to_string()));
419        let m = tracker.get_migration(id).expect("get migration");
420        assert!(
421            matches!(m.status, MigrationStatus::Failed { ref reason } if reason == "disk full"),
422            "expected Failed with reason 'disk full', got {:?}",
423            m.status
424        );
425        // Shard lock must be released after failure.
426        assert!(!tracker.is_shard_migrating(&3));
427    }
428
429    // ── test_rebalance_plan_targets_overloaded_node ───────────────────────────
430
431    #[test]
432    fn test_rebalance_plan_targets_overloaded_node() {
433        // 6 shards on node A (id 1), 2 shards on node B (id 2).
434        // mean = 8/2 = 4.  Node A has 6 (50% above mean); Node B has 2 (50% below).
435        let registry = make_registry_with_distribution(&[
436            (1, 1, "a0", "a1"),
437            (2, 1, "a1", "a2"),
438            (3, 1, "a2", "a3"),
439            (4, 1, "a3", "a4"),
440            (5, 1, "a4", "a5"),
441            (6, 1, "a5", "a6"),
442            (7, 2, "b0", "b1"),
443            (8, 2, "b1", "b2"),
444        ]);
445        let tracker = MigrationTracker::new();
446        // imbalance_threshold 0.2 → overloaded if > mean*(1+0.2)=4.8, underloaded if < mean*(1-0.2)=3.2
447        let plan = compute_rebalance_plan(&registry, &tracker, 0.2, 10);
448
449        assert!(
450            !plan.is_empty(),
451            "plan should be non-empty for imbalanced cluster"
452        );
453        for (shard_id, from_node, to_node) in &plan {
454            assert_eq!(*from_node, 1, "moves should come from overloaded node 1");
455            assert_eq!(*to_node, 2, "moves should go to underloaded node 2");
456            assert!(
457                *shard_id >= 1 && *shard_id <= 6,
458                "only shards on node 1 should be moved"
459            );
460        }
461    }
462
463    // ── test_no_rebalance_when_balanced ───────────────────────────────────────
464
465    #[test]
466    fn test_no_rebalance_when_balanced() {
467        // 4 shards on node 1, 4 shards on node 2 → perfectly balanced.
468        let registry = make_registry_with_distribution(&[
469            (1, 1, "a0", "a1"),
470            (2, 1, "a1", "a2"),
471            (3, 1, "a2", "a3"),
472            (4, 1, "a3", "a4"),
473            (5, 2, "b0", "b1"),
474            (6, 2, "b1", "b2"),
475            (7, 2, "b2", "b3"),
476            (8, 2, "b3", "b4"),
477        ]);
478        let tracker = MigrationTracker::new();
479        let plan = compute_rebalance_plan(&registry, &tracker, 0.2, 10);
480
481        assert!(
482            plan.is_empty(),
483            "plan should be empty for balanced cluster, got {:?}",
484            plan
485        );
486    }
487
488    // ── test_active_migrations_excludes_terminal ──────────────────────────────
489
490    #[test]
491    fn test_active_migrations_excludes_terminal() {
492        let tracker = MigrationTracker::new();
493        let id1 = tracker.begin_migration(10, 1, 2).expect("begin 10");
494        let id2 = tracker.begin_migration(11, 1, 2).expect("begin 11");
495        let id3 = tracker.begin_migration(12, 1, 2).expect("begin 12");
496
497        tracker.complete_migration(id1);
498        tracker.fail_migration(id2, "oops".to_string());
499
500        let active = tracker.active_migrations();
501        let active_ids: Vec<Uuid> = active.iter().map(|m| m.id).collect();
502
503        assert!(
504            !active_ids.contains(&id1),
505            "completed migration should not appear in active list"
506        );
507        assert!(
508            !active_ids.contains(&id2),
509            "failed migration should not appear in active list"
510        );
511        assert!(
512            active_ids.contains(&id3),
513            "pending migration should appear in active list"
514        );
515    }
516
517    // ── test_max_concurrent_migrations_respected ──────────────────────────────
518
519    #[test]
520    fn test_max_concurrent_migrations_respected() {
521        let registry = make_registry_with_distribution(&[
522            (1, 1, "a0", "a1"),
523            (2, 1, "a1", "a2"),
524            (3, 1, "a2", "a3"),
525            (4, 1, "a3", "a4"),
526            (5, 1, "a4", "a5"),
527            (6, 1, "a5", "a6"),
528            (7, 2, "b0", "b1"),
529            (8, 2, "b1", "b2"),
530        ]);
531        let tracker = MigrationTracker::new();
532        let plan = compute_rebalance_plan(&registry, &tracker, 0.2, 2);
533
534        assert!(
535            plan.len() <= 2,
536            "plan must not exceed max_concurrent_migrations=2, got {}",
537            plan.len()
538        );
539    }
540}