Skip to main content

atomr_cluster_sharding/
handoff.rs

1//! 3-phase shard handoff state machine.
2//!
3//! Three phases:
4//!
5//! ```text
6//! BeginHandoff(shard) ── source region acks ──► HandingOff
7//! HandingOff          ── all entities stopped ─► Stopped
8//! Stopped             ── coordinator allocates ─► StartElsewhere(shard, new_region)
9//! ```
10//!
11//! [`HandoffCoordinator`] tracks a per-shard state machine and
12//! exposes pure transition helpers; the runtime driver wires it into
13//! the shard region.
14
15use std::collections::HashMap;
16
17use parking_lot::RwLock;
18use thiserror::Error;
19
20#[derive(Debug, Clone, PartialEq, Eq)]
21#[non_exhaustive]
22pub enum HandoffState {
23    /// No handoff in progress.
24    Idle,
25    /// Source region has been told to begin draining.
26    Beginning { source_region: String },
27    /// Entities are stopping; new messages buffer at the source.
28    HandingOff { source_region: String, remaining_entities: usize },
29    /// All entities stopped; awaiting reassignment.
30    Stopped { source_region: String },
31    /// Shard re-allocated to `target_region`.
32    Started { source_region: String, target_region: String },
33}
34
35#[derive(Debug, Error)]
36#[non_exhaustive]
37pub enum HandoffError {
38    #[error("invalid transition for shard `{0}` (current state does not allow it)")]
39    InvalidTransition(String),
40}
41
42/// Per-shard handoff state machine.
43#[derive(Default)]
44pub struct HandoffCoordinator {
45    states: RwLock<HashMap<String, HandoffState>>,
46}
47
48impl HandoffCoordinator {
49    pub fn new() -> Self {
50        Self::default()
51    }
52
53    pub fn state(&self, shard_id: &str) -> HandoffState {
54        self.states.read().get(shard_id).cloned().unwrap_or(HandoffState::Idle)
55    }
56
57    /// Phase 1: tell `source_region` to start draining `shard_id`.
58    pub fn begin(&self, shard_id: &str, source_region: &str) -> Result<(), HandoffError> {
59        let mut g = self.states.write();
60        let cur = g.entry(shard_id.into()).or_insert(HandoffState::Idle).clone();
61        if !matches!(cur, HandoffState::Idle | HandoffState::Started { .. }) {
62            return Err(HandoffError::InvalidTransition(shard_id.into()));
63        }
64        g.insert(shard_id.into(), HandoffState::Beginning { source_region: source_region.into() });
65        Ok(())
66    }
67
68    /// Phase 2a: source region has acknowledged the begin and is now
69    /// stopping `entity_count` entities.
70    pub fn ack_begin(&self, shard_id: &str, entity_count: usize) -> Result<(), HandoffError> {
71        let mut g = self.states.write();
72        let cur = g.get(shard_id).cloned().unwrap_or(HandoffState::Idle);
73        let HandoffState::Beginning { source_region } = cur else {
74            return Err(HandoffError::InvalidTransition(shard_id.into()));
75        };
76        g.insert(
77            shard_id.into(),
78            HandoffState::HandingOff { source_region, remaining_entities: entity_count },
79        );
80        Ok(())
81    }
82
83    /// Phase 2b: one more entity finished stopping. Auto-transitions
84    /// to `Stopped` when the count reaches zero.
85    pub fn entity_stopped(&self, shard_id: &str) -> Result<(), HandoffError> {
86        let mut g = self.states.write();
87        let cur = g.get(shard_id).cloned().unwrap_or(HandoffState::Idle);
88        let HandoffState::HandingOff { source_region, remaining_entities } = cur else {
89            return Err(HandoffError::InvalidTransition(shard_id.into()));
90        };
91        let next = if remaining_entities <= 1 {
92            HandoffState::Stopped { source_region }
93        } else {
94            HandoffState::HandingOff { source_region, remaining_entities: remaining_entities - 1 }
95        };
96        g.insert(shard_id.into(), next);
97        Ok(())
98    }
99
100    /// Phase 3: coordinator allocated the shard to `target_region`.
101    pub fn start_elsewhere(&self, shard_id: &str, target_region: &str) -> Result<(), HandoffError> {
102        let mut g = self.states.write();
103        let cur = g.get(shard_id).cloned().unwrap_or(HandoffState::Idle);
104        let HandoffState::Stopped { source_region } = cur else {
105            return Err(HandoffError::InvalidTransition(shard_id.into()));
106        };
107        g.insert(
108            shard_id.into(),
109            HandoffState::Started { source_region, target_region: target_region.into() },
110        );
111        Ok(())
112    }
113
114    /// Forget a shard (e.g. it was removed entirely).
115    pub fn forget(&self, shard_id: &str) {
116        self.states.write().remove(shard_id);
117    }
118
119    /// Snapshot for telemetry — `(shard_id, state)` pairs.
120    pub fn snapshot(&self) -> Vec<(String, HandoffState)> {
121        let mut v: Vec<(String, HandoffState)> =
122            self.states.read().iter().map(|(k, v)| (k.clone(), v.clone())).collect();
123        v.sort_by(|a, b| a.0.cmp(&b.0));
124        v
125    }
126}
127
128#[cfg(test)]
129mod tests {
130    use super::*;
131
132    #[test]
133    fn full_three_phase_handoff() {
134        let h = HandoffCoordinator::new();
135        h.begin("s1", "r1").unwrap();
136        assert!(matches!(h.state("s1"), HandoffState::Beginning { .. }));
137        h.ack_begin("s1", 3).unwrap();
138        h.entity_stopped("s1").unwrap();
139        h.entity_stopped("s1").unwrap();
140        assert!(matches!(h.state("s1"), HandoffState::HandingOff { remaining_entities: 1, .. }));
141        h.entity_stopped("s1").unwrap();
142        assert!(matches!(h.state("s1"), HandoffState::Stopped { .. }));
143        h.start_elsewhere("s1", "r2").unwrap();
144        assert!(matches!(h.state("s1"), HandoffState::Started { .. }));
145    }
146
147    #[test]
148    fn ack_without_begin_errors() {
149        let h = HandoffCoordinator::new();
150        let r = h.ack_begin("s1", 5);
151        assert!(matches!(r, Err(HandoffError::InvalidTransition(_))));
152    }
153
154    #[test]
155    fn entity_stopped_without_handing_off_errors() {
156        let h = HandoffCoordinator::new();
157        let r = h.entity_stopped("s1");
158        assert!(matches!(r, Err(HandoffError::InvalidTransition(_))));
159    }
160
161    #[test]
162    fn start_elsewhere_without_stopped_errors() {
163        let h = HandoffCoordinator::new();
164        let r = h.start_elsewhere("s1", "r2");
165        assert!(matches!(r, Err(HandoffError::InvalidTransition(_))));
166    }
167
168    #[test]
169    fn re_handoff_after_started_is_allowed() {
170        let h = HandoffCoordinator::new();
171        h.begin("s1", "r1").unwrap();
172        h.ack_begin("s1", 1).unwrap();
173        h.entity_stopped("s1").unwrap();
174        h.start_elsewhere("s1", "r2").unwrap();
175        // Now start a new handoff cycle — `r2 → r3`.
176        h.begin("s1", "r2").unwrap();
177        assert!(matches!(h.state("s1"), HandoffState::Beginning { .. }));
178    }
179
180    #[test]
181    fn forget_drops_state() {
182        let h = HandoffCoordinator::new();
183        h.begin("s1", "r1").unwrap();
184        h.forget("s1");
185        assert!(matches!(h.state("s1"), HandoffState::Idle));
186    }
187}