1use std::collections::BTreeMap;
8use std::fmt::Debug;
9use std::io::Cursor;
10use std::net::SocketAddr;
11use std::ops::RangeBounds;
12use std::path::PathBuf;
13use std::sync::Arc;
14
15use openraft::error::{
16 ClientWriteError, InstallSnapshotError, NetworkError, RPCError, RaftError, Unreachable,
17};
18use openraft::network::{RPCOption, RaftNetwork, RaftNetworkFactory as RaftNetworkFactoryTrait};
19use openraft::raft::{
20 AppendEntriesRequest, AppendEntriesResponse, InstallSnapshotRequest, InstallSnapshotResponse,
21 VoteRequest, VoteResponse,
22};
23use openraft::storage::{Adaptor, LogState, RaftLogReader, RaftSnapshotBuilder, Snapshot};
24use openraft::{
25 BasicNode, Config, Entry, EntryPayload, LogId, OptionalSend, Raft, RaftStorage, RaftTypeConfig,
26 ServerState, SnapshotMeta, StorageError, StorageIOError, StoredMembership, Vote,
27};
28use serde::{Deserialize, Serialize};
29use tokio::net::{TcpListener, TcpStream};
30use tokio::sync::{watch, RwLock};
31use tracing::{debug, warn};
32
33use crate::raft_log::{RaftDisk, RaftDiskError};
34
35use crate::auth::ClusterSecret;
36use crate::raft_transport::{
37 read_frame, read_frame_authenticated, write_frame, write_frame_authenticated, RaftRpc,
38 RaftRpcResponse,
39};
40use crate::slots::SLOT_COUNT;
41use crate::{NodeId, SlotRange};
42
43#[derive(Debug, Clone, Copy, Default, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
45pub struct TypeConfig;
46
47impl RaftTypeConfig for TypeConfig {
48 type D = ClusterCommand;
49 type R = ClusterResponse;
50 type Node = BasicNode;
51 type NodeId = u64;
52 type Entry = Entry<TypeConfig>;
53 type SnapshotData = Cursor<Vec<u8>>;
54 type AsyncRuntime = openraft::TokioRuntime;
55 type Responder = openraft::impls::OneshotResponder<TypeConfig>;
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
63pub enum ClusterCommand {
64 AddNode {
66 node_id: NodeId,
67 raft_id: u64,
68 addr: String,
69 is_primary: bool,
70 },
71 RemoveNode { node_id: NodeId },
73 AssignSlots {
75 node_id: NodeId,
76 slots: Vec<SlotRange>,
77 },
78 RemoveSlots {
80 node_id: NodeId,
81 slots: Vec<SlotRange>,
82 },
83 PromoteReplica { replica_id: NodeId },
85 BeginMigration { slot: u16, from: NodeId, to: NodeId },
87 CompleteMigration { slot: u16, new_owner: NodeId },
89}
90
91#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
93pub enum ClusterResponse {
94 Ok,
95 Error(String),
96}
97
98#[derive(Debug, Clone, Serialize, Deserialize, Default)]
100pub struct ClusterSnapshot {
101 pub last_applied: Option<LogId<u64>>,
102 pub last_membership: StoredMembership<u64, BasicNode>,
103 pub state_data: Vec<u8>,
105}
106
107#[derive(Debug, Clone, Serialize, Deserialize, Default)]
109pub struct ClusterStateData {
110 pub nodes: BTreeMap<String, NodeInfo>,
112 pub slots: BTreeMap<u16, String>,
114 pub migrations: BTreeMap<u16, MigrationState>,
116}
117
118#[derive(Debug, Clone, Serialize, Deserialize)]
120pub struct NodeInfo {
121 pub node_id: String,
122 pub raft_id: u64,
123 pub addr: String,
124 pub is_primary: bool,
125 pub slots: Vec<SlotRange>,
126}
127
128#[derive(Debug, Clone, Serialize, Deserialize)]
130pub struct MigrationState {
131 pub from: String,
132 pub to: String,
133}
134
135#[derive(Debug)]
141pub struct Storage {
142 vote: RwLock<Option<Vote<u64>>>,
143 log: RwLock<BTreeMap<u64, Entry<TypeConfig>>>,
144 last_purged: RwLock<Option<LogId<u64>>>,
145 last_applied: RwLock<Option<LogId<u64>>>,
146 last_membership: RwLock<StoredMembership<u64, BasicNode>>,
147 snapshot: RwLock<Option<StoredSnapshot>>,
148 state: Arc<RwLock<ClusterStateData>>,
149 state_tx: watch::Sender<ClusterStateData>,
151 disk: Option<std::sync::Mutex<RaftDisk>>,
153}
154
155#[derive(Debug, Clone)]
156struct StoredSnapshot {
157 meta: SnapshotMeta<u64, BasicNode>,
158 data: Vec<u8>,
159}
160
161impl Default for Storage {
162 fn default() -> Self {
163 let (state_tx, _) = watch::channel(ClusterStateData::default());
165 Self {
166 vote: RwLock::new(None),
167 log: RwLock::new(BTreeMap::new()),
168 last_purged: RwLock::new(None),
169 last_applied: RwLock::new(None),
170 last_membership: RwLock::new(StoredMembership::default()),
171 snapshot: RwLock::new(None),
172 state: Arc::new(RwLock::new(ClusterStateData::default())),
173 state_tx,
174 disk: None,
175 }
176 }
177}
178
179impl Storage {
180 pub fn new() -> (Arc<Self>, watch::Receiver<ClusterStateData>) {
183 let (state_tx, state_rx) = watch::channel(ClusterStateData::default());
184 let storage = Arc::new(Self {
185 vote: RwLock::new(None),
186 log: RwLock::new(BTreeMap::new()),
187 last_purged: RwLock::new(None),
188 last_applied: RwLock::new(None),
189 last_membership: RwLock::new(StoredMembership::default()),
190 snapshot: RwLock::new(None),
191 state: Arc::new(RwLock::new(ClusterStateData::default())),
192 state_tx,
193 disk: None,
194 });
195 (storage, state_rx)
196 }
197
198 pub fn open(
204 raft_dir: PathBuf,
205 ) -> Result<(Arc<Self>, watch::Receiver<ClusterStateData>), RaftDiskError> {
206 let (raft_disk, recovered) = RaftDisk::open(&raft_dir)?;
207
208 let (state_tx, state_rx) = watch::channel(ClusterStateData::default());
209
210 let snapshot = recovered
211 .snapshot
212 .map(|(meta, data)| StoredSnapshot { meta, data });
213
214 let storage = Arc::new(Self {
215 vote: RwLock::new(recovered.vote),
216 log: RwLock::new(recovered.log),
217 last_purged: RwLock::new(recovered.last_purged),
218 last_applied: RwLock::new(None),
219 last_membership: RwLock::new(StoredMembership::default()),
220 snapshot: RwLock::new(snapshot),
221 state: Arc::new(RwLock::new(ClusterStateData::default())),
222 state_tx,
223 disk: Some(std::sync::Mutex::new(raft_disk)),
224 });
225
226 Ok((storage, state_rx))
227 }
228
229 pub fn has_log_entries(&self) -> bool {
234 self.log
237 .try_read()
238 .map(|log| !log.is_empty())
239 .unwrap_or(false)
240 }
241
242 pub fn state(&self) -> Arc<RwLock<ClusterStateData>> {
243 Arc::clone(&self.state)
244 }
245
246 fn apply_command(cmd: &ClusterCommand, state: &mut ClusterStateData) -> ClusterResponse {
247 match cmd {
248 ClusterCommand::AddNode {
249 node_id,
250 raft_id,
251 addr,
252 is_primary,
253 } => {
254 let key = node_id.as_key();
255 state.nodes.insert(
256 key.clone(),
257 NodeInfo {
258 node_id: key,
259 raft_id: *raft_id,
260 addr: addr.clone(),
261 is_primary: *is_primary,
262 slots: Vec::new(),
263 },
264 );
265 ClusterResponse::Ok
266 }
267
268 ClusterCommand::RemoveNode { node_id } => {
269 let key = node_id.as_key();
270 state.nodes.remove(&key);
271 state.slots.retain(|_, owner| owner != &key);
272 ClusterResponse::Ok
273 }
274
275 ClusterCommand::AssignSlots { node_id, slots } => {
276 for range in slots {
278 if range.start > range.end || range.end >= SLOT_COUNT {
279 return ClusterResponse::Error(format!(
280 "invalid slot range {}..={} (max {})",
281 range.start,
282 range.end,
283 SLOT_COUNT - 1
284 ));
285 }
286 }
287 let key = node_id.as_key();
288 if let Some(node) = state.nodes.get_mut(&key) {
289 node.slots = slots.clone();
290 for slot_range in slots {
291 for slot in slot_range.start..=slot_range.end {
292 state.slots.insert(slot, key.clone());
293 }
294 }
295 ClusterResponse::Ok
296 } else {
297 ClusterResponse::Error(format!("node {} not found", node_id))
298 }
299 }
300
301 ClusterCommand::RemoveSlots { node_id, slots } => {
302 for range in slots {
303 if range.start > range.end || range.end >= SLOT_COUNT {
304 return ClusterResponse::Error(format!(
305 "invalid slot range {}..={} (max {})",
306 range.start,
307 range.end,
308 SLOT_COUNT - 1
309 ));
310 }
311 }
312 let key = node_id.as_key();
313 for slot_range in slots {
314 for slot in slot_range.start..=slot_range.end {
315 if state.slots.get(&slot).map(|s| s.as_str()) == Some(key.as_str()) {
317 state.slots.remove(&slot);
318 }
319 }
320 }
321 let remaining = slots_for_node_in_state(state, &key);
323 if let Some(node) = state.nodes.get_mut(&key) {
324 node.slots = remaining;
325 }
326 ClusterResponse::Ok
327 }
328
329 ClusterCommand::PromoteReplica { replica_id } => {
330 let key = replica_id.as_key();
331 if let Some(node) = state.nodes.get_mut(&key) {
332 node.is_primary = true;
333 ClusterResponse::Ok
334 } else {
335 ClusterResponse::Error(format!("replica {} not found", replica_id))
336 }
337 }
338
339 ClusterCommand::BeginMigration { slot, from, to } => {
340 if *slot >= SLOT_COUNT {
341 return ClusterResponse::Error(format!(
342 "slot {slot} out of range (max {})",
343 SLOT_COUNT - 1
344 ));
345 }
346 state.migrations.insert(
347 *slot,
348 MigrationState {
349 from: from.as_key(),
350 to: to.as_key(),
351 },
352 );
353 ClusterResponse::Ok
354 }
355
356 ClusterCommand::CompleteMigration { slot, new_owner } => {
357 if !state.migrations.contains_key(slot) {
358 return ClusterResponse::Error(format!(
359 "no migration in progress for slot {slot}"
360 ));
361 }
362 state.migrations.remove(slot);
363 let key = new_owner.as_key();
364 state.slots.insert(*slot, key);
365 ClusterResponse::Ok
366 }
367 }
368 }
369}
370
371fn slots_for_node_in_state(state: &ClusterStateData, node_key: &str) -> Vec<SlotRange> {
373 let mut slots: Vec<u16> = state
374 .slots
375 .iter()
376 .filter(|(_, v)| v.as_str() == node_key)
377 .map(|(k, _)| *k)
378 .collect();
379 slots.sort_unstable();
380
381 let mut ranges = Vec::new();
383 let mut i = 0;
384 while i < slots.len() {
385 let start = slots[i];
386 let mut end = start;
387 while i + 1 < slots.len() && slots[i + 1] == end + 1 {
388 i += 1;
389 end = slots[i];
390 }
391 ranges.push(SlotRange::new(start, end));
392 i += 1;
393 }
394 ranges
395}
396
397impl RaftLogReader<TypeConfig> for Arc<Storage> {
398 async fn try_get_log_entries<RB: RangeBounds<u64> + Clone + Debug + OptionalSend>(
399 &mut self,
400 range: RB,
401 ) -> Result<Vec<Entry<TypeConfig>>, StorageError<u64>> {
402 let log = self.log.read().await;
403 Ok(log.range(range).map(|(_, v)| v.clone()).collect())
404 }
405}
406
407impl RaftSnapshotBuilder<TypeConfig> for Arc<Storage> {
408 async fn build_snapshot(&mut self) -> Result<Snapshot<TypeConfig>, StorageError<u64>> {
409 let last_applied = *self.last_applied.read().await;
410 let membership = self.last_membership.read().await.clone();
411 let state = self.state.read().await;
412
413 let state_data =
414 serde_json::to_vec(&*state).map_err(|e| StorageIOError::write_snapshot(None, &e))?;
415
416 let snapshot = ClusterSnapshot {
417 last_applied,
418 last_membership: membership.clone(),
419 state_data,
420 };
421
422 let data =
423 serde_json::to_vec(&snapshot).map_err(|e| StorageIOError::write_snapshot(None, &e))?;
424
425 let snapshot_id = last_applied
426 .map(|id| format!("{}-{}", id.leader_id, id.index))
427 .unwrap_or_else(|| "0-0".to_string());
428
429 let meta = SnapshotMeta {
430 last_log_id: last_applied,
431 last_membership: membership,
432 snapshot_id,
433 };
434
435 *self.snapshot.write().await = Some(StoredSnapshot {
437 meta: meta.clone(),
438 data: data.clone(),
439 });
440
441 if let Some(disk) = &self.disk {
442 let d = disk
443 .lock()
444 .map_err(|e| StorageIOError::write(&io_error(&format!("lock poisoned: {e}"))))?;
445 d.write_snapshot(&meta, &data).map_err(StorageError::from)?;
446 }
447
448 Ok(Snapshot {
449 meta,
450 snapshot: Box::new(Cursor::new(data)),
451 })
452 }
453}
454
455impl RaftStorage<TypeConfig> for Arc<Storage> {
456 type LogReader = Self;
457 type SnapshotBuilder = Self;
458
459 async fn get_log_state(&mut self) -> Result<LogState<TypeConfig>, StorageError<u64>> {
460 let log = self.log.read().await;
461 let last = log.iter().next_back().map(|(_, e)| e.log_id);
462 let purged = *self.last_purged.read().await;
463
464 Ok(LogState {
465 last_purged_log_id: purged,
466 last_log_id: last,
467 })
468 }
469
470 async fn save_vote(&mut self, vote: &Vote<u64>) -> Result<(), StorageError<u64>> {
471 *self.vote.write().await = Some(*vote);
472 if let Some(disk) = &self.disk {
473 let last_purged = *self.last_purged.read().await;
474 disk.lock()
475 .map_err(|e| StorageIOError::write(&io_error(&format!("lock poisoned: {e}"))))?
476 .write_meta(Some(*vote), last_purged)?;
477 }
478 Ok(())
479 }
480
481 async fn read_vote(&mut self) -> Result<Option<Vote<u64>>, StorageError<u64>> {
482 Ok(*self.vote.read().await)
483 }
484
485 async fn get_log_reader(&mut self) -> Self::LogReader {
486 Arc::clone(self)
487 }
488
489 async fn append_to_log<I>(&mut self, entries: I) -> Result<(), StorageError<u64>>
490 where
491 I: IntoIterator<Item = Entry<TypeConfig>> + Send,
492 {
493 let mut log = self.log.write().await;
494 let new_entries: Vec<Entry<TypeConfig>> = entries.into_iter().collect();
495 for entry in &new_entries {
496 log.insert(entry.log_id.index, entry.clone());
497 }
498 if let Some(disk) = &self.disk {
499 disk.lock()
500 .map_err(|e| StorageIOError::write(&io_error(&format!("lock poisoned: {e}"))))?
501 .append_entries(&new_entries)?;
502 }
503 Ok(())
504 }
505
506 async fn delete_conflict_logs_since(
507 &mut self,
508 log_id: LogId<u64>,
509 ) -> Result<(), StorageError<u64>> {
510 let mut log = self.log.write().await;
511 let to_remove: Vec<_> = log.range(log_id.index..).map(|(k, _)| *k).collect();
512 for key in to_remove {
513 log.remove(&key);
514 }
515 if let Some(disk) = &self.disk {
516 disk.lock()
517 .map_err(|e| StorageIOError::write(&io_error(&format!("lock poisoned: {e}"))))?
518 .rewrite_log(&log)?;
519 }
520 Ok(())
521 }
522
523 async fn purge_logs_upto(&mut self, log_id: LogId<u64>) -> Result<(), StorageError<u64>> {
524 let mut log = self.log.write().await;
525 let to_remove: Vec<_> = log.range(..=log_id.index).map(|(k, _)| *k).collect();
526 for key in to_remove {
527 log.remove(&key);
528 }
529 *self.last_purged.write().await = Some(log_id);
530 if let Some(disk) = &self.disk {
531 let vote = *self.vote.read().await;
532 let mut d = disk
533 .lock()
534 .map_err(|e| StorageIOError::write(&io_error(&format!("lock poisoned: {e}"))))?;
535 d.write_meta(vote, Some(log_id))?;
536 d.rewrite_log(&log)?;
537 }
538 Ok(())
539 }
540
541 async fn last_applied_state(
542 &mut self,
543 ) -> Result<(Option<LogId<u64>>, StoredMembership<u64, BasicNode>), StorageError<u64>> {
544 let last_applied = *self.last_applied.read().await;
545 let membership = self.last_membership.read().await.clone();
546 Ok((last_applied, membership))
547 }
548
549 async fn apply_to_state_machine(
550 &mut self,
551 entries: &[Entry<TypeConfig>],
552 ) -> Result<Vec<ClusterResponse>, StorageError<u64>> {
553 let mut results = Vec::new();
554 let mut state = self.state.write().await;
555
556 for entry in entries {
557 *self.last_applied.write().await = Some(entry.log_id);
558
559 match &entry.payload {
560 EntryPayload::Blank => {
561 results.push(ClusterResponse::Ok);
562 }
563 EntryPayload::Normal(cmd) => {
564 let result = Storage::apply_command(cmd, &mut state);
565 results.push(result);
566 }
567 EntryPayload::Membership(m) => {
568 *self.last_membership.write().await =
569 StoredMembership::new(Some(entry.log_id), m.clone());
570 results.push(ClusterResponse::Ok);
571 }
572 }
573 }
574
575 let state_snapshot = state.clone();
577 drop(state);
578 let _ = self.state_tx.send_replace(state_snapshot);
579
580 Ok(results)
581 }
582
583 async fn get_snapshot_builder(&mut self) -> Self::SnapshotBuilder {
584 Arc::clone(self)
585 }
586
587 async fn begin_receiving_snapshot(
588 &mut self,
589 ) -> Result<Box<Cursor<Vec<u8>>>, StorageError<u64>> {
590 Ok(Box::new(Cursor::new(Vec::new())))
591 }
592
593 async fn install_snapshot(
594 &mut self,
595 meta: &SnapshotMeta<u64, BasicNode>,
596 snapshot: Box<Cursor<Vec<u8>>>,
597 ) -> Result<(), StorageError<u64>> {
598 let data = snapshot.into_inner();
599 let snap: ClusterSnapshot = serde_json::from_slice(&data)
600 .map_err(|e| StorageIOError::read_snapshot(Some(meta.signature()), &e))?;
601
602 *self.last_applied.write().await = snap.last_applied;
603 *self.last_membership.write().await = snap.last_membership;
604
605 let state_data: ClusterStateData = serde_json::from_slice(&snap.state_data)
606 .map_err(|e| StorageIOError::read_snapshot(Some(meta.signature()), &e))?;
607 *self.state.write().await = state_data.clone();
608
609 *self.snapshot.write().await = Some(StoredSnapshot {
610 meta: meta.clone(),
611 data: data.clone(),
612 });
613
614 if let Some(disk) = &self.disk {
615 disk.lock()
616 .map_err(|e| StorageIOError::write(&io_error(&format!("lock poisoned: {e}"))))?
617 .write_snapshot(meta, &data)?;
618 }
619
620 let _ = self.state_tx.send_replace(state_data);
622
623 Ok(())
624 }
625
626 async fn get_current_snapshot(
627 &mut self,
628 ) -> Result<Option<Snapshot<TypeConfig>>, StorageError<u64>> {
629 let snap = self.snapshot.read().await;
630 Ok(snap.as_ref().map(|s| Snapshot {
631 meta: s.meta.clone(),
632 snapshot: Box::new(Cursor::new(s.data.clone())),
633 }))
634 }
635}
636
637pub struct RaftNetworkClient {
644 target_addr: SocketAddr,
645 secret: Option<Arc<ClusterSecret>>,
646}
647
648impl RaftNetwork<TypeConfig> for RaftNetworkClient {
649 async fn append_entries(
650 &mut self,
651 rpc: AppendEntriesRequest<TypeConfig>,
652 _option: RPCOption,
653 ) -> Result<AppendEntriesResponse<u64>, RPCError<u64, BasicNode, RaftError<u64>>> {
654 let resp = self.call(RaftRpc::AppendEntries(rpc)).await?;
655 match resp {
656 RaftRpcResponse::AppendEntries(r) => Ok(r),
657 _ => Err(RPCError::Network(NetworkError::new(&io_error(
658 "unexpected response variant",
659 )))),
660 }
661 }
662
663 async fn vote(
664 &mut self,
665 rpc: VoteRequest<u64>,
666 _option: RPCOption,
667 ) -> Result<VoteResponse<u64>, RPCError<u64, BasicNode, RaftError<u64>>> {
668 let resp = self.call(RaftRpc::Vote(rpc)).await?;
669 match resp {
670 RaftRpcResponse::Vote(r) => Ok(r),
671 _ => Err(RPCError::Network(NetworkError::new(&io_error(
672 "unexpected response variant",
673 )))),
674 }
675 }
676
677 async fn install_snapshot(
678 &mut self,
679 rpc: InstallSnapshotRequest<TypeConfig>,
680 _option: RPCOption,
681 ) -> Result<
682 InstallSnapshotResponse<u64>,
683 RPCError<u64, BasicNode, RaftError<u64, InstallSnapshotError>>,
684 > {
685 let resp = self
686 .call_snapshot(RaftRpc::InstallSnapshot(rpc))
687 .await
688 .map_err(|e| RPCError::Network(NetworkError::new(&e)))?;
689 match resp {
690 RaftRpcResponse::InstallSnapshot(r) => Ok(r),
691 _ => Err(RPCError::Network(NetworkError::new(&io_error(
692 "unexpected response variant",
693 )))),
694 }
695 }
696}
697
698impl RaftNetworkClient {
699 async fn call(
703 &self,
704 rpc: RaftRpc,
705 ) -> Result<RaftRpcResponse, RPCError<u64, BasicNode, RaftError<u64>>> {
706 self.send_rpc(rpc)
707 .await
708 .map_err(|e| RPCError::Unreachable(Unreachable::new(&e)))
709 }
710
711 async fn call_snapshot(&self, rpc: RaftRpc) -> std::io::Result<RaftRpcResponse> {
713 self.send_rpc(rpc).await
714 }
715
716 async fn send_rpc(&self, rpc: RaftRpc) -> std::io::Result<RaftRpcResponse> {
717 let mut stream = TcpStream::connect(self.target_addr).await?;
718 stream.set_nodelay(true)?;
721 match &self.secret {
722 Some(secret) => {
723 write_frame_authenticated(&mut stream, &rpc, secret).await?;
724 read_frame_authenticated(&mut stream, secret).await
725 }
726 None => {
727 write_frame(&mut stream, &rpc).await?;
728 read_frame(&mut stream).await
729 }
730 }
731 }
732}
733
734pub struct RaftNetworkFactory {
738 secret: Option<Arc<ClusterSecret>>,
739}
740
741impl RaftNetworkFactoryTrait<TypeConfig> for RaftNetworkFactory {
742 type Network = RaftNetworkClient;
743
744 async fn new_client(&mut self, _target: u64, node: &BasicNode) -> RaftNetworkClient {
745 let target_addr = node
746 .addr
747 .parse()
748 .unwrap_or_else(|_| "127.0.0.1:0".parse().unwrap());
749 RaftNetworkClient {
750 target_addr,
751 secret: self.secret.clone(),
752 }
753 }
754}
755
756pub(crate) fn spawn_raft_listener(
764 raft: Raft<TypeConfig>,
765 bind_addr: SocketAddr,
766 secret: Option<Arc<ClusterSecret>>,
767) {
768 tokio::spawn(async move {
769 let listener = match TcpListener::bind(bind_addr).await {
770 Ok(l) => l,
771 Err(e) => {
772 warn!("raft listener failed to bind on {bind_addr}: {e}");
773 return;
774 }
775 };
776
777 tracing::info!("raft listener on {bind_addr}");
778
779 loop {
780 let (mut stream, peer) = match listener.accept().await {
781 Ok(pair) => pair,
782 Err(e) => {
783 warn!("raft accept error: {e}");
784 continue;
785 }
786 };
787
788 let raft = raft.clone();
789 let secret = secret.clone();
790 tokio::spawn(async move {
791 let rpc: RaftRpc = match &secret {
792 Some(s) => match read_frame_authenticated(&mut stream, s).await {
793 Ok(r) => r,
794 Err(e) => {
795 debug!("raft auth/read error from {peer}: {e}");
796 return;
797 }
798 },
799 None => match read_frame(&mut stream).await {
800 Ok(r) => r,
801 Err(e) => {
802 debug!("raft read error from {peer}: {e}");
803 return;
804 }
805 },
806 };
807
808 let response = match rpc {
809 RaftRpc::AppendEntries(req) => match raft.append_entries(req).await {
810 Ok(r) => RaftRpcResponse::AppendEntries(r),
811 Err(e) => {
812 debug!("append_entries error: {e}");
813 return;
814 }
815 },
816 RaftRpc::Vote(req) => match raft.vote(req).await {
817 Ok(r) => RaftRpcResponse::Vote(r),
818 Err(e) => {
819 debug!("vote error: {e}");
820 return;
821 }
822 },
823 RaftRpc::InstallSnapshot(req) => {
824 let vote = req.vote;
826 let meta = req.meta.clone();
827 let data = req.data.clone();
828 let snapshot = Snapshot {
829 meta,
830 snapshot: Box::new(Cursor::new(data)),
831 };
832 match raft.install_full_snapshot(vote, snapshot).await {
833 Ok(r) => RaftRpcResponse::InstallSnapshot(InstallSnapshotResponse {
834 vote: r.vote,
835 }),
836 Err(e) => {
837 debug!("install_snapshot error: {e}");
838 return;
839 }
840 }
841 }
842 };
843
844 let write_result = match &secret {
845 Some(s) => write_frame_authenticated(&mut stream, &response, s).await,
846 None => write_frame(&mut stream, &response).await,
847 };
848 if let Err(e) = write_result {
849 debug!("raft write error to {peer}: {e}");
850 }
851 });
852 }
853 });
854}
855
856#[derive(Debug)]
860pub enum RaftProposalError {
861 NotLeader(Option<BasicNode>),
863 Fatal(String),
865}
866
867impl std::fmt::Display for RaftProposalError {
868 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
869 match self {
870 RaftProposalError::NotLeader(Some(node)) => {
871 write!(f, "not leader, leader at {}", node.addr)
872 }
873 RaftProposalError::NotLeader(None) => write!(f, "no leader elected"),
874 RaftProposalError::Fatal(msg) => write!(f, "raft fatal: {msg}"),
875 }
876 }
877}
878
879pub struct RaftNode {
884 raft: Raft<TypeConfig>,
885 local_raft_id: u64,
886 local_raft_addr: SocketAddr,
887}
888
889impl RaftNode {
890 pub async fn start(
896 local_raft_id: u64,
897 raft_addr: SocketAddr,
898 storage: Arc<Storage>,
899 secret: Option<Arc<ClusterSecret>>,
900 ) -> Result<Self, openraft::error::Fatal<u64>> {
901 let config = Arc::new(
902 Config {
903 cluster_name: "ember".to_string(),
904 heartbeat_interval: 500,
905 election_timeout_min: 1500,
906 election_timeout_max: 3000,
907 ..Config::default()
908 }
909 .validate()
910 .expect("raft config validation failed"),
911 );
912
913 let (log_store, state_machine) = Adaptor::new(Arc::clone(&storage));
914
915 let network_factory = RaftNetworkFactory {
916 secret: secret.clone(),
917 };
918
919 let raft = Raft::new(
920 local_raft_id,
921 config,
922 network_factory,
923 log_store,
924 state_machine,
925 )
926 .await?;
927
928 spawn_raft_listener(raft.clone(), raft_addr, secret);
929
930 Ok(Self {
931 raft,
932 local_raft_id,
933 local_raft_addr: raft_addr,
934 })
935 }
936
937 pub async fn bootstrap_single(&self) -> Result<(), String> {
942 let mut members = BTreeMap::new();
943 members.insert(
944 self.local_raft_id,
945 BasicNode {
946 addr: self.local_raft_addr.to_string(),
947 },
948 );
949
950 self.raft
951 .initialize(members)
952 .await
953 .map_err(|e| e.to_string())
954 }
955
956 pub async fn propose(&self, cmd: ClusterCommand) -> Result<ClusterResponse, RaftProposalError> {
961 match self.raft.client_write(cmd).await {
962 Ok(resp) => Ok(resp.data),
963 Err(e) => match e {
964 openraft::error::RaftError::APIError(ClientWriteError::ForwardToLeader(fwd)) => {
965 Err(RaftProposalError::NotLeader(fwd.leader_node))
966 }
967 other => Err(RaftProposalError::Fatal(other.to_string())),
968 },
969 }
970 }
971
972 pub fn is_leader(&self) -> bool {
974 self.raft.metrics().borrow().state == ServerState::Leader
975 }
976
977 pub fn current_leader_node(&self) -> Option<BasicNode> {
979 let m = self.raft.metrics().borrow().clone();
980 let leader_id = m.current_leader?;
981 m.membership_config
982 .membership()
983 .get_node(&leader_id)
984 .cloned()
985 }
986
987 pub fn raft_handle(&self) -> &Raft<TypeConfig> {
989 &self.raft
990 }
991
992 pub fn local_raft_id(&self) -> u64 {
993 self.local_raft_id
994 }
995
996 pub fn raft_addr(&self) -> SocketAddr {
997 self.local_raft_addr
998 }
999}
1000
1001pub fn raft_id_from_node_id(node_id: NodeId) -> u64 {
1005 node_id.0.as_u64_pair().0
1006}
1007
1008fn io_error(msg: &str) -> std::io::Error {
1009 std::io::Error::other(msg)
1010}
1011
1012#[cfg(test)]
1013mod tests {
1014 use super::*;
1015 use openraft::CommittedLeaderId;
1016
1017 fn log_id(term: u64, index: u64) -> LogId<u64> {
1019 LogId::new(CommittedLeaderId::new(term, 0), index)
1020 }
1021
1022 #[tokio::test]
1023 async fn storage_add_node() {
1024 let (storage, _rx) = Storage::new();
1025 let mut storage_clone = Arc::clone(&storage);
1026
1027 let node_id = NodeId::new();
1028 let entry = Entry {
1029 log_id: log_id(1, 1),
1030 payload: EntryPayload::Normal(ClusterCommand::AddNode {
1031 node_id,
1032 raft_id: 1,
1033 addr: "127.0.0.1:6379".to_string(),
1034 is_primary: true,
1035 }),
1036 };
1037
1038 let results = storage_clone
1039 .apply_to_state_machine(&[entry])
1040 .await
1041 .unwrap();
1042 assert_eq!(results, vec![ClusterResponse::Ok]);
1043
1044 let state_arc = storage.state();
1045 let state = state_arc.read().await;
1046 assert!(state.nodes.contains_key(&node_id.as_key()));
1047 }
1048
1049 #[tokio::test]
1050 async fn storage_assign_slots() {
1051 let (storage, _rx) = Storage::new();
1052 let mut storage_clone = Arc::clone(&storage);
1053
1054 let node_id = NodeId::new();
1055
1056 let add_entry = Entry {
1058 log_id: log_id(1, 1),
1059 payload: EntryPayload::Normal(ClusterCommand::AddNode {
1060 node_id,
1061 raft_id: 1,
1062 addr: "127.0.0.1:6379".to_string(),
1063 is_primary: true,
1064 }),
1065 };
1066 storage_clone
1067 .apply_to_state_machine(&[add_entry])
1068 .await
1069 .unwrap();
1070
1071 let assign_entry = Entry {
1073 log_id: log_id(1, 2),
1074 payload: EntryPayload::Normal(ClusterCommand::AssignSlots {
1075 node_id,
1076 slots: vec![SlotRange::new(0, 5460)],
1077 }),
1078 };
1079 let results = storage_clone
1080 .apply_to_state_machine(&[assign_entry])
1081 .await
1082 .unwrap();
1083 assert_eq!(results, vec![ClusterResponse::Ok]);
1084
1085 let state_arc = storage.state();
1086 let state = state_arc.read().await;
1087 assert_eq!(state.slots.get(&0), Some(&node_id.as_key()));
1088 assert_eq!(state.slots.get(&5460), Some(&node_id.as_key()));
1089 }
1090
1091 #[tokio::test]
1092 async fn storage_remove_slots() {
1093 let (storage, _rx) = Storage::new();
1094 let mut s = Arc::clone(&storage);
1095 let node_id = NodeId::new();
1096
1097 let add = Entry {
1098 log_id: log_id(1, 1),
1099 payload: EntryPayload::Normal(ClusterCommand::AddNode {
1100 node_id,
1101 raft_id: 1,
1102 addr: "127.0.0.1:6379".into(),
1103 is_primary: true,
1104 }),
1105 };
1106 s.apply_to_state_machine(&[add]).await.unwrap();
1107
1108 let assign = Entry {
1109 log_id: log_id(1, 2),
1110 payload: EntryPayload::Normal(ClusterCommand::AssignSlots {
1111 node_id,
1112 slots: vec![SlotRange::new(0, 10)],
1113 }),
1114 };
1115 s.apply_to_state_machine(&[assign]).await.unwrap();
1116
1117 let remove = Entry {
1118 log_id: log_id(1, 3),
1119 payload: EntryPayload::Normal(ClusterCommand::RemoveSlots {
1120 node_id,
1121 slots: vec![SlotRange::new(0, 5)],
1122 }),
1123 };
1124 let results = s.apply_to_state_machine(&[remove]).await.unwrap();
1125 assert_eq!(results, vec![ClusterResponse::Ok]);
1126
1127 let state = storage.state();
1128 let state = state.read().await;
1129 assert!(!state.slots.contains_key(&0));
1130 assert!(state.slots.contains_key(&6));
1131 }
1132
1133 #[tokio::test]
1134 async fn storage_migration() {
1135 let (storage, _rx) = Storage::new();
1136 let mut storage_clone = Arc::clone(&storage);
1137
1138 let node1 = NodeId::new();
1139 let node2 = NodeId::new();
1140
1141 let entries: Vec<Entry<TypeConfig>> = [node1, node2]
1143 .iter()
1144 .enumerate()
1145 .map(|(i, node_id)| Entry {
1146 log_id: log_id(1, i as u64 + 1),
1147 payload: EntryPayload::Normal(ClusterCommand::AddNode {
1148 node_id: *node_id,
1149 raft_id: i as u64 + 1,
1150 addr: format!("127.0.0.1:{}", 6379 + i),
1151 is_primary: true,
1152 }),
1153 })
1154 .collect();
1155 storage_clone
1156 .apply_to_state_machine(&entries)
1157 .await
1158 .unwrap();
1159
1160 let begin_entry = Entry {
1162 log_id: log_id(1, 3),
1163 payload: EntryPayload::Normal(ClusterCommand::BeginMigration {
1164 slot: 100,
1165 from: node1,
1166 to: node2,
1167 }),
1168 };
1169 storage_clone
1170 .apply_to_state_machine(&[begin_entry])
1171 .await
1172 .unwrap();
1173
1174 {
1175 let state_arc = storage.state();
1176 let state = state_arc.read().await;
1177 assert!(state.migrations.contains_key(&100));
1178 }
1179
1180 let complete_entry = Entry {
1182 log_id: log_id(1, 4),
1183 payload: EntryPayload::Normal(ClusterCommand::CompleteMigration {
1184 slot: 100,
1185 new_owner: node2,
1186 }),
1187 };
1188 storage_clone
1189 .apply_to_state_machine(&[complete_entry])
1190 .await
1191 .unwrap();
1192
1193 {
1194 let state_arc = storage.state();
1195 let state = state_arc.read().await;
1196 assert!(!state.migrations.contains_key(&100));
1197 assert_eq!(state.slots.get(&100), Some(&node2.0.to_string()));
1198 }
1199 }
1200
1201 #[tokio::test]
1202 async fn assign_slots_rejects_invalid_range() {
1203 let (storage, _rx) = Storage::new();
1204 let mut s = Arc::clone(&storage);
1205
1206 let node_id = NodeId::new();
1207 let add = Entry {
1208 log_id: log_id(1, 1),
1209 payload: EntryPayload::Normal(ClusterCommand::AddNode {
1210 node_id,
1211 raft_id: 1,
1212 addr: "127.0.0.1:6379".into(),
1213 is_primary: true,
1214 }),
1215 };
1216 s.apply_to_state_machine(&[add]).await.unwrap();
1217
1218 let bad_range = SlotRange {
1220 start: 100,
1221 end: 50,
1222 };
1223 let assign = Entry {
1224 log_id: log_id(1, 2),
1225 payload: EntryPayload::Normal(ClusterCommand::AssignSlots {
1226 node_id,
1227 slots: vec![bad_range],
1228 }),
1229 };
1230 let results = s.apply_to_state_machine(&[assign]).await.unwrap();
1231 assert!(
1232 matches!(&results[0], ClusterResponse::Error(msg) if msg.contains("invalid slot range"))
1233 );
1234 }
1235
1236 #[tokio::test]
1237 async fn assign_slots_rejects_out_of_range() {
1238 let (storage, _rx) = Storage::new();
1239 let mut s = Arc::clone(&storage);
1240
1241 let node_id = NodeId::new();
1242 let add = Entry {
1243 log_id: log_id(1, 1),
1244 payload: EntryPayload::Normal(ClusterCommand::AddNode {
1245 node_id,
1246 raft_id: 1,
1247 addr: "127.0.0.1:6379".into(),
1248 is_primary: true,
1249 }),
1250 };
1251 s.apply_to_state_machine(&[add]).await.unwrap();
1252
1253 let bad_range = SlotRange {
1255 start: 0,
1256 end: 16384,
1257 };
1258 let assign = Entry {
1259 log_id: log_id(1, 2),
1260 payload: EntryPayload::Normal(ClusterCommand::AssignSlots {
1261 node_id,
1262 slots: vec![bad_range],
1263 }),
1264 };
1265 let results = s.apply_to_state_machine(&[assign]).await.unwrap();
1266 assert!(
1267 matches!(&results[0], ClusterResponse::Error(msg) if msg.contains("invalid slot range"))
1268 );
1269 }
1270
1271 #[tokio::test]
1272 async fn complete_migration_without_begin_errors() {
1273 let (storage, _rx) = Storage::new();
1274 let mut s = Arc::clone(&storage);
1275
1276 let node_id = NodeId::new();
1277 let complete = Entry {
1278 log_id: log_id(1, 1),
1279 payload: EntryPayload::Normal(ClusterCommand::CompleteMigration {
1280 slot: 100,
1281 new_owner: node_id,
1282 }),
1283 };
1284 let results = s.apply_to_state_machine(&[complete]).await.unwrap();
1285 assert!(matches!(&results[0], ClusterResponse::Error(msg) if msg.contains("no migration")));
1286 }
1287
1288 #[tokio::test]
1289 async fn begin_migration_rejects_invalid_slot() {
1290 let (storage, _rx) = Storage::new();
1291 let mut s = Arc::clone(&storage);
1292
1293 let node1 = NodeId::new();
1294 let node2 = NodeId::new();
1295 let begin = Entry {
1296 log_id: log_id(1, 1),
1297 payload: EntryPayload::Normal(ClusterCommand::BeginMigration {
1298 slot: 16384,
1299 from: node1,
1300 to: node2,
1301 }),
1302 };
1303 let results = s.apply_to_state_machine(&[begin]).await.unwrap();
1304 assert!(matches!(&results[0], ClusterResponse::Error(msg) if msg.contains("out of range")));
1305 }
1306
1307 #[tokio::test]
1308 async fn storage_log_operations() {
1309 let (storage, _rx) = Storage::new();
1310 let mut storage_clone = Arc::clone(&storage);
1311
1312 let entry = Entry::<TypeConfig> {
1313 log_id: log_id(1, 1),
1314 payload: EntryPayload::Blank,
1315 };
1316
1317 storage_clone.append_to_log(vec![entry]).await.unwrap();
1318
1319 let state = storage_clone.get_log_state().await.unwrap();
1320 assert_eq!(state.last_log_id, Some(log_id(1, 1)));
1321 }
1322
1323 #[tokio::test]
1324 async fn storage_vote() {
1325 let (storage, _rx) = Storage::new();
1326 let mut storage_clone = Arc::clone(&storage);
1327
1328 let vote = Vote::new(1, 1);
1329 storage_clone.save_vote(&vote).await.unwrap();
1330
1331 let read_vote = storage_clone.read_vote().await.unwrap();
1332 assert_eq!(read_vote, Some(vote));
1333 }
1334
1335 #[tokio::test]
1336 async fn watch_channel_notified_on_apply() {
1337 let (storage, mut rx) = Storage::new();
1338 let mut s = Arc::clone(&storage);
1339
1340 let node_id = NodeId::new();
1341 let entry = Entry {
1342 log_id: log_id(1, 1),
1343 payload: EntryPayload::Normal(ClusterCommand::AddNode {
1344 node_id,
1345 raft_id: 1,
1346 addr: "127.0.0.1:6379".into(),
1347 is_primary: true,
1348 }),
1349 };
1350
1351 let _ = rx.borrow_and_update();
1354
1355 s.apply_to_state_machine(&[entry]).await.unwrap();
1356
1357 assert!(
1358 rx.changed().await.is_ok(),
1359 "watch channel should have fired"
1360 );
1361 let data = rx.borrow();
1362 assert!(data.nodes.contains_key(&node_id.as_key()));
1363 }
1364}