1use std::collections::BTreeMap;
8use std::fmt::Debug;
9use std::io::Cursor;
10use std::ops::RangeBounds;
11use std::sync::Arc;
12
13use openraft::storage::{LogState, RaftLogReader, RaftSnapshotBuilder, Snapshot};
14use openraft::{
15 BasicNode, Entry, EntryPayload, LogId, OptionalSend, RaftStorage, RaftTypeConfig, SnapshotMeta,
16 StorageError, StorageIOError, StoredMembership, Vote,
17};
18use serde::{Deserialize, Serialize};
19use tokio::sync::RwLock;
20
21use crate::{NodeId, SlotRange};
22
23#[derive(Debug, Clone, Copy, Default, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
25pub struct TypeConfig;
26
27impl RaftTypeConfig for TypeConfig {
28 type D = ClusterCommand;
29 type R = ClusterResponse;
30 type Node = BasicNode;
31 type NodeId = u64;
32 type Entry = Entry<TypeConfig>;
33 type SnapshotData = Cursor<Vec<u8>>;
34 type AsyncRuntime = openraft::TokioRuntime;
35 type Responder = openraft::impls::OneshotResponder<TypeConfig>;
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
43pub enum ClusterCommand {
44 AddNode {
46 node_id: NodeId,
47 raft_id: u64,
48 addr: String,
49 is_primary: bool,
50 },
51 RemoveNode { node_id: NodeId },
53 AssignSlots {
55 node_id: NodeId,
56 slots: Vec<SlotRange>,
57 },
58 PromoteReplica { replica_id: NodeId },
60 BeginMigration { slot: u16, from: NodeId, to: NodeId },
62 CompleteMigration { slot: u16, new_owner: NodeId },
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
68pub enum ClusterResponse {
69 Ok,
70 Error(String),
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize, Default)]
75pub struct ClusterSnapshot {
76 pub last_applied: Option<LogId<u64>>,
77 pub last_membership: StoredMembership<u64, BasicNode>,
78 pub state_data: Vec<u8>,
80}
81
82#[derive(Debug, Clone, Serialize, Deserialize, Default)]
84pub struct ClusterStateData {
85 pub nodes: BTreeMap<String, NodeInfo>,
87 pub slots: BTreeMap<u16, String>,
89 pub migrations: BTreeMap<u16, MigrationState>,
91}
92
93#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct NodeInfo {
96 pub node_id: String,
97 pub raft_id: u64,
98 pub addr: String,
99 pub is_primary: bool,
100 pub slots: Vec<SlotRange>,
101}
102
103#[derive(Debug, Clone, Serialize, Deserialize)]
105pub struct MigrationState {
106 pub from: String,
107 pub to: String,
108}
109
110#[derive(Debug)]
112pub struct Storage {
113 vote: RwLock<Option<Vote<u64>>>,
114 log: RwLock<BTreeMap<u64, Entry<TypeConfig>>>,
115 last_purged: RwLock<Option<LogId<u64>>>,
116 last_applied: RwLock<Option<LogId<u64>>>,
117 last_membership: RwLock<StoredMembership<u64, BasicNode>>,
118 snapshot: RwLock<Option<StoredSnapshot>>,
119 state: Arc<RwLock<ClusterStateData>>,
120}
121
122#[derive(Debug, Clone)]
123struct StoredSnapshot {
124 meta: SnapshotMeta<u64, BasicNode>,
125 data: Vec<u8>,
126}
127
128impl Default for Storage {
129 fn default() -> Self {
130 Self::new()
131 }
132}
133
134impl Storage {
135 pub fn new() -> Self {
136 Self {
137 vote: RwLock::new(None),
138 log: RwLock::new(BTreeMap::new()),
139 last_purged: RwLock::new(None),
140 last_applied: RwLock::new(None),
141 last_membership: RwLock::new(StoredMembership::default()),
142 snapshot: RwLock::new(None),
143 state: Arc::new(RwLock::new(ClusterStateData::default())),
144 }
145 }
146
147 pub fn state(&self) -> Arc<RwLock<ClusterStateData>> {
148 Arc::clone(&self.state)
149 }
150
151 fn apply_command(cmd: &ClusterCommand, state: &mut ClusterStateData) -> ClusterResponse {
152 match cmd {
153 ClusterCommand::AddNode {
154 node_id,
155 raft_id,
156 addr,
157 is_primary,
158 } => {
159 let key = node_id.0.to_string();
160 state.nodes.insert(
161 key.clone(),
162 NodeInfo {
163 node_id: key,
164 raft_id: *raft_id,
165 addr: addr.clone(),
166 is_primary: *is_primary,
167 slots: Vec::new(),
168 },
169 );
170 ClusterResponse::Ok
171 }
172
173 ClusterCommand::RemoveNode { node_id } => {
174 let key = node_id.0.to_string();
175 state.nodes.remove(&key);
176 state.slots.retain(|_, owner| owner != &key);
177 ClusterResponse::Ok
178 }
179
180 ClusterCommand::AssignSlots { node_id, slots } => {
181 let key = node_id.0.to_string();
182 if let Some(node) = state.nodes.get_mut(&key) {
183 node.slots = slots.clone();
184 for slot_range in slots {
185 for slot in slot_range.start..=slot_range.end {
186 state.slots.insert(slot, key.clone());
187 }
188 }
189 ClusterResponse::Ok
190 } else {
191 ClusterResponse::Error(format!("node {} not found", node_id))
192 }
193 }
194
195 ClusterCommand::PromoteReplica { replica_id } => {
196 let key = replica_id.0.to_string();
197 if let Some(node) = state.nodes.get_mut(&key) {
198 node.is_primary = true;
199 ClusterResponse::Ok
200 } else {
201 ClusterResponse::Error(format!("replica {} not found", replica_id))
202 }
203 }
204
205 ClusterCommand::BeginMigration { slot, from, to } => {
206 state.migrations.insert(
207 *slot,
208 MigrationState {
209 from: from.0.to_string(),
210 to: to.0.to_string(),
211 },
212 );
213 ClusterResponse::Ok
214 }
215
216 ClusterCommand::CompleteMigration { slot, new_owner } => {
217 state.migrations.remove(slot);
218 let key = new_owner.0.to_string();
219 state.slots.insert(*slot, key);
220 ClusterResponse::Ok
221 }
222 }
223 }
224}
225
226impl RaftLogReader<TypeConfig> for Arc<Storage> {
227 async fn try_get_log_entries<RB: RangeBounds<u64> + Clone + Debug + OptionalSend>(
228 &mut self,
229 range: RB,
230 ) -> Result<Vec<Entry<TypeConfig>>, StorageError<u64>> {
231 let log = self.log.read().await;
232 Ok(log.range(range).map(|(_, v)| v.clone()).collect())
233 }
234}
235
236impl RaftSnapshotBuilder<TypeConfig> for Arc<Storage> {
237 async fn build_snapshot(&mut self) -> Result<Snapshot<TypeConfig>, StorageError<u64>> {
238 let last_applied = *self.last_applied.read().await;
239 let membership = self.last_membership.read().await.clone();
240 let state = self.state.read().await;
241
242 let state_data =
243 serde_json::to_vec(&*state).map_err(|e| StorageIOError::write_snapshot(None, &e))?;
244
245 let snapshot = ClusterSnapshot {
246 last_applied,
247 last_membership: membership.clone(),
248 state_data,
249 };
250
251 let data =
252 serde_json::to_vec(&snapshot).map_err(|e| StorageIOError::write_snapshot(None, &e))?;
253
254 let snapshot_id = last_applied
255 .map(|id| format!("{}-{}", id.leader_id, id.index))
256 .unwrap_or_else(|| "0-0".to_string());
257
258 let meta = SnapshotMeta {
259 last_log_id: last_applied,
260 last_membership: membership,
261 snapshot_id,
262 };
263
264 *self.snapshot.write().await = Some(StoredSnapshot {
266 meta: meta.clone(),
267 data: data.clone(),
268 });
269
270 Ok(Snapshot {
271 meta,
272 snapshot: Box::new(Cursor::new(data)),
273 })
274 }
275}
276
277impl RaftStorage<TypeConfig> for Arc<Storage> {
278 type LogReader = Self;
279 type SnapshotBuilder = Self;
280
281 async fn get_log_state(&mut self) -> Result<LogState<TypeConfig>, StorageError<u64>> {
282 let log = self.log.read().await;
283 let last = log.iter().next_back().map(|(_, e)| e.log_id);
284 let purged = *self.last_purged.read().await;
285
286 Ok(LogState {
287 last_purged_log_id: purged,
288 last_log_id: last,
289 })
290 }
291
292 async fn save_vote(&mut self, vote: &Vote<u64>) -> Result<(), StorageError<u64>> {
293 *self.vote.write().await = Some(*vote);
294 Ok(())
295 }
296
297 async fn read_vote(&mut self) -> Result<Option<Vote<u64>>, StorageError<u64>> {
298 Ok(*self.vote.read().await)
299 }
300
301 async fn get_log_reader(&mut self) -> Self::LogReader {
302 Arc::clone(self)
303 }
304
305 async fn append_to_log<I>(&mut self, entries: I) -> Result<(), StorageError<u64>>
306 where
307 I: IntoIterator<Item = Entry<TypeConfig>> + Send,
308 {
309 let mut log = self.log.write().await;
310 for entry in entries {
311 log.insert(entry.log_id.index, entry);
312 }
313 Ok(())
314 }
315
316 async fn delete_conflict_logs_since(
317 &mut self,
318 log_id: LogId<u64>,
319 ) -> Result<(), StorageError<u64>> {
320 let mut log = self.log.write().await;
321 let to_remove: Vec<_> = log.range(log_id.index..).map(|(k, _)| *k).collect();
322 for key in to_remove {
323 log.remove(&key);
324 }
325 Ok(())
326 }
327
328 async fn purge_logs_upto(&mut self, log_id: LogId<u64>) -> Result<(), StorageError<u64>> {
329 let mut log = self.log.write().await;
330 let to_remove: Vec<_> = log.range(..=log_id.index).map(|(k, _)| *k).collect();
331 for key in to_remove {
332 log.remove(&key);
333 }
334 *self.last_purged.write().await = Some(log_id);
335 Ok(())
336 }
337
338 async fn last_applied_state(
339 &mut self,
340 ) -> Result<(Option<LogId<u64>>, StoredMembership<u64, BasicNode>), StorageError<u64>> {
341 let last_applied = *self.last_applied.read().await;
342 let membership = self.last_membership.read().await.clone();
343 Ok((last_applied, membership))
344 }
345
346 async fn apply_to_state_machine(
347 &mut self,
348 entries: &[Entry<TypeConfig>],
349 ) -> Result<Vec<ClusterResponse>, StorageError<u64>> {
350 let mut results = Vec::new();
351 let mut state = self.state.write().await;
352
353 for entry in entries {
354 *self.last_applied.write().await = Some(entry.log_id);
355
356 match &entry.payload {
357 EntryPayload::Blank => {
358 results.push(ClusterResponse::Ok);
359 }
360 EntryPayload::Normal(cmd) => {
361 let result = Storage::apply_command(cmd, &mut state);
362 results.push(result);
363 }
364 EntryPayload::Membership(m) => {
365 *self.last_membership.write().await =
366 StoredMembership::new(Some(entry.log_id), m.clone());
367 results.push(ClusterResponse::Ok);
368 }
369 }
370 }
371
372 Ok(results)
373 }
374
375 async fn get_snapshot_builder(&mut self) -> Self::SnapshotBuilder {
376 Arc::clone(self)
377 }
378
379 async fn begin_receiving_snapshot(
380 &mut self,
381 ) -> Result<Box<Cursor<Vec<u8>>>, StorageError<u64>> {
382 Ok(Box::new(Cursor::new(Vec::new())))
383 }
384
385 async fn install_snapshot(
386 &mut self,
387 meta: &SnapshotMeta<u64, BasicNode>,
388 snapshot: Box<Cursor<Vec<u8>>>,
389 ) -> Result<(), StorageError<u64>> {
390 let data = snapshot.into_inner();
391 let snap: ClusterSnapshot = serde_json::from_slice(&data)
392 .map_err(|e| StorageIOError::read_snapshot(Some(meta.signature()), &e))?;
393
394 *self.last_applied.write().await = snap.last_applied;
395 *self.last_membership.write().await = snap.last_membership;
396
397 let state_data: ClusterStateData = serde_json::from_slice(&snap.state_data)
398 .map_err(|e| StorageIOError::read_snapshot(Some(meta.signature()), &e))?;
399 *self.state.write().await = state_data;
400
401 *self.snapshot.write().await = Some(StoredSnapshot {
402 meta: meta.clone(),
403 data,
404 });
405
406 Ok(())
407 }
408
409 async fn get_current_snapshot(
410 &mut self,
411 ) -> Result<Option<Snapshot<TypeConfig>>, StorageError<u64>> {
412 let snap = self.snapshot.read().await;
413 Ok(snap.as_ref().map(|s| Snapshot {
414 meta: s.meta.clone(),
415 snapshot: Box::new(Cursor::new(s.data.clone())),
416 }))
417 }
418}
419
420#[cfg(test)]
421mod tests {
422 use super::*;
423 use openraft::CommittedLeaderId;
424
425 fn log_id(term: u64, index: u64) -> LogId<u64> {
427 LogId::new(CommittedLeaderId::new(term, 0), index)
428 }
429
430 #[tokio::test]
431 async fn storage_add_node() {
432 let storage = Arc::new(Storage::new());
433 let mut storage_clone = Arc::clone(&storage);
434
435 let node_id = NodeId::new();
436 let entry = Entry {
437 log_id: log_id(1, 1),
438 payload: EntryPayload::Normal(ClusterCommand::AddNode {
439 node_id,
440 raft_id: 1,
441 addr: "127.0.0.1:6379".to_string(),
442 is_primary: true,
443 }),
444 };
445
446 let results = storage_clone
447 .apply_to_state_machine(&[entry])
448 .await
449 .unwrap();
450 assert_eq!(results, vec![ClusterResponse::Ok]);
451
452 let state_arc = storage.state();
453 let state = state_arc.read().await;
454 assert!(state.nodes.contains_key(&node_id.0.to_string()));
455 }
456
457 #[tokio::test]
458 async fn storage_assign_slots() {
459 let storage = Arc::new(Storage::new());
460 let mut storage_clone = Arc::clone(&storage);
461
462 let node_id = NodeId::new();
463
464 let add_entry = Entry {
466 log_id: log_id(1, 1),
467 payload: EntryPayload::Normal(ClusterCommand::AddNode {
468 node_id,
469 raft_id: 1,
470 addr: "127.0.0.1:6379".to_string(),
471 is_primary: true,
472 }),
473 };
474 storage_clone
475 .apply_to_state_machine(&[add_entry])
476 .await
477 .unwrap();
478
479 let assign_entry = Entry {
481 log_id: log_id(1, 2),
482 payload: EntryPayload::Normal(ClusterCommand::AssignSlots {
483 node_id,
484 slots: vec![SlotRange::new(0, 5460)],
485 }),
486 };
487 let results = storage_clone
488 .apply_to_state_machine(&[assign_entry])
489 .await
490 .unwrap();
491 assert_eq!(results, vec![ClusterResponse::Ok]);
492
493 let state_arc = storage.state();
494 let state = state_arc.read().await;
495 assert_eq!(state.slots.get(&0), Some(&node_id.0.to_string()));
496 assert_eq!(state.slots.get(&5460), Some(&node_id.0.to_string()));
497 }
498
499 #[tokio::test]
500 async fn storage_migration() {
501 let storage = Arc::new(Storage::new());
502 let mut storage_clone = Arc::clone(&storage);
503
504 let node1 = NodeId::new();
505 let node2 = NodeId::new();
506
507 let entries: Vec<Entry<TypeConfig>> = [node1, node2]
509 .iter()
510 .enumerate()
511 .map(|(i, node_id)| Entry {
512 log_id: log_id(1, i as u64 + 1),
513 payload: EntryPayload::Normal(ClusterCommand::AddNode {
514 node_id: *node_id,
515 raft_id: i as u64 + 1,
516 addr: format!("127.0.0.1:{}", 6379 + i),
517 is_primary: true,
518 }),
519 })
520 .collect();
521 storage_clone
522 .apply_to_state_machine(&entries)
523 .await
524 .unwrap();
525
526 let begin_entry = Entry {
528 log_id: log_id(1, 3),
529 payload: EntryPayload::Normal(ClusterCommand::BeginMigration {
530 slot: 100,
531 from: node1,
532 to: node2,
533 }),
534 };
535 storage_clone
536 .apply_to_state_machine(&[begin_entry])
537 .await
538 .unwrap();
539
540 {
541 let state_arc = storage.state();
542 let state = state_arc.read().await;
543 assert!(state.migrations.contains_key(&100));
544 }
545
546 let complete_entry = Entry {
548 log_id: log_id(1, 4),
549 payload: EntryPayload::Normal(ClusterCommand::CompleteMigration {
550 slot: 100,
551 new_owner: node2,
552 }),
553 };
554 storage_clone
555 .apply_to_state_machine(&[complete_entry])
556 .await
557 .unwrap();
558
559 {
560 let state_arc = storage.state();
561 let state = state_arc.read().await;
562 assert!(!state.migrations.contains_key(&100));
563 assert_eq!(state.slots.get(&100), Some(&node2.0.to_string()));
564 }
565 }
566
567 #[tokio::test]
568 async fn storage_log_operations() {
569 let storage = Arc::new(Storage::new());
570 let mut storage_clone = Arc::clone(&storage);
571
572 let entry = Entry::<TypeConfig> {
573 log_id: log_id(1, 1),
574 payload: EntryPayload::Blank,
575 };
576
577 storage_clone.append_to_log(vec![entry]).await.unwrap();
578
579 let state = storage_clone.get_log_state().await.unwrap();
580 assert_eq!(state.last_log_id, Some(log_id(1, 1)));
581 }
582
583 #[tokio::test]
584 async fn storage_vote() {
585 let storage = Arc::new(Storage::new());
586 let mut storage_clone = Arc::clone(&storage);
587
588 let vote = Vote::new(1, 1);
589 storage_clone.save_vote(&vote).await.unwrap();
590
591 let read_vote = storage_clone.read_vote().await.unwrap();
592 assert_eq!(read_vote, Some(vote));
593 }
594}