1use async_trait::async_trait;
15use atomr_persistence::Eventsourced;
16use serde::{Deserialize, Serialize};
17use thiserror::Error;
18
19use crate::coordinator::ShardCoordinator;
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
22#[non_exhaustive]
23pub enum CoordinatorEvent {
24 ShardAllocated { shard_id: String, region: String },
25 ShardRebalanced { shard_id: String, from_region: String, to_region: String },
26 ShardRemoved { shard_id: String },
27}
28
29#[derive(Debug, Clone)]
30#[non_exhaustive]
31pub enum CoordinatorCommand {
32 Allocate { shard_id: String, region: String },
33 Rebalance { shard_id: String, to_region: String },
34 Remove { shard_id: String },
35}
36
37#[derive(Debug, Error)]
38#[non_exhaustive]
39pub enum CoordinatorError {
40 #[error("shard `{0}` is unknown")]
41 UnknownShard(String),
42}
43
44#[derive(Default, Debug, Clone)]
49pub struct CoordinatorState {
50 pub allocations: std::collections::HashMap<String, String>,
51}
52
53pub struct PersistentShardCoordinator {
57 persistence_id: String,
58}
59
60impl PersistentShardCoordinator {
61 pub fn new(persistence_id: impl Into<String>) -> Self {
62 Self { persistence_id: persistence_id.into() }
63 }
64}
65
66#[async_trait]
67impl Eventsourced for PersistentShardCoordinator {
68 type Command = CoordinatorCommand;
69 type Event = CoordinatorEvent;
70 type State = CoordinatorState;
71 type Error = CoordinatorError;
72
73 fn persistence_id(&self) -> String {
74 self.persistence_id.clone()
75 }
76
77 fn command_to_events(
78 &self,
79 state: &Self::State,
80 cmd: Self::Command,
81 ) -> Result<Vec<Self::Event>, Self::Error> {
82 match cmd {
83 CoordinatorCommand::Allocate { shard_id, region } => {
84 Ok(vec![CoordinatorEvent::ShardAllocated { shard_id, region }])
85 }
86 CoordinatorCommand::Rebalance { shard_id, to_region } => {
87 let Some(from) = state.allocations.get(&shard_id).cloned() else {
88 return Err(CoordinatorError::UnknownShard(shard_id));
89 };
90 Ok(vec![CoordinatorEvent::ShardRebalanced { shard_id, from_region: from, to_region }])
91 }
92 CoordinatorCommand::Remove { shard_id } => {
93 if !state.allocations.contains_key(&shard_id) {
94 return Err(CoordinatorError::UnknownShard(shard_id));
95 }
96 Ok(vec![CoordinatorEvent::ShardRemoved { shard_id }])
97 }
98 }
99 }
100
101 fn apply_event(state: &mut Self::State, event: &Self::Event) {
102 match event {
103 CoordinatorEvent::ShardAllocated { shard_id, region } => {
104 state.allocations.insert(shard_id.clone(), region.clone());
105 }
106 CoordinatorEvent::ShardRebalanced { shard_id, to_region, .. } => {
107 state.allocations.insert(shard_id.clone(), to_region.clone());
108 }
109 CoordinatorEvent::ShardRemoved { shard_id } => {
110 state.allocations.remove(shard_id);
111 }
112 }
113 }
114
115 fn encode_event(event: &Self::Event) -> Result<Vec<u8>, String> {
116 let cfg = bincode::config::standard();
117 bincode::serde::encode_to_vec(event, cfg).map_err(|e| e.to_string())
118 }
119
120 fn decode_event(bytes: &[u8]) -> Result<Self::Event, String> {
121 let cfg = bincode::config::standard();
122 bincode::serde::decode_from_slice::<Self::Event, _>(bytes, cfg)
123 .map(|(v, _)| v)
124 .map_err(|e| e.to_string())
125 }
126}
127
128pub fn project_into(state: &CoordinatorState, target: &ShardCoordinator) {
131 for (shard, region) in &state.allocations {
132 target.rebalance(shard, region.clone());
137 }
138}
139
140#[cfg(test)]
141mod tests {
142 use super::*;
143 use atomr_persistence::{EventsourcedError, InMemoryJournal, RecoveryPermitter};
144 use std::sync::Arc;
145
146 fn cfg() -> (Arc<InMemoryJournal>, RecoveryPermitter) {
147 (Arc::new(InMemoryJournal::default()), RecoveryPermitter::new(2))
148 }
149
150 #[tokio::test]
151 async fn allocate_then_rebalance_round_trips() {
152 let (journal, permits) = cfg();
153 let coord = PersistentShardCoordinator::new("coord-1");
154 let mut state = CoordinatorState::default();
155 let mut seq = 0u64;
156
157 coord
158 .handle_command(
159 journal.clone(),
160 &mut state,
161 &mut seq,
162 "w",
163 CoordinatorCommand::Allocate { shard_id: "s1".into(), region: "r1".into() },
164 )
165 .await
166 .unwrap();
167 coord
168 .handle_command(
169 journal.clone(),
170 &mut state,
171 &mut seq,
172 "w",
173 CoordinatorCommand::Rebalance { shard_id: "s1".into(), to_region: "r2".into() },
174 )
175 .await
176 .unwrap();
177 assert_eq!(state.allocations.get("s1"), Some(&"r2".to_string()));
178
179 let mut coord2 = PersistentShardCoordinator::new("coord-1");
181 let mut state2 = CoordinatorState::default();
182 coord2.recover(journal.clone(), &mut state2, &permits).await.unwrap();
183 assert_eq!(state2.allocations.get("s1"), Some(&"r2".to_string()));
184 }
185
186 #[tokio::test]
187 async fn rebalance_unknown_shard_errors() {
188 let (journal, _) = cfg();
189 let coord = PersistentShardCoordinator::new("coord-2");
190 let mut state = CoordinatorState::default();
191 let mut seq = 0u64;
192 let r = coord
193 .handle_command(
194 journal,
195 &mut state,
196 &mut seq,
197 "w",
198 CoordinatorCommand::Rebalance { shard_id: "missing".into(), to_region: "r2".into() },
199 )
200 .await;
201 assert!(matches!(r, Err(EventsourcedError::Domain(CoordinatorError::UnknownShard(_)))));
202 }
203
204 #[tokio::test]
205 async fn project_into_in_memory_coordinator() {
206 let (journal, permits) = cfg();
207 let coord = PersistentShardCoordinator::new("coord-3");
208 let mut state = CoordinatorState::default();
209 let mut seq = 0u64;
210 for (sid, region) in [("s1", "r1"), ("s2", "r2"), ("s3", "r1")] {
211 coord
212 .handle_command(
213 journal.clone(),
214 &mut state,
215 &mut seq,
216 "w",
217 CoordinatorCommand::Allocate { shard_id: sid.into(), region: region.into() },
218 )
219 .await
220 .unwrap();
221 }
222
223 let mut state2 = CoordinatorState::default();
225 let mut coord2 = PersistentShardCoordinator::new("coord-3");
226 coord2.recover(journal.clone(), &mut state2, &permits).await.unwrap();
227 let local = ShardCoordinator::new();
228 project_into(&state2, &local);
229 assert_eq!(local.region_for("s1"), Some("r1".to_string()));
230 assert_eq!(local.region_for("s2"), Some("r2".to_string()));
231 assert_eq!(local.region_for("s3"), Some("r1".to_string()));
232 }
233
234 #[tokio::test]
235 async fn remove_shard_drops_from_state() {
236 let (journal, _) = cfg();
237 let coord = PersistentShardCoordinator::new("coord-4");
238 let mut state = CoordinatorState::default();
239 let mut seq = 0u64;
240 coord
241 .handle_command(
242 journal.clone(),
243 &mut state,
244 &mut seq,
245 "w",
246 CoordinatorCommand::Allocate { shard_id: "s1".into(), region: "r1".into() },
247 )
248 .await
249 .unwrap();
250 coord
251 .handle_command(
252 journal.clone(),
253 &mut state,
254 &mut seq,
255 "w",
256 CoordinatorCommand::Remove { shard_id: "s1".into() },
257 )
258 .await
259 .unwrap();
260 assert!(!state.allocations.contains_key("s1"));
261 }
262}