Skip to main content

ember_cluster/
raft.rs

1//! Raft consensus for cluster configuration.
2//!
3//! Uses openraft to achieve consensus on cluster topology changes.
4//! Only configuration changes go through Raft - data operations use
5//! primary-replica async replication for lower latency.
6
7use std::collections::BTreeMap;
8use std::fmt::Debug;
9use std::io::Cursor;
10use std::ops::RangeBounds;
11use std::sync::Arc;
12
13use openraft::storage::{LogState, RaftLogReader, RaftSnapshotBuilder, Snapshot};
14use openraft::{
15    BasicNode, Entry, EntryPayload, LogId, OptionalSend, RaftStorage, RaftTypeConfig, SnapshotMeta,
16    StorageError, StorageIOError, StoredMembership, Vote,
17};
18use serde::{Deserialize, Serialize};
19use tokio::sync::RwLock;
20
21use crate::{NodeId, SlotRange};
22
23/// Type configuration for openraft.
24#[derive(Debug, Clone, Copy, Default, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
25pub struct TypeConfig;
26
27impl RaftTypeConfig for TypeConfig {
28    type D = ClusterCommand;
29    type R = ClusterResponse;
30    type Node = BasicNode;
31    type NodeId = u64;
32    type Entry = Entry<TypeConfig>;
33    type SnapshotData = Cursor<Vec<u8>>;
34    type AsyncRuntime = openraft::TokioRuntime;
35    type Responder = openraft::impls::OneshotResponder<TypeConfig>;
36}
37
38/// Commands that modify cluster configuration.
39///
40/// These are replicated through Raft to ensure all nodes agree
41/// on the cluster topology.
42#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
43pub enum ClusterCommand {
44    /// Add a new node to the cluster.
45    AddNode {
46        node_id: NodeId,
47        raft_id: u64,
48        addr: String,
49        is_primary: bool,
50    },
51    /// Remove a node from the cluster.
52    RemoveNode { node_id: NodeId },
53    /// Assign slots to a node.
54    AssignSlots {
55        node_id: NodeId,
56        slots: Vec<SlotRange>,
57    },
58    /// Promote a replica to primary (during failover).
59    PromoteReplica { replica_id: NodeId },
60    /// Mark a slot as migrating.
61    BeginMigration { slot: u16, from: NodeId, to: NodeId },
62    /// Complete a slot migration.
63    CompleteMigration { slot: u16, new_owner: NodeId },
64}
65
66/// Response from applying a cluster command.
67#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
68pub enum ClusterResponse {
69    Ok,
70    Error(String),
71}
72
73/// State machine snapshot.
74#[derive(Debug, Clone, Serialize, Deserialize, Default)]
75pub struct ClusterSnapshot {
76    pub last_applied: Option<LogId<u64>>,
77    pub last_membership: StoredMembership<u64, BasicNode>,
78    /// Serialized cluster state.
79    pub state_data: Vec<u8>,
80}
81
82/// Internal cluster state managed by the state machine.
83#[derive(Debug, Clone, Serialize, Deserialize, Default)]
84pub struct ClusterStateData {
85    /// Node ID to raft ID mapping.
86    pub nodes: BTreeMap<String, NodeInfo>,
87    /// Slot assignments.
88    pub slots: BTreeMap<u16, String>,
89    /// Ongoing migrations.
90    pub migrations: BTreeMap<u16, MigrationState>,
91}
92
93/// Information about a cluster node.
94#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct NodeInfo {
96    pub node_id: String,
97    pub raft_id: u64,
98    pub addr: String,
99    pub is_primary: bool,
100    pub slots: Vec<SlotRange>,
101}
102
103/// State of an ongoing slot migration.
104#[derive(Debug, Clone, Serialize, Deserialize)]
105pub struct MigrationState {
106    pub from: String,
107    pub to: String,
108}
109
110/// Combined log and state machine storage for Raft.
111#[derive(Debug)]
112pub struct Storage {
113    vote: RwLock<Option<Vote<u64>>>,
114    log: RwLock<BTreeMap<u64, Entry<TypeConfig>>>,
115    last_purged: RwLock<Option<LogId<u64>>>,
116    last_applied: RwLock<Option<LogId<u64>>>,
117    last_membership: RwLock<StoredMembership<u64, BasicNode>>,
118    snapshot: RwLock<Option<StoredSnapshot>>,
119    state: Arc<RwLock<ClusterStateData>>,
120}
121
122#[derive(Debug, Clone)]
123struct StoredSnapshot {
124    meta: SnapshotMeta<u64, BasicNode>,
125    data: Vec<u8>,
126}
127
128impl Default for Storage {
129    fn default() -> Self {
130        Self::new()
131    }
132}
133
134impl Storage {
135    pub fn new() -> Self {
136        Self {
137            vote: RwLock::new(None),
138            log: RwLock::new(BTreeMap::new()),
139            last_purged: RwLock::new(None),
140            last_applied: RwLock::new(None),
141            last_membership: RwLock::new(StoredMembership::default()),
142            snapshot: RwLock::new(None),
143            state: Arc::new(RwLock::new(ClusterStateData::default())),
144        }
145    }
146
147    pub fn state(&self) -> Arc<RwLock<ClusterStateData>> {
148        Arc::clone(&self.state)
149    }
150
151    fn apply_command(cmd: &ClusterCommand, state: &mut ClusterStateData) -> ClusterResponse {
152        match cmd {
153            ClusterCommand::AddNode {
154                node_id,
155                raft_id,
156                addr,
157                is_primary,
158            } => {
159                let key = node_id.0.to_string();
160                state.nodes.insert(
161                    key.clone(),
162                    NodeInfo {
163                        node_id: key,
164                        raft_id: *raft_id,
165                        addr: addr.clone(),
166                        is_primary: *is_primary,
167                        slots: Vec::new(),
168                    },
169                );
170                ClusterResponse::Ok
171            }
172
173            ClusterCommand::RemoveNode { node_id } => {
174                let key = node_id.0.to_string();
175                state.nodes.remove(&key);
176                state.slots.retain(|_, owner| owner != &key);
177                ClusterResponse::Ok
178            }
179
180            ClusterCommand::AssignSlots { node_id, slots } => {
181                let key = node_id.0.to_string();
182                if let Some(node) = state.nodes.get_mut(&key) {
183                    node.slots = slots.clone();
184                    for slot_range in slots {
185                        for slot in slot_range.start..=slot_range.end {
186                            state.slots.insert(slot, key.clone());
187                        }
188                    }
189                    ClusterResponse::Ok
190                } else {
191                    ClusterResponse::Error(format!("node {} not found", node_id))
192                }
193            }
194
195            ClusterCommand::PromoteReplica { replica_id } => {
196                let key = replica_id.0.to_string();
197                if let Some(node) = state.nodes.get_mut(&key) {
198                    node.is_primary = true;
199                    ClusterResponse::Ok
200                } else {
201                    ClusterResponse::Error(format!("replica {} not found", replica_id))
202                }
203            }
204
205            ClusterCommand::BeginMigration { slot, from, to } => {
206                state.migrations.insert(
207                    *slot,
208                    MigrationState {
209                        from: from.0.to_string(),
210                        to: to.0.to_string(),
211                    },
212                );
213                ClusterResponse::Ok
214            }
215
216            ClusterCommand::CompleteMigration { slot, new_owner } => {
217                state.migrations.remove(slot);
218                let key = new_owner.0.to_string();
219                state.slots.insert(*slot, key);
220                ClusterResponse::Ok
221            }
222        }
223    }
224}
225
226impl RaftLogReader<TypeConfig> for Arc<Storage> {
227    async fn try_get_log_entries<RB: RangeBounds<u64> + Clone + Debug + OptionalSend>(
228        &mut self,
229        range: RB,
230    ) -> Result<Vec<Entry<TypeConfig>>, StorageError<u64>> {
231        let log = self.log.read().await;
232        Ok(log.range(range).map(|(_, v)| v.clone()).collect())
233    }
234}
235
236impl RaftSnapshotBuilder<TypeConfig> for Arc<Storage> {
237    async fn build_snapshot(&mut self) -> Result<Snapshot<TypeConfig>, StorageError<u64>> {
238        let last_applied = *self.last_applied.read().await;
239        let membership = self.last_membership.read().await.clone();
240        let state = self.state.read().await;
241
242        let state_data =
243            serde_json::to_vec(&*state).map_err(|e| StorageIOError::write_snapshot(None, &e))?;
244
245        let snapshot = ClusterSnapshot {
246            last_applied,
247            last_membership: membership.clone(),
248            state_data,
249        };
250
251        let data =
252            serde_json::to_vec(&snapshot).map_err(|e| StorageIOError::write_snapshot(None, &e))?;
253
254        let snapshot_id = last_applied
255            .map(|id| format!("{}-{}", id.leader_id, id.index))
256            .unwrap_or_else(|| "0-0".to_string());
257
258        let meta = SnapshotMeta {
259            last_log_id: last_applied,
260            last_membership: membership,
261            snapshot_id,
262        };
263
264        // Store the snapshot
265        *self.snapshot.write().await = Some(StoredSnapshot {
266            meta: meta.clone(),
267            data: data.clone(),
268        });
269
270        Ok(Snapshot {
271            meta,
272            snapshot: Box::new(Cursor::new(data)),
273        })
274    }
275}
276
277impl RaftStorage<TypeConfig> for Arc<Storage> {
278    type LogReader = Self;
279    type SnapshotBuilder = Self;
280
281    async fn get_log_state(&mut self) -> Result<LogState<TypeConfig>, StorageError<u64>> {
282        let log = self.log.read().await;
283        let last = log.iter().next_back().map(|(_, e)| e.log_id);
284        let purged = *self.last_purged.read().await;
285
286        Ok(LogState {
287            last_purged_log_id: purged,
288            last_log_id: last,
289        })
290    }
291
292    async fn save_vote(&mut self, vote: &Vote<u64>) -> Result<(), StorageError<u64>> {
293        *self.vote.write().await = Some(*vote);
294        Ok(())
295    }
296
297    async fn read_vote(&mut self) -> Result<Option<Vote<u64>>, StorageError<u64>> {
298        Ok(*self.vote.read().await)
299    }
300
301    async fn get_log_reader(&mut self) -> Self::LogReader {
302        Arc::clone(self)
303    }
304
305    async fn append_to_log<I>(&mut self, entries: I) -> Result<(), StorageError<u64>>
306    where
307        I: IntoIterator<Item = Entry<TypeConfig>> + Send,
308    {
309        let mut log = self.log.write().await;
310        for entry in entries {
311            log.insert(entry.log_id.index, entry);
312        }
313        Ok(())
314    }
315
316    async fn delete_conflict_logs_since(
317        &mut self,
318        log_id: LogId<u64>,
319    ) -> Result<(), StorageError<u64>> {
320        let mut log = self.log.write().await;
321        let to_remove: Vec<_> = log.range(log_id.index..).map(|(k, _)| *k).collect();
322        for key in to_remove {
323            log.remove(&key);
324        }
325        Ok(())
326    }
327
328    async fn purge_logs_upto(&mut self, log_id: LogId<u64>) -> Result<(), StorageError<u64>> {
329        let mut log = self.log.write().await;
330        let to_remove: Vec<_> = log.range(..=log_id.index).map(|(k, _)| *k).collect();
331        for key in to_remove {
332            log.remove(&key);
333        }
334        *self.last_purged.write().await = Some(log_id);
335        Ok(())
336    }
337
338    async fn last_applied_state(
339        &mut self,
340    ) -> Result<(Option<LogId<u64>>, StoredMembership<u64, BasicNode>), StorageError<u64>> {
341        let last_applied = *self.last_applied.read().await;
342        let membership = self.last_membership.read().await.clone();
343        Ok((last_applied, membership))
344    }
345
346    async fn apply_to_state_machine(
347        &mut self,
348        entries: &[Entry<TypeConfig>],
349    ) -> Result<Vec<ClusterResponse>, StorageError<u64>> {
350        let mut results = Vec::new();
351        let mut state = self.state.write().await;
352
353        for entry in entries {
354            *self.last_applied.write().await = Some(entry.log_id);
355
356            match &entry.payload {
357                EntryPayload::Blank => {
358                    results.push(ClusterResponse::Ok);
359                }
360                EntryPayload::Normal(cmd) => {
361                    let result = Storage::apply_command(cmd, &mut state);
362                    results.push(result);
363                }
364                EntryPayload::Membership(m) => {
365                    *self.last_membership.write().await =
366                        StoredMembership::new(Some(entry.log_id), m.clone());
367                    results.push(ClusterResponse::Ok);
368                }
369            }
370        }
371
372        Ok(results)
373    }
374
375    async fn get_snapshot_builder(&mut self) -> Self::SnapshotBuilder {
376        Arc::clone(self)
377    }
378
379    async fn begin_receiving_snapshot(
380        &mut self,
381    ) -> Result<Box<Cursor<Vec<u8>>>, StorageError<u64>> {
382        Ok(Box::new(Cursor::new(Vec::new())))
383    }
384
385    async fn install_snapshot(
386        &mut self,
387        meta: &SnapshotMeta<u64, BasicNode>,
388        snapshot: Box<Cursor<Vec<u8>>>,
389    ) -> Result<(), StorageError<u64>> {
390        let data = snapshot.into_inner();
391        let snap: ClusterSnapshot = serde_json::from_slice(&data)
392            .map_err(|e| StorageIOError::read_snapshot(Some(meta.signature()), &e))?;
393
394        *self.last_applied.write().await = snap.last_applied;
395        *self.last_membership.write().await = snap.last_membership;
396
397        let state_data: ClusterStateData = serde_json::from_slice(&snap.state_data)
398            .map_err(|e| StorageIOError::read_snapshot(Some(meta.signature()), &e))?;
399        *self.state.write().await = state_data;
400
401        *self.snapshot.write().await = Some(StoredSnapshot {
402            meta: meta.clone(),
403            data,
404        });
405
406        Ok(())
407    }
408
409    async fn get_current_snapshot(
410        &mut self,
411    ) -> Result<Option<Snapshot<TypeConfig>>, StorageError<u64>> {
412        let snap = self.snapshot.read().await;
413        Ok(snap.as_ref().map(|s| Snapshot {
414            meta: s.meta.clone(),
415            snapshot: Box::new(Cursor::new(s.data.clone())),
416        }))
417    }
418}
419
420#[cfg(test)]
421mod tests {
422    use super::*;
423    use openraft::CommittedLeaderId;
424
425    /// Helper to create a LogId for tests.
426    fn log_id(term: u64, index: u64) -> LogId<u64> {
427        LogId::new(CommittedLeaderId::new(term, 0), index)
428    }
429
430    #[tokio::test]
431    async fn storage_add_node() {
432        let storage = Arc::new(Storage::new());
433        let mut storage_clone = Arc::clone(&storage);
434
435        let node_id = NodeId::new();
436        let entry = Entry {
437            log_id: log_id(1, 1),
438            payload: EntryPayload::Normal(ClusterCommand::AddNode {
439                node_id,
440                raft_id: 1,
441                addr: "127.0.0.1:6379".to_string(),
442                is_primary: true,
443            }),
444        };
445
446        let results = storage_clone
447            .apply_to_state_machine(&[entry])
448            .await
449            .unwrap();
450        assert_eq!(results, vec![ClusterResponse::Ok]);
451
452        let state_arc = storage.state();
453        let state = state_arc.read().await;
454        assert!(state.nodes.contains_key(&node_id.0.to_string()));
455    }
456
457    #[tokio::test]
458    async fn storage_assign_slots() {
459        let storage = Arc::new(Storage::new());
460        let mut storage_clone = Arc::clone(&storage);
461
462        let node_id = NodeId::new();
463
464        // Add node first
465        let add_entry = Entry {
466            log_id: log_id(1, 1),
467            payload: EntryPayload::Normal(ClusterCommand::AddNode {
468                node_id,
469                raft_id: 1,
470                addr: "127.0.0.1:6379".to_string(),
471                is_primary: true,
472            }),
473        };
474        storage_clone
475            .apply_to_state_machine(&[add_entry])
476            .await
477            .unwrap();
478
479        // Assign slots
480        let assign_entry = Entry {
481            log_id: log_id(1, 2),
482            payload: EntryPayload::Normal(ClusterCommand::AssignSlots {
483                node_id,
484                slots: vec![SlotRange::new(0, 5460)],
485            }),
486        };
487        let results = storage_clone
488            .apply_to_state_machine(&[assign_entry])
489            .await
490            .unwrap();
491        assert_eq!(results, vec![ClusterResponse::Ok]);
492
493        let state_arc = storage.state();
494        let state = state_arc.read().await;
495        assert_eq!(state.slots.get(&0), Some(&node_id.0.to_string()));
496        assert_eq!(state.slots.get(&5460), Some(&node_id.0.to_string()));
497    }
498
499    #[tokio::test]
500    async fn storage_migration() {
501        let storage = Arc::new(Storage::new());
502        let mut storage_clone = Arc::clone(&storage);
503
504        let node1 = NodeId::new();
505        let node2 = NodeId::new();
506
507        // Add nodes
508        let entries: Vec<Entry<TypeConfig>> = [node1, node2]
509            .iter()
510            .enumerate()
511            .map(|(i, node_id)| Entry {
512                log_id: log_id(1, i as u64 + 1),
513                payload: EntryPayload::Normal(ClusterCommand::AddNode {
514                    node_id: *node_id,
515                    raft_id: i as u64 + 1,
516                    addr: format!("127.0.0.1:{}", 6379 + i),
517                    is_primary: true,
518                }),
519            })
520            .collect();
521        storage_clone
522            .apply_to_state_machine(&entries)
523            .await
524            .unwrap();
525
526        // Begin migration
527        let begin_entry = Entry {
528            log_id: log_id(1, 3),
529            payload: EntryPayload::Normal(ClusterCommand::BeginMigration {
530                slot: 100,
531                from: node1,
532                to: node2,
533            }),
534        };
535        storage_clone
536            .apply_to_state_machine(&[begin_entry])
537            .await
538            .unwrap();
539
540        {
541            let state_arc = storage.state();
542            let state = state_arc.read().await;
543            assert!(state.migrations.contains_key(&100));
544        }
545
546        // Complete migration
547        let complete_entry = Entry {
548            log_id: log_id(1, 4),
549            payload: EntryPayload::Normal(ClusterCommand::CompleteMigration {
550                slot: 100,
551                new_owner: node2,
552            }),
553        };
554        storage_clone
555            .apply_to_state_machine(&[complete_entry])
556            .await
557            .unwrap();
558
559        {
560            let state_arc = storage.state();
561            let state = state_arc.read().await;
562            assert!(!state.migrations.contains_key(&100));
563            assert_eq!(state.slots.get(&100), Some(&node2.0.to_string()));
564        }
565    }
566
567    #[tokio::test]
568    async fn storage_log_operations() {
569        let storage = Arc::new(Storage::new());
570        let mut storage_clone = Arc::clone(&storage);
571
572        let entry = Entry::<TypeConfig> {
573            log_id: log_id(1, 1),
574            payload: EntryPayload::Blank,
575        };
576
577        storage_clone.append_to_log(vec![entry]).await.unwrap();
578
579        let state = storage_clone.get_log_state().await.unwrap();
580        assert_eq!(state.last_log_id, Some(log_id(1, 1)));
581    }
582
583    #[tokio::test]
584    async fn storage_vote() {
585        let storage = Arc::new(Storage::new());
586        let mut storage_clone = Arc::clone(&storage);
587
588        let vote = Vote::new(1, 1);
589        storage_clone.save_vote(&vote).await.unwrap();
590
591        let read_vote = storage_clone.read_vote().await.unwrap();
592        assert_eq!(read_vote, Some(vote));
593    }
594}