Skip to main content

ember_cluster/
raft.rs

1//! Raft consensus for cluster configuration.
2//!
3//! Uses openraft to achieve consensus on cluster topology changes.
4//! Only configuration changes go through Raft - data operations use
5//! primary-replica async replication for lower latency.
6
7use std::collections::BTreeMap;
8use std::fmt::Debug;
9use std::io::Cursor;
10use std::net::SocketAddr;
11use std::ops::RangeBounds;
12use std::path::PathBuf;
13use std::sync::Arc;
14
15use openraft::error::{
16    ClientWriteError, InstallSnapshotError, NetworkError, RPCError, RaftError, Unreachable,
17};
18use openraft::network::{RPCOption, RaftNetwork, RaftNetworkFactory as RaftNetworkFactoryTrait};
19use openraft::raft::{
20    AppendEntriesRequest, AppendEntriesResponse, InstallSnapshotRequest, InstallSnapshotResponse,
21    VoteRequest, VoteResponse,
22};
23use openraft::storage::{Adaptor, LogState, RaftLogReader, RaftSnapshotBuilder, Snapshot};
24use openraft::{
25    BasicNode, Config, Entry, EntryPayload, LogId, OptionalSend, Raft, RaftStorage, RaftTypeConfig,
26    ServerState, SnapshotMeta, StorageError, StorageIOError, StoredMembership, Vote,
27};
28use serde::{Deserialize, Serialize};
29use tokio::net::{TcpListener, TcpStream};
30use tokio::sync::{watch, RwLock};
31use tracing::{debug, warn};
32
33use crate::raft_log::{RaftDisk, RaftDiskError};
34
35use crate::auth::ClusterSecret;
36use crate::raft_transport::{
37    read_frame, read_frame_authenticated, write_frame, write_frame_authenticated, RaftRpc,
38    RaftRpcResponse,
39};
40use crate::slots::SLOT_COUNT;
41use crate::{NodeId, SlotRange};
42
43/// Type configuration for openraft.
44#[derive(Debug, Clone, Copy, Default, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
45pub struct TypeConfig;
46
47impl RaftTypeConfig for TypeConfig {
48    type D = ClusterCommand;
49    type R = ClusterResponse;
50    type Node = BasicNode;
51    type NodeId = u64;
52    type Entry = Entry<TypeConfig>;
53    type SnapshotData = Cursor<Vec<u8>>;
54    type AsyncRuntime = openraft::TokioRuntime;
55    type Responder = openraft::impls::OneshotResponder<TypeConfig>;
56}
57
58/// Commands that modify cluster configuration.
59///
60/// These are replicated through Raft to ensure all nodes agree
61/// on the cluster topology.
62#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
63pub enum ClusterCommand {
64    /// Add a new node to the cluster.
65    AddNode {
66        node_id: NodeId,
67        raft_id: u64,
68        addr: String,
69        is_primary: bool,
70    },
71    /// Remove a node from the cluster.
72    RemoveNode { node_id: NodeId },
73    /// Assign slots to a node.
74    AssignSlots {
75        node_id: NodeId,
76        slots: Vec<SlotRange>,
77    },
78    /// Remove specific slots from a node.
79    RemoveSlots {
80        node_id: NodeId,
81        slots: Vec<SlotRange>,
82    },
83    /// Promote a replica to primary (during failover).
84    PromoteReplica { replica_id: NodeId },
85    /// Mark a slot as migrating.
86    BeginMigration { slot: u16, from: NodeId, to: NodeId },
87    /// Complete a slot migration.
88    CompleteMigration { slot: u16, new_owner: NodeId },
89}
90
91/// Response from applying a cluster command.
92#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
93pub enum ClusterResponse {
94    Ok,
95    Error(String),
96}
97
98/// State machine snapshot.
99#[derive(Debug, Clone, Serialize, Deserialize, Default)]
100pub struct ClusterSnapshot {
101    pub last_applied: Option<LogId<u64>>,
102    pub last_membership: StoredMembership<u64, BasicNode>,
103    /// Serialized cluster state.
104    pub state_data: Vec<u8>,
105}
106
107/// Internal cluster state managed by the state machine.
108#[derive(Debug, Clone, Serialize, Deserialize, Default)]
109pub struct ClusterStateData {
110    /// Node ID to raft ID mapping.
111    pub nodes: BTreeMap<String, NodeInfo>,
112    /// Slot assignments.
113    pub slots: BTreeMap<u16, String>,
114    /// Ongoing migrations.
115    pub migrations: BTreeMap<u16, MigrationState>,
116}
117
118/// Information about a cluster node.
119#[derive(Debug, Clone, Serialize, Deserialize)]
120pub struct NodeInfo {
121    pub node_id: String,
122    pub raft_id: u64,
123    pub addr: String,
124    pub is_primary: bool,
125    pub slots: Vec<SlotRange>,
126}
127
128/// State of an ongoing slot migration.
129#[derive(Debug, Clone, Serialize, Deserialize)]
130pub struct MigrationState {
131    pub from: String,
132    pub to: String,
133}
134
135/// Combined log and state machine storage for Raft.
136///
137/// When `disk` is `Some`, all mutations are persisted to the raft directory
138/// so that state survives restarts. When `None`, storage is purely in-memory
139/// (used when cluster mode is disabled or no data directory is configured).
140#[derive(Debug)]
141pub struct Storage {
142    vote: RwLock<Option<Vote<u64>>>,
143    log: RwLock<BTreeMap<u64, Entry<TypeConfig>>>,
144    last_purged: RwLock<Option<LogId<u64>>>,
145    last_applied: RwLock<Option<LogId<u64>>>,
146    last_membership: RwLock<StoredMembership<u64, BasicNode>>,
147    snapshot: RwLock<Option<StoredSnapshot>>,
148    state: Arc<RwLock<ClusterStateData>>,
149    /// Notifies watchers whenever `apply_to_state_machine` commits entries.
150    state_tx: watch::Sender<ClusterStateData>,
151    /// Disk persistence layer. `None` for in-memory-only mode.
152    disk: Option<std::sync::Mutex<RaftDisk>>,
153}
154
155#[derive(Debug, Clone)]
156struct StoredSnapshot {
157    meta: SnapshotMeta<u64, BasicNode>,
158    data: Vec<u8>,
159}
160
161impl Default for Storage {
162    fn default() -> Self {
163        // creates a disconnected watch channel (state changes are not observed externally)
164        let (state_tx, _) = watch::channel(ClusterStateData::default());
165        Self {
166            vote: RwLock::new(None),
167            log: RwLock::new(BTreeMap::new()),
168            last_purged: RwLock::new(None),
169            last_applied: RwLock::new(None),
170            last_membership: RwLock::new(StoredMembership::default()),
171            snapshot: RwLock::new(None),
172            state: Arc::new(RwLock::new(ClusterStateData::default())),
173            state_tx,
174            disk: None,
175        }
176    }
177}
178
179impl Storage {
180    /// Creates a new in-memory storage instance and returns a receiver that
181    /// fires whenever the Raft state machine commits entries.
182    pub fn new() -> (Arc<Self>, watch::Receiver<ClusterStateData>) {
183        let (state_tx, state_rx) = watch::channel(ClusterStateData::default());
184        let storage = Arc::new(Self {
185            vote: RwLock::new(None),
186            log: RwLock::new(BTreeMap::new()),
187            last_purged: RwLock::new(None),
188            last_applied: RwLock::new(None),
189            last_membership: RwLock::new(StoredMembership::default()),
190            snapshot: RwLock::new(None),
191            state: Arc::new(RwLock::new(ClusterStateData::default())),
192            state_tx,
193            disk: None,
194        });
195        (storage, state_rx)
196    }
197
198    /// Opens persistent storage at `raft_dir`, recovering any existing state.
199    ///
200    /// On a fresh start the directory is created and empty files are written.
201    /// On recovery, the vote, log entries, and snapshot are loaded from disk
202    /// into memory so the Raft node can resume where it left off.
203    pub fn open(
204        raft_dir: PathBuf,
205    ) -> Result<(Arc<Self>, watch::Receiver<ClusterStateData>), RaftDiskError> {
206        let (raft_disk, recovered) = RaftDisk::open(&raft_dir)?;
207
208        let (state_tx, state_rx) = watch::channel(ClusterStateData::default());
209
210        let snapshot = recovered
211            .snapshot
212            .map(|(meta, data)| StoredSnapshot { meta, data });
213
214        let storage = Arc::new(Self {
215            vote: RwLock::new(recovered.vote),
216            log: RwLock::new(recovered.log),
217            last_purged: RwLock::new(recovered.last_purged),
218            last_applied: RwLock::new(None),
219            last_membership: RwLock::new(StoredMembership::default()),
220            snapshot: RwLock::new(snapshot),
221            state: Arc::new(RwLock::new(ClusterStateData::default())),
222            state_tx,
223            disk: Some(std::sync::Mutex::new(raft_disk)),
224        });
225
226        Ok((storage, state_rx))
227    }
228
229    /// Returns `true` if the log has any entries (indicating a prior run).
230    ///
231    /// Used by the bootstrap path to decide whether to call `bootstrap_single()`
232    /// (fresh start) or skip it (recovery from persisted state).
233    pub fn has_log_entries(&self) -> bool {
234        // safe to call from sync context — the RwLock is only contended
235        // by the Raft runtime which hasn't started yet during bootstrap
236        self.log
237            .try_read()
238            .map(|log| !log.is_empty())
239            .unwrap_or(false)
240    }
241
242    pub fn state(&self) -> Arc<RwLock<ClusterStateData>> {
243        Arc::clone(&self.state)
244    }
245
246    fn apply_command(cmd: &ClusterCommand, state: &mut ClusterStateData) -> ClusterResponse {
247        match cmd {
248            ClusterCommand::AddNode {
249                node_id,
250                raft_id,
251                addr,
252                is_primary,
253            } => {
254                let key = node_id.as_key();
255                state.nodes.insert(
256                    key.clone(),
257                    NodeInfo {
258                        node_id: key,
259                        raft_id: *raft_id,
260                        addr: addr.clone(),
261                        is_primary: *is_primary,
262                        slots: Vec::new(),
263                    },
264                );
265                ClusterResponse::Ok
266            }
267
268            ClusterCommand::RemoveNode { node_id } => {
269                let key = node_id.as_key();
270                state.nodes.remove(&key);
271                state.slots.retain(|_, owner| owner != &key);
272                ClusterResponse::Ok
273            }
274
275            ClusterCommand::AssignSlots { node_id, slots } => {
276                // validate all slot ranges before applying
277                for range in slots {
278                    if range.start > range.end || range.end >= SLOT_COUNT {
279                        return ClusterResponse::Error(format!(
280                            "invalid slot range {}..={} (max {})",
281                            range.start,
282                            range.end,
283                            SLOT_COUNT - 1
284                        ));
285                    }
286                }
287                let key = node_id.as_key();
288                if let Some(node) = state.nodes.get_mut(&key) {
289                    node.slots = slots.clone();
290                    for slot_range in slots {
291                        for slot in slot_range.start..=slot_range.end {
292                            state.slots.insert(slot, key.clone());
293                        }
294                    }
295                    ClusterResponse::Ok
296                } else {
297                    ClusterResponse::Error(format!("node {} not found", node_id))
298                }
299            }
300
301            ClusterCommand::RemoveSlots { node_id, slots } => {
302                for range in slots {
303                    if range.start > range.end || range.end >= SLOT_COUNT {
304                        return ClusterResponse::Error(format!(
305                            "invalid slot range {}..={} (max {})",
306                            range.start,
307                            range.end,
308                            SLOT_COUNT - 1
309                        ));
310                    }
311                }
312                let key = node_id.as_key();
313                for slot_range in slots {
314                    for slot in slot_range.start..=slot_range.end {
315                        // only remove if this node is the current owner
316                        if state.slots.get(&slot).map(|s| s.as_str()) == Some(key.as_str()) {
317                            state.slots.remove(&slot);
318                        }
319                    }
320                }
321                // rebuild node's slot list from what remains; split the borrow
322                let remaining = slots_for_node_in_state(state, &key);
323                if let Some(node) = state.nodes.get_mut(&key) {
324                    node.slots = remaining;
325                }
326                ClusterResponse::Ok
327            }
328
329            ClusterCommand::PromoteReplica { replica_id } => {
330                let key = replica_id.as_key();
331                if let Some(node) = state.nodes.get_mut(&key) {
332                    node.is_primary = true;
333                    ClusterResponse::Ok
334                } else {
335                    ClusterResponse::Error(format!("replica {} not found", replica_id))
336                }
337            }
338
339            ClusterCommand::BeginMigration { slot, from, to } => {
340                if *slot >= SLOT_COUNT {
341                    return ClusterResponse::Error(format!(
342                        "slot {slot} out of range (max {})",
343                        SLOT_COUNT - 1
344                    ));
345                }
346                state.migrations.insert(
347                    *slot,
348                    MigrationState {
349                        from: from.as_key(),
350                        to: to.as_key(),
351                    },
352                );
353                ClusterResponse::Ok
354            }
355
356            ClusterCommand::CompleteMigration { slot, new_owner } => {
357                if !state.migrations.contains_key(slot) {
358                    return ClusterResponse::Error(format!(
359                        "no migration in progress for slot {slot}"
360                    ));
361                }
362                state.migrations.remove(slot);
363                let key = new_owner.as_key();
364                state.slots.insert(*slot, key);
365                ClusterResponse::Ok
366            }
367        }
368    }
369}
370
371/// Returns the slot ranges owned by `node_key` according to the slot map.
372fn slots_for_node_in_state(state: &ClusterStateData, node_key: &str) -> Vec<SlotRange> {
373    let mut slots: Vec<u16> = state
374        .slots
375        .iter()
376        .filter(|(_, v)| v.as_str() == node_key)
377        .map(|(k, _)| *k)
378        .collect();
379    slots.sort_unstable();
380
381    // compress into contiguous ranges
382    let mut ranges = Vec::new();
383    let mut i = 0;
384    while i < slots.len() {
385        let start = slots[i];
386        let mut end = start;
387        while i + 1 < slots.len() && slots[i + 1] == end + 1 {
388            i += 1;
389            end = slots[i];
390        }
391        ranges.push(SlotRange::new(start, end));
392        i += 1;
393    }
394    ranges
395}
396
397impl RaftLogReader<TypeConfig> for Arc<Storage> {
398    async fn try_get_log_entries<RB: RangeBounds<u64> + Clone + Debug + OptionalSend>(
399        &mut self,
400        range: RB,
401    ) -> Result<Vec<Entry<TypeConfig>>, StorageError<u64>> {
402        let log = self.log.read().await;
403        Ok(log.range(range).map(|(_, v)| v.clone()).collect())
404    }
405}
406
407impl RaftSnapshotBuilder<TypeConfig> for Arc<Storage> {
408    async fn build_snapshot(&mut self) -> Result<Snapshot<TypeConfig>, StorageError<u64>> {
409        let last_applied = *self.last_applied.read().await;
410        let membership = self.last_membership.read().await.clone();
411        let state = self.state.read().await;
412
413        let state_data =
414            serde_json::to_vec(&*state).map_err(|e| StorageIOError::write_snapshot(None, &e))?;
415
416        let snapshot = ClusterSnapshot {
417            last_applied,
418            last_membership: membership.clone(),
419            state_data,
420        };
421
422        let data =
423            serde_json::to_vec(&snapshot).map_err(|e| StorageIOError::write_snapshot(None, &e))?;
424
425        let snapshot_id = last_applied
426            .map(|id| format!("{}-{}", id.leader_id, id.index))
427            .unwrap_or_else(|| "0-0".to_string());
428
429        let meta = SnapshotMeta {
430            last_log_id: last_applied,
431            last_membership: membership,
432            snapshot_id,
433        };
434
435        // Store the snapshot
436        *self.snapshot.write().await = Some(StoredSnapshot {
437            meta: meta.clone(),
438            data: data.clone(),
439        });
440
441        if let Some(disk) = &self.disk {
442            let d = disk
443                .lock()
444                .map_err(|e| StorageIOError::write(&io_error(&format!("lock poisoned: {e}"))))?;
445            d.write_snapshot(&meta, &data).map_err(StorageError::from)?;
446        }
447
448        Ok(Snapshot {
449            meta,
450            snapshot: Box::new(Cursor::new(data)),
451        })
452    }
453}
454
455impl RaftStorage<TypeConfig> for Arc<Storage> {
456    type LogReader = Self;
457    type SnapshotBuilder = Self;
458
459    async fn get_log_state(&mut self) -> Result<LogState<TypeConfig>, StorageError<u64>> {
460        let log = self.log.read().await;
461        let last = log.iter().next_back().map(|(_, e)| e.log_id);
462        let purged = *self.last_purged.read().await;
463
464        Ok(LogState {
465            last_purged_log_id: purged,
466            last_log_id: last,
467        })
468    }
469
470    async fn save_vote(&mut self, vote: &Vote<u64>) -> Result<(), StorageError<u64>> {
471        *self.vote.write().await = Some(*vote);
472        if let Some(disk) = &self.disk {
473            let last_purged = *self.last_purged.read().await;
474            disk.lock()
475                .map_err(|e| StorageIOError::write(&io_error(&format!("lock poisoned: {e}"))))?
476                .write_meta(Some(*vote), last_purged)?;
477        }
478        Ok(())
479    }
480
481    async fn read_vote(&mut self) -> Result<Option<Vote<u64>>, StorageError<u64>> {
482        Ok(*self.vote.read().await)
483    }
484
485    async fn get_log_reader(&mut self) -> Self::LogReader {
486        Arc::clone(self)
487    }
488
489    async fn append_to_log<I>(&mut self, entries: I) -> Result<(), StorageError<u64>>
490    where
491        I: IntoIterator<Item = Entry<TypeConfig>> + Send,
492    {
493        let mut log = self.log.write().await;
494        let new_entries: Vec<Entry<TypeConfig>> = entries.into_iter().collect();
495        for entry in &new_entries {
496            log.insert(entry.log_id.index, entry.clone());
497        }
498        if let Some(disk) = &self.disk {
499            disk.lock()
500                .map_err(|e| StorageIOError::write(&io_error(&format!("lock poisoned: {e}"))))?
501                .append_entries(&new_entries)?;
502        }
503        Ok(())
504    }
505
506    async fn delete_conflict_logs_since(
507        &mut self,
508        log_id: LogId<u64>,
509    ) -> Result<(), StorageError<u64>> {
510        let mut log = self.log.write().await;
511        let to_remove: Vec<_> = log.range(log_id.index..).map(|(k, _)| *k).collect();
512        for key in to_remove {
513            log.remove(&key);
514        }
515        if let Some(disk) = &self.disk {
516            disk.lock()
517                .map_err(|e| StorageIOError::write(&io_error(&format!("lock poisoned: {e}"))))?
518                .rewrite_log(&log)?;
519        }
520        Ok(())
521    }
522
523    async fn purge_logs_upto(&mut self, log_id: LogId<u64>) -> Result<(), StorageError<u64>> {
524        let mut log = self.log.write().await;
525        let to_remove: Vec<_> = log.range(..=log_id.index).map(|(k, _)| *k).collect();
526        for key in to_remove {
527            log.remove(&key);
528        }
529        *self.last_purged.write().await = Some(log_id);
530        if let Some(disk) = &self.disk {
531            let vote = *self.vote.read().await;
532            let mut d = disk
533                .lock()
534                .map_err(|e| StorageIOError::write(&io_error(&format!("lock poisoned: {e}"))))?;
535            d.write_meta(vote, Some(log_id))?;
536            d.rewrite_log(&log)?;
537        }
538        Ok(())
539    }
540
541    async fn last_applied_state(
542        &mut self,
543    ) -> Result<(Option<LogId<u64>>, StoredMembership<u64, BasicNode>), StorageError<u64>> {
544        let last_applied = *self.last_applied.read().await;
545        let membership = self.last_membership.read().await.clone();
546        Ok((last_applied, membership))
547    }
548
549    async fn apply_to_state_machine(
550        &mut self,
551        entries: &[Entry<TypeConfig>],
552    ) -> Result<Vec<ClusterResponse>, StorageError<u64>> {
553        let mut results = Vec::new();
554        let mut state = self.state.write().await;
555
556        for entry in entries {
557            *self.last_applied.write().await = Some(entry.log_id);
558
559            match &entry.payload {
560                EntryPayload::Blank => {
561                    results.push(ClusterResponse::Ok);
562                }
563                EntryPayload::Normal(cmd) => {
564                    let result = Storage::apply_command(cmd, &mut state);
565                    results.push(result);
566                }
567                EntryPayload::Membership(m) => {
568                    *self.last_membership.write().await =
569                        StoredMembership::new(Some(entry.log_id), m.clone());
570                    results.push(ClusterResponse::Ok);
571                }
572            }
573        }
574
575        // notify the reconciliation watcher after releasing the write lock
576        let state_snapshot = state.clone();
577        drop(state);
578        let _ = self.state_tx.send_replace(state_snapshot);
579
580        Ok(results)
581    }
582
583    async fn get_snapshot_builder(&mut self) -> Self::SnapshotBuilder {
584        Arc::clone(self)
585    }
586
587    async fn begin_receiving_snapshot(
588        &mut self,
589    ) -> Result<Box<Cursor<Vec<u8>>>, StorageError<u64>> {
590        Ok(Box::new(Cursor::new(Vec::new())))
591    }
592
593    async fn install_snapshot(
594        &mut self,
595        meta: &SnapshotMeta<u64, BasicNode>,
596        snapshot: Box<Cursor<Vec<u8>>>,
597    ) -> Result<(), StorageError<u64>> {
598        let data = snapshot.into_inner();
599        let snap: ClusterSnapshot = serde_json::from_slice(&data)
600            .map_err(|e| StorageIOError::read_snapshot(Some(meta.signature()), &e))?;
601
602        *self.last_applied.write().await = snap.last_applied;
603        *self.last_membership.write().await = snap.last_membership;
604
605        let state_data: ClusterStateData = serde_json::from_slice(&snap.state_data)
606            .map_err(|e| StorageIOError::read_snapshot(Some(meta.signature()), &e))?;
607        *self.state.write().await = state_data.clone();
608
609        *self.snapshot.write().await = Some(StoredSnapshot {
610            meta: meta.clone(),
611            data: data.clone(),
612        });
613
614        if let Some(disk) = &self.disk {
615            disk.lock()
616                .map_err(|e| StorageIOError::write(&io_error(&format!("lock poisoned: {e}"))))?
617                .write_snapshot(meta, &data)?;
618        }
619
620        // notify after snapshot install, same as after apply
621        let _ = self.state_tx.send_replace(state_data);
622
623        Ok(())
624    }
625
626    async fn get_current_snapshot(
627        &mut self,
628    ) -> Result<Option<Snapshot<TypeConfig>>, StorageError<u64>> {
629        let snap = self.snapshot.read().await;
630        Ok(snap.as_ref().map(|s| Snapshot {
631            meta: s.meta.clone(),
632            snapshot: Box::new(Cursor::new(s.data.clone())),
633        }))
634    }
635}
636
637// -- network implementation --
638
639/// Per-peer network handle. Opens a short-lived TCP connection per RPC call.
640///
641/// One connection per RPC is acceptable because Raft RPCs are infrequent:
642/// one heartbeat per 500 ms per follower.
643pub struct RaftNetworkClient {
644    target_addr: SocketAddr,
645    secret: Option<Arc<ClusterSecret>>,
646}
647
648impl RaftNetwork<TypeConfig> for RaftNetworkClient {
649    async fn append_entries(
650        &mut self,
651        rpc: AppendEntriesRequest<TypeConfig>,
652        _option: RPCOption,
653    ) -> Result<AppendEntriesResponse<u64>, RPCError<u64, BasicNode, RaftError<u64>>> {
654        let resp = self.call(RaftRpc::AppendEntries(rpc)).await?;
655        match resp {
656            RaftRpcResponse::AppendEntries(r) => Ok(r),
657            _ => Err(RPCError::Network(NetworkError::new(&io_error(
658                "unexpected response variant",
659            )))),
660        }
661    }
662
663    async fn vote(
664        &mut self,
665        rpc: VoteRequest<u64>,
666        _option: RPCOption,
667    ) -> Result<VoteResponse<u64>, RPCError<u64, BasicNode, RaftError<u64>>> {
668        let resp = self.call(RaftRpc::Vote(rpc)).await?;
669        match resp {
670            RaftRpcResponse::Vote(r) => Ok(r),
671            _ => Err(RPCError::Network(NetworkError::new(&io_error(
672                "unexpected response variant",
673            )))),
674        }
675    }
676
677    async fn install_snapshot(
678        &mut self,
679        rpc: InstallSnapshotRequest<TypeConfig>,
680        _option: RPCOption,
681    ) -> Result<
682        InstallSnapshotResponse<u64>,
683        RPCError<u64, BasicNode, RaftError<u64, InstallSnapshotError>>,
684    > {
685        let resp = self
686            .call_snapshot(RaftRpc::InstallSnapshot(rpc))
687            .await
688            .map_err(|e| RPCError::Network(NetworkError::new(&e)))?;
689        match resp {
690            RaftRpcResponse::InstallSnapshot(r) => Ok(r),
691            _ => Err(RPCError::Network(NetworkError::new(&io_error(
692                "unexpected response variant",
693            )))),
694        }
695    }
696}
697
698impl RaftNetworkClient {
699    /// Sends a Raft RPC to the target and returns the response.
700    ///
701    /// Opens a TCP connection, sends one frame, reads one frame, and closes.
702    async fn call(
703        &self,
704        rpc: RaftRpc,
705    ) -> Result<RaftRpcResponse, RPCError<u64, BasicNode, RaftError<u64>>> {
706        self.send_rpc(rpc)
707            .await
708            .map_err(|e| RPCError::Unreachable(Unreachable::new(&e)))
709    }
710
711    /// Same as `call` but maps errors to the install-snapshot error type.
712    async fn call_snapshot(&self, rpc: RaftRpc) -> std::io::Result<RaftRpcResponse> {
713        self.send_rpc(rpc).await
714    }
715
716    async fn send_rpc(&self, rpc: RaftRpc) -> std::io::Result<RaftRpcResponse> {
717        let mut stream = TcpStream::connect(self.target_addr).await?;
718        // disable Nagle's algorithm so Raft heartbeats and vote requests are
719        // sent immediately rather than buffered for up to 200ms.
720        stream.set_nodelay(true)?;
721        match &self.secret {
722            Some(secret) => {
723                write_frame_authenticated(&mut stream, &rpc, secret).await?;
724                read_frame_authenticated(&mut stream, secret).await
725            }
726            None => {
727                write_frame(&mut stream, &rpc).await?;
728                read_frame(&mut stream).await
729            }
730        }
731    }
732}
733
734/// Factory that creates per-peer `RaftNetworkClient` instances.
735///
736/// The `node.addr` field in `BasicNode` must contain `"ip:raft_port"`.
737pub struct RaftNetworkFactory {
738    secret: Option<Arc<ClusterSecret>>,
739}
740
741impl RaftNetworkFactoryTrait<TypeConfig> for RaftNetworkFactory {
742    type Network = RaftNetworkClient;
743
744    async fn new_client(&mut self, _target: u64, node: &BasicNode) -> RaftNetworkClient {
745        let target_addr = node
746            .addr
747            .parse()
748            .unwrap_or_else(|_| "127.0.0.1:0".parse().unwrap());
749        RaftNetworkClient {
750            target_addr,
751            secret: self.secret.clone(),
752        }
753    }
754}
755
756// -- TCP listener for inbound Raft RPCs --
757
758/// Spawns a task that accepts incoming Raft RPC connections.
759///
760/// Reads one `RaftRpc` frame, dispatches to the local Raft instance,
761/// writes one `RaftRpcResponse` frame, then closes the connection.
762/// When `secret` is `Some`, authenticated framing is used.
763pub(crate) fn spawn_raft_listener(
764    raft: Raft<TypeConfig>,
765    bind_addr: SocketAddr,
766    secret: Option<Arc<ClusterSecret>>,
767) {
768    tokio::spawn(async move {
769        let listener = match TcpListener::bind(bind_addr).await {
770            Ok(l) => l,
771            Err(e) => {
772                warn!("raft listener failed to bind on {bind_addr}: {e}");
773                return;
774            }
775        };
776
777        tracing::info!("raft listener on {bind_addr}");
778
779        loop {
780            let (mut stream, peer) = match listener.accept().await {
781                Ok(pair) => pair,
782                Err(e) => {
783                    warn!("raft accept error: {e}");
784                    continue;
785                }
786            };
787
788            let raft = raft.clone();
789            let secret = secret.clone();
790            tokio::spawn(async move {
791                let rpc: RaftRpc = match &secret {
792                    Some(s) => match read_frame_authenticated(&mut stream, s).await {
793                        Ok(r) => r,
794                        Err(e) => {
795                            debug!("raft auth/read error from {peer}: {e}");
796                            return;
797                        }
798                    },
799                    None => match read_frame(&mut stream).await {
800                        Ok(r) => r,
801                        Err(e) => {
802                            debug!("raft read error from {peer}: {e}");
803                            return;
804                        }
805                    },
806                };
807
808                let response = match rpc {
809                    RaftRpc::AppendEntries(req) => match raft.append_entries(req).await {
810                        Ok(r) => RaftRpcResponse::AppendEntries(r),
811                        Err(e) => {
812                            debug!("append_entries error: {e}");
813                            return;
814                        }
815                    },
816                    RaftRpc::Vote(req) => match raft.vote(req).await {
817                        Ok(r) => RaftRpcResponse::Vote(r),
818                        Err(e) => {
819                            debug!("vote error: {e}");
820                            return;
821                        }
822                    },
823                    RaftRpc::InstallSnapshot(req) => {
824                        // convert chunked snapshot to full snapshot for install
825                        let vote = req.vote;
826                        let meta = req.meta.clone();
827                        let data = req.data.clone();
828                        let snapshot = Snapshot {
829                            meta,
830                            snapshot: Box::new(Cursor::new(data)),
831                        };
832                        match raft.install_full_snapshot(vote, snapshot).await {
833                            Ok(r) => RaftRpcResponse::InstallSnapshot(InstallSnapshotResponse {
834                                vote: r.vote,
835                            }),
836                            Err(e) => {
837                                debug!("install_snapshot error: {e}");
838                                return;
839                            }
840                        }
841                    }
842                };
843
844                let write_result = match &secret {
845                    Some(s) => write_frame_authenticated(&mut stream, &response, s).await,
846                    None => write_frame(&mut stream, &response).await,
847                };
848                if let Err(e) = write_result {
849                    debug!("raft write error to {peer}: {e}");
850                }
851            });
852        }
853    });
854}
855
856// -- RaftNode wrapper --
857
858/// Error from proposing a command through Raft.
859#[derive(Debug)]
860pub enum RaftProposalError {
861    /// This node is not the leader. The leader's address is provided when known.
862    NotLeader(Option<BasicNode>),
863    /// Fatal Raft error.
864    Fatal(String),
865}
866
867impl std::fmt::Display for RaftProposalError {
868    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
869        match self {
870            RaftProposalError::NotLeader(Some(node)) => {
871                write!(f, "not leader, leader at {}", node.addr)
872            }
873            RaftProposalError::NotLeader(None) => write!(f, "no leader elected"),
874            RaftProposalError::Fatal(msg) => write!(f, "raft fatal: {msg}"),
875        }
876    }
877}
878
879/// High-level Raft node wrapper.
880///
881/// Owns the `Raft<TypeConfig>` instance and exposes the operations needed
882/// by `ClusterCoordinator`: proposing mutations and checking leader status.
883pub struct RaftNode {
884    raft: Raft<TypeConfig>,
885    local_raft_id: u64,
886    local_raft_addr: SocketAddr,
887}
888
889impl RaftNode {
890    /// Starts a Raft node bound to `raft_addr`.
891    ///
892    /// Creates the Raft consensus engine, wraps `storage` with the openraft
893    /// adaptor, and spawns the TCP listener for inbound RPCs. When `secret`
894    /// is `Some`, all Raft frames are authenticated with HMAC-SHA256.
895    pub async fn start(
896        local_raft_id: u64,
897        raft_addr: SocketAddr,
898        storage: Arc<Storage>,
899        secret: Option<Arc<ClusterSecret>>,
900    ) -> Result<Self, openraft::error::Fatal<u64>> {
901        let config = Arc::new(
902            Config {
903                cluster_name: "ember".to_string(),
904                heartbeat_interval: 500,
905                election_timeout_min: 1500,
906                election_timeout_max: 3000,
907                ..Config::default()
908            }
909            .validate()
910            .expect("raft config validation failed"),
911        );
912
913        let (log_store, state_machine) = Adaptor::new(Arc::clone(&storage));
914
915        let network_factory = RaftNetworkFactory {
916            secret: secret.clone(),
917        };
918
919        let raft = Raft::new(
920            local_raft_id,
921            config,
922            network_factory,
923            log_store,
924            state_machine,
925        )
926        .await?;
927
928        spawn_raft_listener(raft.clone(), raft_addr, secret);
929
930        Ok(Self {
931            raft,
932            local_raft_id,
933            local_raft_addr: raft_addr,
934        })
935    }
936
937    /// Initializes a single-node cluster.
938    ///
939    /// Must only be called once, on first boot, before any log entries exist.
940    /// Subsequent boots should NOT call this — the existing log is sufficient.
941    pub async fn bootstrap_single(&self) -> Result<(), String> {
942        let mut members = BTreeMap::new();
943        members.insert(
944            self.local_raft_id,
945            BasicNode {
946                addr: self.local_raft_addr.to_string(),
947            },
948        );
949
950        self.raft
951            .initialize(members)
952            .await
953            .map_err(|e| e.to_string())
954    }
955
956    /// Proposes a cluster configuration change through Raft.
957    ///
958    /// Blocks until the entry is committed and applied to the state machine
959    /// on a quorum of nodes. Returns `NotLeader` if this node is not the leader.
960    pub async fn propose(&self, cmd: ClusterCommand) -> Result<ClusterResponse, RaftProposalError> {
961        match self.raft.client_write(cmd).await {
962            Ok(resp) => Ok(resp.data),
963            Err(e) => match e {
964                openraft::error::RaftError::APIError(ClientWriteError::ForwardToLeader(fwd)) => {
965                    Err(RaftProposalError::NotLeader(fwd.leader_node))
966                }
967                other => Err(RaftProposalError::Fatal(other.to_string())),
968            },
969        }
970    }
971
972    /// Returns `true` if this node is currently the Raft leader.
973    pub fn is_leader(&self) -> bool {
974        self.raft.metrics().borrow().state == ServerState::Leader
975    }
976
977    /// Returns the current leader's BasicNode info, if known.
978    pub fn current_leader_node(&self) -> Option<BasicNode> {
979        let m = self.raft.metrics().borrow().clone();
980        let leader_id = m.current_leader?;
981        m.membership_config
982            .membership()
983            .get_node(&leader_id)
984            .cloned()
985    }
986
987    /// Exposes the underlying `Raft` handle for membership management.
988    pub fn raft_handle(&self) -> &Raft<TypeConfig> {
989        &self.raft
990    }
991
992    pub fn local_raft_id(&self) -> u64 {
993        self.local_raft_id
994    }
995
996    pub fn raft_addr(&self) -> SocketAddr {
997        self.local_raft_addr
998    }
999}
1000
1001/// Derives a stable `u64` raft ID from a `NodeId` UUID.
1002///
1003/// Uses the upper 64 bits of the UUID which are as random as the full value.
1004pub fn raft_id_from_node_id(node_id: NodeId) -> u64 {
1005    node_id.0.as_u64_pair().0
1006}
1007
1008fn io_error(msg: &str) -> std::io::Error {
1009    std::io::Error::other(msg)
1010}
1011
1012#[cfg(test)]
1013mod tests {
1014    use super::*;
1015    use openraft::CommittedLeaderId;
1016
1017    /// Helper to create a LogId for tests.
1018    fn log_id(term: u64, index: u64) -> LogId<u64> {
1019        LogId::new(CommittedLeaderId::new(term, 0), index)
1020    }
1021
1022    #[tokio::test]
1023    async fn storage_add_node() {
1024        let (storage, _rx) = Storage::new();
1025        let mut storage_clone = Arc::clone(&storage);
1026
1027        let node_id = NodeId::new();
1028        let entry = Entry {
1029            log_id: log_id(1, 1),
1030            payload: EntryPayload::Normal(ClusterCommand::AddNode {
1031                node_id,
1032                raft_id: 1,
1033                addr: "127.0.0.1:6379".to_string(),
1034                is_primary: true,
1035            }),
1036        };
1037
1038        let results = storage_clone
1039            .apply_to_state_machine(&[entry])
1040            .await
1041            .unwrap();
1042        assert_eq!(results, vec![ClusterResponse::Ok]);
1043
1044        let state_arc = storage.state();
1045        let state = state_arc.read().await;
1046        assert!(state.nodes.contains_key(&node_id.as_key()));
1047    }
1048
1049    #[tokio::test]
1050    async fn storage_assign_slots() {
1051        let (storage, _rx) = Storage::new();
1052        let mut storage_clone = Arc::clone(&storage);
1053
1054        let node_id = NodeId::new();
1055
1056        // Add node first
1057        let add_entry = Entry {
1058            log_id: log_id(1, 1),
1059            payload: EntryPayload::Normal(ClusterCommand::AddNode {
1060                node_id,
1061                raft_id: 1,
1062                addr: "127.0.0.1:6379".to_string(),
1063                is_primary: true,
1064            }),
1065        };
1066        storage_clone
1067            .apply_to_state_machine(&[add_entry])
1068            .await
1069            .unwrap();
1070
1071        // Assign slots
1072        let assign_entry = Entry {
1073            log_id: log_id(1, 2),
1074            payload: EntryPayload::Normal(ClusterCommand::AssignSlots {
1075                node_id,
1076                slots: vec![SlotRange::new(0, 5460)],
1077            }),
1078        };
1079        let results = storage_clone
1080            .apply_to_state_machine(&[assign_entry])
1081            .await
1082            .unwrap();
1083        assert_eq!(results, vec![ClusterResponse::Ok]);
1084
1085        let state_arc = storage.state();
1086        let state = state_arc.read().await;
1087        assert_eq!(state.slots.get(&0), Some(&node_id.as_key()));
1088        assert_eq!(state.slots.get(&5460), Some(&node_id.as_key()));
1089    }
1090
1091    #[tokio::test]
1092    async fn storage_remove_slots() {
1093        let (storage, _rx) = Storage::new();
1094        let mut s = Arc::clone(&storage);
1095        let node_id = NodeId::new();
1096
1097        let add = Entry {
1098            log_id: log_id(1, 1),
1099            payload: EntryPayload::Normal(ClusterCommand::AddNode {
1100                node_id,
1101                raft_id: 1,
1102                addr: "127.0.0.1:6379".into(),
1103                is_primary: true,
1104            }),
1105        };
1106        s.apply_to_state_machine(&[add]).await.unwrap();
1107
1108        let assign = Entry {
1109            log_id: log_id(1, 2),
1110            payload: EntryPayload::Normal(ClusterCommand::AssignSlots {
1111                node_id,
1112                slots: vec![SlotRange::new(0, 10)],
1113            }),
1114        };
1115        s.apply_to_state_machine(&[assign]).await.unwrap();
1116
1117        let remove = Entry {
1118            log_id: log_id(1, 3),
1119            payload: EntryPayload::Normal(ClusterCommand::RemoveSlots {
1120                node_id,
1121                slots: vec![SlotRange::new(0, 5)],
1122            }),
1123        };
1124        let results = s.apply_to_state_machine(&[remove]).await.unwrap();
1125        assert_eq!(results, vec![ClusterResponse::Ok]);
1126
1127        let state = storage.state();
1128        let state = state.read().await;
1129        assert!(!state.slots.contains_key(&0));
1130        assert!(state.slots.contains_key(&6));
1131    }
1132
1133    #[tokio::test]
1134    async fn storage_migration() {
1135        let (storage, _rx) = Storage::new();
1136        let mut storage_clone = Arc::clone(&storage);
1137
1138        let node1 = NodeId::new();
1139        let node2 = NodeId::new();
1140
1141        // Add nodes
1142        let entries: Vec<Entry<TypeConfig>> = [node1, node2]
1143            .iter()
1144            .enumerate()
1145            .map(|(i, node_id)| Entry {
1146                log_id: log_id(1, i as u64 + 1),
1147                payload: EntryPayload::Normal(ClusterCommand::AddNode {
1148                    node_id: *node_id,
1149                    raft_id: i as u64 + 1,
1150                    addr: format!("127.0.0.1:{}", 6379 + i),
1151                    is_primary: true,
1152                }),
1153            })
1154            .collect();
1155        storage_clone
1156            .apply_to_state_machine(&entries)
1157            .await
1158            .unwrap();
1159
1160        // Begin migration
1161        let begin_entry = Entry {
1162            log_id: log_id(1, 3),
1163            payload: EntryPayload::Normal(ClusterCommand::BeginMigration {
1164                slot: 100,
1165                from: node1,
1166                to: node2,
1167            }),
1168        };
1169        storage_clone
1170            .apply_to_state_machine(&[begin_entry])
1171            .await
1172            .unwrap();
1173
1174        {
1175            let state_arc = storage.state();
1176            let state = state_arc.read().await;
1177            assert!(state.migrations.contains_key(&100));
1178        }
1179
1180        // Complete migration
1181        let complete_entry = Entry {
1182            log_id: log_id(1, 4),
1183            payload: EntryPayload::Normal(ClusterCommand::CompleteMigration {
1184                slot: 100,
1185                new_owner: node2,
1186            }),
1187        };
1188        storage_clone
1189            .apply_to_state_machine(&[complete_entry])
1190            .await
1191            .unwrap();
1192
1193        {
1194            let state_arc = storage.state();
1195            let state = state_arc.read().await;
1196            assert!(!state.migrations.contains_key(&100));
1197            assert_eq!(state.slots.get(&100), Some(&node2.0.to_string()));
1198        }
1199    }
1200
1201    #[tokio::test]
1202    async fn assign_slots_rejects_invalid_range() {
1203        let (storage, _rx) = Storage::new();
1204        let mut s = Arc::clone(&storage);
1205
1206        let node_id = NodeId::new();
1207        let add = Entry {
1208            log_id: log_id(1, 1),
1209            payload: EntryPayload::Normal(ClusterCommand::AddNode {
1210                node_id,
1211                raft_id: 1,
1212                addr: "127.0.0.1:6379".into(),
1213                is_primary: true,
1214            }),
1215        };
1216        s.apply_to_state_machine(&[add]).await.unwrap();
1217
1218        // craft a SlotRange with start > end (bypassing SlotRange::new)
1219        let bad_range = SlotRange {
1220            start: 100,
1221            end: 50,
1222        };
1223        let assign = Entry {
1224            log_id: log_id(1, 2),
1225            payload: EntryPayload::Normal(ClusterCommand::AssignSlots {
1226                node_id,
1227                slots: vec![bad_range],
1228            }),
1229        };
1230        let results = s.apply_to_state_machine(&[assign]).await.unwrap();
1231        assert!(
1232            matches!(&results[0], ClusterResponse::Error(msg) if msg.contains("invalid slot range"))
1233        );
1234    }
1235
1236    #[tokio::test]
1237    async fn assign_slots_rejects_out_of_range() {
1238        let (storage, _rx) = Storage::new();
1239        let mut s = Arc::clone(&storage);
1240
1241        let node_id = NodeId::new();
1242        let add = Entry {
1243            log_id: log_id(1, 1),
1244            payload: EntryPayload::Normal(ClusterCommand::AddNode {
1245                node_id,
1246                raft_id: 1,
1247                addr: "127.0.0.1:6379".into(),
1248                is_primary: true,
1249            }),
1250        };
1251        s.apply_to_state_machine(&[add]).await.unwrap();
1252
1253        // slot end >= SLOT_COUNT
1254        let bad_range = SlotRange {
1255            start: 0,
1256            end: 16384,
1257        };
1258        let assign = Entry {
1259            log_id: log_id(1, 2),
1260            payload: EntryPayload::Normal(ClusterCommand::AssignSlots {
1261                node_id,
1262                slots: vec![bad_range],
1263            }),
1264        };
1265        let results = s.apply_to_state_machine(&[assign]).await.unwrap();
1266        assert!(
1267            matches!(&results[0], ClusterResponse::Error(msg) if msg.contains("invalid slot range"))
1268        );
1269    }
1270
1271    #[tokio::test]
1272    async fn complete_migration_without_begin_errors() {
1273        let (storage, _rx) = Storage::new();
1274        let mut s = Arc::clone(&storage);
1275
1276        let node_id = NodeId::new();
1277        let complete = Entry {
1278            log_id: log_id(1, 1),
1279            payload: EntryPayload::Normal(ClusterCommand::CompleteMigration {
1280                slot: 100,
1281                new_owner: node_id,
1282            }),
1283        };
1284        let results = s.apply_to_state_machine(&[complete]).await.unwrap();
1285        assert!(matches!(&results[0], ClusterResponse::Error(msg) if msg.contains("no migration")));
1286    }
1287
1288    #[tokio::test]
1289    async fn begin_migration_rejects_invalid_slot() {
1290        let (storage, _rx) = Storage::new();
1291        let mut s = Arc::clone(&storage);
1292
1293        let node1 = NodeId::new();
1294        let node2 = NodeId::new();
1295        let begin = Entry {
1296            log_id: log_id(1, 1),
1297            payload: EntryPayload::Normal(ClusterCommand::BeginMigration {
1298                slot: 16384,
1299                from: node1,
1300                to: node2,
1301            }),
1302        };
1303        let results = s.apply_to_state_machine(&[begin]).await.unwrap();
1304        assert!(matches!(&results[0], ClusterResponse::Error(msg) if msg.contains("out of range")));
1305    }
1306
1307    #[tokio::test]
1308    async fn storage_log_operations() {
1309        let (storage, _rx) = Storage::new();
1310        let mut storage_clone = Arc::clone(&storage);
1311
1312        let entry = Entry::<TypeConfig> {
1313            log_id: log_id(1, 1),
1314            payload: EntryPayload::Blank,
1315        };
1316
1317        storage_clone.append_to_log(vec![entry]).await.unwrap();
1318
1319        let state = storage_clone.get_log_state().await.unwrap();
1320        assert_eq!(state.last_log_id, Some(log_id(1, 1)));
1321    }
1322
1323    #[tokio::test]
1324    async fn storage_vote() {
1325        let (storage, _rx) = Storage::new();
1326        let mut storage_clone = Arc::clone(&storage);
1327
1328        let vote = Vote::new(1, 1);
1329        storage_clone.save_vote(&vote).await.unwrap();
1330
1331        let read_vote = storage_clone.read_vote().await.unwrap();
1332        assert_eq!(read_vote, Some(vote));
1333    }
1334
1335    #[tokio::test]
1336    async fn watch_channel_notified_on_apply() {
1337        let (storage, mut rx) = Storage::new();
1338        let mut s = Arc::clone(&storage);
1339
1340        let node_id = NodeId::new();
1341        let entry = Entry {
1342            log_id: log_id(1, 1),
1343            payload: EntryPayload::Normal(ClusterCommand::AddNode {
1344                node_id,
1345                raft_id: 1,
1346                addr: "127.0.0.1:6379".into(),
1347                is_primary: true,
1348            }),
1349        };
1350
1351        // the watch channel starts with the initial state, so borrow it first to
1352        // mark it as seen, then apply — changed() should fire
1353        let _ = rx.borrow_and_update();
1354
1355        s.apply_to_state_machine(&[entry]).await.unwrap();
1356
1357        assert!(
1358            rx.changed().await.is_ok(),
1359            "watch channel should have fired"
1360        );
1361        let data = rx.borrow();
1362        assert!(data.nodes.contains_key(&node_id.as_key()));
1363    }
1364}