1use std::{
8 sync::Arc,
9 time::{Duration, SystemTime},
10};
11
12use amaters_core::Key;
13
14use crate::{
15 cluster_command::ClusterCommand,
16 error::{RaftError, RaftResult},
17 log::{LogEntry, StateMachine},
18 shard::{
19 KeyRange, ShardId, ShardMerge, ShardMetadata, ShardRegistry, ShardSplit, ShardState,
20 ShardTransfer,
21 },
22 types::NodeId,
23};
24
25#[derive(Debug, serde::Serialize, serde::Deserialize)]
32struct ShardMetadataDto {
33 id: ShardId,
34 range_start: Vec<u8>,
35 range_end: Vec<u8>,
36 state: u8,
39 node_id: NodeId,
40 replicas: Vec<NodeId>,
41 estimated_keys: u64,
42 estimated_size_bytes: u64,
43 last_updated_ms: u64,
44 created_at_ms: u64,
45 version: u64,
46}
47
48fn state_to_u8(s: &ShardState) -> u8 {
49 match s {
50 ShardState::Active => 0,
51 ShardState::Splitting => 1,
52 ShardState::Merging => 2,
53 ShardState::Transferring => 3,
54 ShardState::Offline => 4,
55 }
56}
57
58fn u8_to_state(v: u8) -> RaftResult<ShardState> {
59 match v {
60 0 => Ok(ShardState::Active),
61 1 => Ok(ShardState::Splitting),
62 2 => Ok(ShardState::Merging),
63 3 => Ok(ShardState::Transferring),
64 4 => Ok(ShardState::Offline),
65 other => Err(RaftError::StateMachineError {
66 message: format!("unknown ShardState discriminant {}", other),
67 }),
68 }
69}
70
71fn system_time_to_ms(t: SystemTime) -> u64 {
72 t.duration_since(SystemTime::UNIX_EPOCH)
73 .unwrap_or(Duration::ZERO)
74 .as_millis() as u64
75}
76
77fn ms_to_system_time(ms: u64) -> SystemTime {
78 SystemTime::UNIX_EPOCH + Duration::from_millis(ms)
79}
80
81impl ShardMetadataDto {
82 fn from_meta(m: &ShardMetadata) -> Self {
83 Self {
84 id: m.id,
85 range_start: m.range.start.as_bytes().to_vec(),
86 range_end: m.range.end.as_bytes().to_vec(),
87 state: state_to_u8(&m.state),
88 node_id: m.node_id,
89 replicas: m.replicas.clone(),
90 estimated_keys: m.estimated_keys,
91 estimated_size_bytes: m.estimated_size_bytes,
92 last_updated_ms: system_time_to_ms(m.last_updated),
93 created_at_ms: system_time_to_ms(m.created_at),
94 version: m.version,
95 }
96 }
97
98 fn into_meta(self) -> RaftResult<ShardMetadata> {
99 let start = Key::from_slice(&self.range_start);
100 let end = Key::from_slice(&self.range_end);
101 let range = KeyRange::new(start, end)?;
102 let mut meta = ShardMetadata::new(self.id, range, self.node_id);
103 meta.state = u8_to_state(self.state)?;
104 meta.replicas = self.replicas;
105 meta.estimated_keys = self.estimated_keys;
106 meta.estimated_size_bytes = self.estimated_size_bytes;
107 meta.last_updated = ms_to_system_time(self.last_updated_ms);
108 meta.created_at = ms_to_system_time(self.created_at_ms);
109 meta.version = self.version;
110 Ok(meta)
111 }
112}
113
114pub struct PlacementStateMachine {
124 registry: Arc<ShardRegistry>,
125}
126
127impl PlacementStateMachine {
128 pub fn new(registry: Arc<ShardRegistry>) -> Self {
130 Self { registry }
131 }
132}
133
134impl StateMachine for PlacementStateMachine {
135 fn apply(&mut self, entry: &LogEntry) -> RaftResult<Vec<u8>> {
136 let cmd = match ClusterCommand::decode(&entry.command.data) {
137 Ok(c) => c,
138 Err(_) => {
139 return Ok(Vec::new());
142 }
143 };
144
145 match cmd {
146 ClusterCommand::PlaceSplit {
147 shard_id,
148 split_key,
149 } => {
150 let left_id = self.registry.allocate_shard_id();
151 let right_id = self.registry.allocate_shard_id();
152 let key = Key::from_slice(&split_key);
153 let split = ShardSplit::new(shard_id, left_id, right_id, key);
154 self.registry.execute_split(&split)?;
155 Ok(Vec::new())
156 }
157 ClusterCommand::PlaceMerge {
158 left_shard_id,
159 right_shard_id,
160 } => {
161 let target_id = self.registry.allocate_shard_id();
162 let merge = ShardMerge::new(left_shard_id, right_shard_id, target_id);
163 self.registry.execute_merge(&merge)?;
164 Ok(Vec::new())
165 }
166 ClusterCommand::PlaceTransfer {
167 shard_id,
168 from_node,
169 to_node,
170 } => {
171 let transfer = ShardTransfer::new(shard_id, from_node, to_node);
172 self.registry.execute_transfer(&transfer)?;
173 Ok(Vec::new())
174 }
175 _ => Ok(Vec::new()),
178 }
179 }
180
181 fn snapshot(&self) -> RaftResult<Vec<u8>> {
182 let shards = self.registry.get_all();
183 let dtos: Vec<ShardMetadataDto> = shards.iter().map(ShardMetadataDto::from_meta).collect();
184 oxicode::serde::encode_serde(&dtos).map_err(|e| RaftError::StateMachineError {
185 message: format!(
186 "PlacementStateMachine::snapshot: serialisation failed: {}",
187 e
188 ),
189 })
190 }
191
192 fn restore(&mut self, snapshot: &[u8]) -> RaftResult<()> {
193 let dtos: Vec<ShardMetadataDto> =
194 oxicode::serde::decode_serde(snapshot).map_err(|e| RaftError::StateMachineError {
195 message: format!(
196 "PlacementStateMachine::restore: deserialisation failed: {}",
197 e
198 ),
199 })?;
200
201 for shard in self.registry.get_all() {
203 let _ = self.registry.remove(shard.id);
206 }
207
208 for dto in dtos {
209 let meta = dto.into_meta()?;
210 self.registry.update(meta)?;
213 }
214
215 Ok(())
216 }
217}
218
219#[cfg(test)]
222mod tests {
223 use super::*;
224 use crate::{
225 log::{Command, LogEntry},
226 shard::{KeyRange, ShardMetadata, ShardRegistry},
227 };
228 use amaters_core::Key;
229
230 fn make_entry(data: Vec<u8>) -> LogEntry {
231 LogEntry::new(1, 1, Command::new(data))
232 }
233
234 fn make_registry_with_shard(
235 start: &str,
236 end: &str,
237 node_id: NodeId,
238 ) -> (Arc<ShardRegistry>, ShardId) {
239 let registry = Arc::new(ShardRegistry::new());
240 let shard_id = registry.allocate_shard_id();
241 let range = KeyRange::new(Key::from_str(start), Key::from_str(end)).expect("valid range");
242 let shard = ShardMetadata::new(shard_id, range, node_id);
243 registry.register(shard).expect("register");
244 (registry, shard_id)
245 }
246
247 #[test]
248 fn test_placement_state_machine_applies_split() {
249 let (registry, shard_id) = make_registry_with_shard("a", "z", 1);
250 let mut sm = PlacementStateMachine::new(Arc::clone(®istry));
251
252 let cmd = ClusterCommand::PlaceSplit {
253 shard_id,
254 split_key: Key::from_str("m").as_bytes().to_vec(),
255 };
256 let entry = make_entry(cmd.encode());
257 sm.apply(&entry).expect("apply split");
258
259 assert!(registry.get(shard_id).is_none(), "parent should be removed");
260 let mut all = registry.get_all();
261 assert_eq!(all.len(), 2, "should have two children");
262
263 for shard in &all {
264 assert_eq!(shard.state, ShardState::Active);
265 }
266
267 all.sort_by(|a, b| a.range.start.cmp(&b.range.start));
268 assert_eq!(all[0].range.start, Key::from_str("a"));
269 assert_eq!(all[0].range.end, Key::from_str("m"));
270 assert_eq!(all[1].range.start, Key::from_str("m"));
271 assert_eq!(all[1].range.end, Key::from_str("z"));
272 }
273
274 #[test]
275 fn test_placement_state_machine_applies_merge() {
276 let registry = Arc::new(ShardRegistry::new());
277 let left_id = registry.allocate_shard_id();
278 let right_id = registry.allocate_shard_id();
279 let left_range = KeyRange::new(Key::from_str("a"), Key::from_str("m")).expect("range");
280 let right_range = KeyRange::new(Key::from_str("m"), Key::from_str("z")).expect("range");
281 registry
282 .register(ShardMetadata::new(left_id, left_range, 1))
283 .expect("register left");
284 registry
285 .register(ShardMetadata::new(right_id, right_range, 1))
286 .expect("register right");
287
288 let mut sm = PlacementStateMachine::new(Arc::clone(®istry));
289
290 let cmd = ClusterCommand::PlaceMerge {
291 left_shard_id: left_id,
292 right_shard_id: right_id,
293 };
294 let entry = make_entry(cmd.encode());
295 sm.apply(&entry).expect("apply merge");
296
297 assert!(registry.get(left_id).is_none(), "left should be removed");
298 assert!(registry.get(right_id).is_none(), "right should be removed");
299 let all = registry.get_all();
300 assert_eq!(all.len(), 1, "should have one merged shard");
301 assert_eq!(all[0].range.start, Key::from_str("a"));
302 assert_eq!(all[0].range.end, Key::from_str("z"));
303 assert_eq!(all[0].state, ShardState::Active);
304 }
305
306 #[test]
307 fn test_placement_snapshot_round_trip() {
308 let (registry, _) = make_registry_with_shard("a", "z", 42);
309 let sm = PlacementStateMachine::new(Arc::clone(®istry));
310
311 let snap = sm.snapshot().expect("snapshot");
312 assert!(!snap.is_empty(), "snapshot must not be empty");
313
314 let new_registry = Arc::new(ShardRegistry::new());
315 let mut sm2 = PlacementStateMachine::new(Arc::clone(&new_registry));
316 sm2.restore(&snap).expect("restore");
317
318 let shards = new_registry.get_all();
319 assert_eq!(shards.len(), 1, "restored registry should have one shard");
320 assert_eq!(shards[0].range.start, Key::from_str("a"));
321 assert_eq!(shards[0].range.end, Key::from_str("z"));
322 assert_eq!(shards[0].node_id, 42);
323 }
324
325 #[test]
326 fn test_apply_non_placement_command_is_noop() {
327 let registry = Arc::new(ShardRegistry::new());
328 let mut sm = PlacementStateMachine::new(Arc::clone(®istry));
329
330 let cmd = ClusterCommand::MembershipAdd {
331 node_id: 5,
332 address: "127.0.0.1:7878".into(),
333 };
334 let entry = make_entry(cmd.encode());
335 let result = sm.apply(&entry).expect("apply membership add");
336 assert!(result.is_empty());
337 assert_eq!(registry.count(), 0, "registry should be unchanged");
338 }
339}