Skip to main content

atomr_cluster_sharding/
persistent_coordinator.rs

1//! `PersistentShardCoordinator` — event-sourced allocation table.
2//!
3//! Wraps a [`ShardCoordinator`] and persists every allocation / rebalance
4//! decision through `atomr_persistence::Eventsourced` so a coordinator
5//! restart on a different node restores the exact same allocation table.
6//!
7//! Events:
8//! ```text
9//! ShardAllocated  { shard_id, region }
10//! ShardRebalanced { shard_id, from_region, to_region }
11//! ShardRemoved    { shard_id }
12//! ```
13
14use async_trait::async_trait;
15use atomr_persistence::Eventsourced;
16use serde::{Deserialize, Serialize};
17use thiserror::Error;
18
19use crate::coordinator::ShardCoordinator;
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
22#[non_exhaustive]
23pub enum CoordinatorEvent {
24    ShardAllocated { shard_id: String, region: String },
25    ShardRebalanced { shard_id: String, from_region: String, to_region: String },
26    ShardRemoved { shard_id: String },
27}
28
29#[derive(Debug, Clone)]
30#[non_exhaustive]
31pub enum CoordinatorCommand {
32    Allocate { shard_id: String, region: String },
33    Rebalance { shard_id: String, to_region: String },
34    Remove { shard_id: String },
35}
36
37#[derive(Debug, Error)]
38#[non_exhaustive]
39pub enum CoordinatorError {
40    #[error("shard `{0}` is unknown")]
41    UnknownShard(String),
42}
43
44/// Eventsourced coordinator state — kept separate from
45/// [`ShardCoordinator`] so callers can rebuild it from journal
46/// replay. The in-memory `ShardCoordinator` is the local
47/// projection; this struct mirrors it through the persistence layer.
48#[derive(Default, Debug, Clone)]
49pub struct CoordinatorState {
50    pub allocations: std::collections::HashMap<String, String>,
51}
52
53/// Wraps a [`ShardCoordinator`] with `Eventsourced` plumbing. Use
54/// `recover` on boot, then `command` for every allocation /
55/// rebalance / removal.
56pub struct PersistentShardCoordinator {
57    persistence_id: String,
58}
59
60impl PersistentShardCoordinator {
61    pub fn new(persistence_id: impl Into<String>) -> Self {
62        Self { persistence_id: persistence_id.into() }
63    }
64}
65
66#[async_trait]
67impl Eventsourced for PersistentShardCoordinator {
68    type Command = CoordinatorCommand;
69    type Event = CoordinatorEvent;
70    type State = CoordinatorState;
71    type Error = CoordinatorError;
72
73    fn persistence_id(&self) -> String {
74        self.persistence_id.clone()
75    }
76
77    fn command_to_events(
78        &self,
79        state: &Self::State,
80        cmd: Self::Command,
81    ) -> Result<Vec<Self::Event>, Self::Error> {
82        match cmd {
83            CoordinatorCommand::Allocate { shard_id, region } => {
84                Ok(vec![CoordinatorEvent::ShardAllocated { shard_id, region }])
85            }
86            CoordinatorCommand::Rebalance { shard_id, to_region } => {
87                let Some(from) = state.allocations.get(&shard_id).cloned() else {
88                    return Err(CoordinatorError::UnknownShard(shard_id));
89                };
90                Ok(vec![CoordinatorEvent::ShardRebalanced { shard_id, from_region: from, to_region }])
91            }
92            CoordinatorCommand::Remove { shard_id } => {
93                if !state.allocations.contains_key(&shard_id) {
94                    return Err(CoordinatorError::UnknownShard(shard_id));
95                }
96                Ok(vec![CoordinatorEvent::ShardRemoved { shard_id }])
97            }
98        }
99    }
100
101    fn apply_event(state: &mut Self::State, event: &Self::Event) {
102        match event {
103            CoordinatorEvent::ShardAllocated { shard_id, region } => {
104                state.allocations.insert(shard_id.clone(), region.clone());
105            }
106            CoordinatorEvent::ShardRebalanced { shard_id, to_region, .. } => {
107                state.allocations.insert(shard_id.clone(), to_region.clone());
108            }
109            CoordinatorEvent::ShardRemoved { shard_id } => {
110                state.allocations.remove(shard_id);
111            }
112        }
113    }
114
115    fn encode_event(event: &Self::Event) -> Result<Vec<u8>, String> {
116        let cfg = bincode::config::standard();
117        bincode::serde::encode_to_vec(event, cfg).map_err(|e| e.to_string())
118    }
119
120    fn decode_event(bytes: &[u8]) -> Result<Self::Event, String> {
121        let cfg = bincode::config::standard();
122        bincode::serde::decode_from_slice::<Self::Event, _>(bytes, cfg)
123            .map(|(v, _)| v)
124            .map_err(|e| e.to_string())
125    }
126}
127
128/// Project a [`CoordinatorState`] (rebuilt from journal replay) onto
129/// a fresh [`ShardCoordinator`]. Useful right after `recover`.
130pub fn project_into(state: &CoordinatorState, target: &ShardCoordinator) {
131    for (shard, region) in &state.allocations {
132        // `allocate` is the right primitive — first-mention sets the
133        // region; we then overwrite via `rebalance` if the journal
134        // shows a later move (which `apply_event` already collapsed
135        // into the final allocation).
136        target.rebalance(shard, region.clone());
137    }
138}
139
140#[cfg(test)]
141mod tests {
142    use super::*;
143    use atomr_persistence::{EventsourcedError, InMemoryJournal, RecoveryPermitter};
144    use std::sync::Arc;
145
146    fn cfg() -> (Arc<InMemoryJournal>, RecoveryPermitter) {
147        (Arc::new(InMemoryJournal::default()), RecoveryPermitter::new(2))
148    }
149
150    #[tokio::test]
151    async fn allocate_then_rebalance_round_trips() {
152        let (journal, permits) = cfg();
153        let coord = PersistentShardCoordinator::new("coord-1");
154        let mut state = CoordinatorState::default();
155        let mut seq = 0u64;
156
157        coord
158            .handle_command(
159                journal.clone(),
160                &mut state,
161                &mut seq,
162                "w",
163                CoordinatorCommand::Allocate { shard_id: "s1".into(), region: "r1".into() },
164            )
165            .await
166            .unwrap();
167        coord
168            .handle_command(
169                journal.clone(),
170                &mut state,
171                &mut seq,
172                "w",
173                CoordinatorCommand::Rebalance { shard_id: "s1".into(), to_region: "r2".into() },
174            )
175            .await
176            .unwrap();
177        assert_eq!(state.allocations.get("s1"), Some(&"r2".to_string()));
178
179        // Replay → identical state.
180        let mut coord2 = PersistentShardCoordinator::new("coord-1");
181        let mut state2 = CoordinatorState::default();
182        coord2.recover(journal.clone(), &mut state2, &permits).await.unwrap();
183        assert_eq!(state2.allocations.get("s1"), Some(&"r2".to_string()));
184    }
185
186    #[tokio::test]
187    async fn rebalance_unknown_shard_errors() {
188        let (journal, _) = cfg();
189        let coord = PersistentShardCoordinator::new("coord-2");
190        let mut state = CoordinatorState::default();
191        let mut seq = 0u64;
192        let r = coord
193            .handle_command(
194                journal,
195                &mut state,
196                &mut seq,
197                "w",
198                CoordinatorCommand::Rebalance { shard_id: "missing".into(), to_region: "r2".into() },
199            )
200            .await;
201        assert!(matches!(r, Err(EventsourcedError::Domain(CoordinatorError::UnknownShard(_)))));
202    }
203
204    #[tokio::test]
205    async fn project_into_in_memory_coordinator() {
206        let (journal, permits) = cfg();
207        let coord = PersistentShardCoordinator::new("coord-3");
208        let mut state = CoordinatorState::default();
209        let mut seq = 0u64;
210        for (sid, region) in [("s1", "r1"), ("s2", "r2"), ("s3", "r1")] {
211            coord
212                .handle_command(
213                    journal.clone(),
214                    &mut state,
215                    &mut seq,
216                    "w",
217                    CoordinatorCommand::Allocate { shard_id: sid.into(), region: region.into() },
218                )
219                .await
220                .unwrap();
221        }
222
223        // Replay into a brand-new in-memory coordinator.
224        let mut state2 = CoordinatorState::default();
225        let mut coord2 = PersistentShardCoordinator::new("coord-3");
226        coord2.recover(journal.clone(), &mut state2, &permits).await.unwrap();
227        let local = ShardCoordinator::new();
228        project_into(&state2, &local);
229        assert_eq!(local.region_for("s1"), Some("r1".to_string()));
230        assert_eq!(local.region_for("s2"), Some("r2".to_string()));
231        assert_eq!(local.region_for("s3"), Some("r1".to_string()));
232    }
233
234    #[tokio::test]
235    async fn remove_shard_drops_from_state() {
236        let (journal, _) = cfg();
237        let coord = PersistentShardCoordinator::new("coord-4");
238        let mut state = CoordinatorState::default();
239        let mut seq = 0u64;
240        coord
241            .handle_command(
242                journal.clone(),
243                &mut state,
244                &mut seq,
245                "w",
246                CoordinatorCommand::Allocate { shard_id: "s1".into(), region: "r1".into() },
247            )
248            .await
249            .unwrap();
250        coord
251            .handle_command(
252                journal.clone(),
253                &mut state,
254                &mut seq,
255                "w",
256                CoordinatorCommand::Remove { shard_id: "s1".into() },
257            )
258            .await
259            .unwrap();
260        assert!(!state.allocations.contains_key("s1"));
261    }
262}