atomr_cluster_sharding/
handoff.rs1use std::collections::HashMap;
16
17use parking_lot::RwLock;
18use thiserror::Error;
19
20#[derive(Debug, Clone, PartialEq, Eq)]
21#[non_exhaustive]
22pub enum HandoffState {
23 Idle,
25 Beginning { source_region: String },
27 HandingOff { source_region: String, remaining_entities: usize },
29 Stopped { source_region: String },
31 Started { source_region: String, target_region: String },
33}
34
35#[derive(Debug, Error)]
36#[non_exhaustive]
37pub enum HandoffError {
38 #[error("invalid transition for shard `{0}` (current state does not allow it)")]
39 InvalidTransition(String),
40}
41
42#[derive(Default)]
44pub struct HandoffCoordinator {
45 states: RwLock<HashMap<String, HandoffState>>,
46}
47
48impl HandoffCoordinator {
49 pub fn new() -> Self {
50 Self::default()
51 }
52
53 pub fn state(&self, shard_id: &str) -> HandoffState {
54 self.states.read().get(shard_id).cloned().unwrap_or(HandoffState::Idle)
55 }
56
57 pub fn begin(&self, shard_id: &str, source_region: &str) -> Result<(), HandoffError> {
59 let mut g = self.states.write();
60 let cur = g.entry(shard_id.into()).or_insert(HandoffState::Idle).clone();
61 if !matches!(cur, HandoffState::Idle | HandoffState::Started { .. }) {
62 return Err(HandoffError::InvalidTransition(shard_id.into()));
63 }
64 g.insert(shard_id.into(), HandoffState::Beginning { source_region: source_region.into() });
65 Ok(())
66 }
67
68 pub fn ack_begin(&self, shard_id: &str, entity_count: usize) -> Result<(), HandoffError> {
71 let mut g = self.states.write();
72 let cur = g.get(shard_id).cloned().unwrap_or(HandoffState::Idle);
73 let HandoffState::Beginning { source_region } = cur else {
74 return Err(HandoffError::InvalidTransition(shard_id.into()));
75 };
76 g.insert(
77 shard_id.into(),
78 HandoffState::HandingOff { source_region, remaining_entities: entity_count },
79 );
80 Ok(())
81 }
82
83 pub fn entity_stopped(&self, shard_id: &str) -> Result<(), HandoffError> {
86 let mut g = self.states.write();
87 let cur = g.get(shard_id).cloned().unwrap_or(HandoffState::Idle);
88 let HandoffState::HandingOff { source_region, remaining_entities } = cur else {
89 return Err(HandoffError::InvalidTransition(shard_id.into()));
90 };
91 let next = if remaining_entities <= 1 {
92 HandoffState::Stopped { source_region }
93 } else {
94 HandoffState::HandingOff { source_region, remaining_entities: remaining_entities - 1 }
95 };
96 g.insert(shard_id.into(), next);
97 Ok(())
98 }
99
100 pub fn start_elsewhere(&self, shard_id: &str, target_region: &str) -> Result<(), HandoffError> {
102 let mut g = self.states.write();
103 let cur = g.get(shard_id).cloned().unwrap_or(HandoffState::Idle);
104 let HandoffState::Stopped { source_region } = cur else {
105 return Err(HandoffError::InvalidTransition(shard_id.into()));
106 };
107 g.insert(
108 shard_id.into(),
109 HandoffState::Started { source_region, target_region: target_region.into() },
110 );
111 Ok(())
112 }
113
114 pub fn forget(&self, shard_id: &str) {
116 self.states.write().remove(shard_id);
117 }
118
119 pub fn snapshot(&self) -> Vec<(String, HandoffState)> {
121 let mut v: Vec<(String, HandoffState)> =
122 self.states.read().iter().map(|(k, v)| (k.clone(), v.clone())).collect();
123 v.sort_by(|a, b| a.0.cmp(&b.0));
124 v
125 }
126}
127
128#[cfg(test)]
129mod tests {
130 use super::*;
131
132 #[test]
133 fn full_three_phase_handoff() {
134 let h = HandoffCoordinator::new();
135 h.begin("s1", "r1").unwrap();
136 assert!(matches!(h.state("s1"), HandoffState::Beginning { .. }));
137 h.ack_begin("s1", 3).unwrap();
138 h.entity_stopped("s1").unwrap();
139 h.entity_stopped("s1").unwrap();
140 assert!(matches!(h.state("s1"), HandoffState::HandingOff { remaining_entities: 1, .. }));
141 h.entity_stopped("s1").unwrap();
142 assert!(matches!(h.state("s1"), HandoffState::Stopped { .. }));
143 h.start_elsewhere("s1", "r2").unwrap();
144 assert!(matches!(h.state("s1"), HandoffState::Started { .. }));
145 }
146
147 #[test]
148 fn ack_without_begin_errors() {
149 let h = HandoffCoordinator::new();
150 let r = h.ack_begin("s1", 5);
151 assert!(matches!(r, Err(HandoffError::InvalidTransition(_))));
152 }
153
154 #[test]
155 fn entity_stopped_without_handing_off_errors() {
156 let h = HandoffCoordinator::new();
157 let r = h.entity_stopped("s1");
158 assert!(matches!(r, Err(HandoffError::InvalidTransition(_))));
159 }
160
161 #[test]
162 fn start_elsewhere_without_stopped_errors() {
163 let h = HandoffCoordinator::new();
164 let r = h.start_elsewhere("s1", "r2");
165 assert!(matches!(r, Err(HandoffError::InvalidTransition(_))));
166 }
167
168 #[test]
169 fn re_handoff_after_started_is_allowed() {
170 let h = HandoffCoordinator::new();
171 h.begin("s1", "r1").unwrap();
172 h.ack_begin("s1", 1).unwrap();
173 h.entity_stopped("s1").unwrap();
174 h.start_elsewhere("s1", "r2").unwrap();
175 h.begin("s1", "r2").unwrap();
177 assert!(matches!(h.state("s1"), HandoffState::Beginning { .. }));
178 }
179
180 #[test]
181 fn forget_drops_state() {
182 let h = HandoffCoordinator::new();
183 h.begin("s1", "r1").unwrap();
184 h.forget("s1");
185 assert!(matches!(h.state("s1"), HandoffState::Idle));
186 }
187}