1#![allow(clippy::result_large_err)]
26
27use crate::config::ClusterConfig;
28use crate::error::{ClusterError, Result};
29use crate::metadata::{ClusterMetadata, MetadataCommand, MetadataResponse};
30use crate::storage::RedbLogStore;
31use openraft::network::{RPCOption, RaftNetwork, RaftNetworkFactory};
32use openraft::raft::responder::OneshotResponder;
33use openraft::storage::{RaftStateMachine, Snapshot};
34use openraft::{
35 BasicNode, Entry, EntryPayload, LogId, Membership, RaftTypeConfig, SnapshotMeta, StorageError,
36 StorageIOError, StoredMembership, Vote,
37};
38use serde::{Deserialize, Serialize};
39use std::collections::BTreeMap;
40use std::fmt::Debug;
41use std::path::PathBuf;
42use std::sync::Arc;
43use tokio::io::{AsyncReadExt, AsyncSeekExt};
44use tokio::sync::RwLock;
45use tracing::{debug, info, warn};
46
47pub type LogStore = RedbLogStore;
49
50#[derive(Debug)]
56struct NetworkErrorWrapper(String);
57
58impl std::fmt::Display for NetworkErrorWrapper {
59 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60 write!(f, "{}", self.0)
61 }
62}
63
64impl std::error::Error for NetworkErrorWrapper {}
65
66pub type NodeId = u64;
72
73#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
75pub struct RaftRequest {
76 pub command: MetadataCommand,
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
81pub struct RaftResponse {
82 pub response: MetadataResponse,
83}
84
85#[derive(Debug, Clone, Copy, Default, Eq, PartialEq, Ord, PartialOrd)]
87#[cfg_attr(feature = "raft", derive(Serialize, Deserialize))]
88pub struct TypeConfig;
89
90impl RaftTypeConfig for TypeConfig {
91 type D = RaftRequest;
92 type R = RaftResponse;
93 type NodeId = NodeId;
94 type Node = BasicNode;
95 type Entry = Entry<TypeConfig>;
96 type SnapshotData = tokio::fs::File;
112 type AsyncRuntime = openraft::TokioRuntime;
113 type Responder = OneshotResponder<TypeConfig>;
114}
115
116const MAX_SNAPSHOT_SIZE: usize = 64 * 1024 * 1024; pub type RaftLogId = LogId<NodeId>;
122pub type RaftVote = Vote<NodeId>;
123pub type RaftEntry = Entry<TypeConfig>;
124pub type RaftMembership = Membership<NodeId, BasicNode>;
125pub type RaftStoredMembership = StoredMembership<NodeId, BasicNode>;
126pub type RaftSnapshot = Snapshot<TypeConfig>;
127pub type RaftSnapshotMeta = SnapshotMeta<NodeId, BasicNode>;
128
129pub struct StateMachine {
135 metadata: RwLock<ClusterMetadata>,
137 last_applied: RwLock<Option<RaftLogId>>,
139 membership: RwLock<RaftStoredMembership>,
141 snapshot_dir: PathBuf,
143 current_snapshot_id: RwLock<Option<String>>,
145 current_snapshot_meta: RwLock<Option<RaftSnapshotMeta>>,
147}
148
149impl StateMachine {
150 pub fn new(snapshot_dir: impl Into<PathBuf>) -> Self {
152 let dir = snapshot_dir.into();
153 if let Err(e) = std::fs::create_dir_all(&dir) {
154 warn!(
155 path = %dir.display(),
156 error = %e,
157 "Failed to create snapshot directory; snapshots may not persist"
158 );
159 }
160 Self {
161 metadata: RwLock::new(ClusterMetadata::new()),
162 last_applied: RwLock::new(None),
163 membership: RwLock::new(StoredMembership::new(None, Membership::new(vec![], ()))),
164 snapshot_dir: dir,
165 current_snapshot_id: RwLock::new(None),
166 current_snapshot_meta: RwLock::new(None),
167 }
168 }
169
170 pub fn new_default() -> Self {
172 Self::new(std::env::temp_dir().join("rivven-snapshots"))
173 }
174
175 pub async fn metadata(&self) -> tokio::sync::RwLockReadGuard<'_, ClusterMetadata> {
177 self.metadata.read().await
178 }
179
180 async fn apply_command(&self, log_id: &RaftLogId, command: MetadataCommand) -> RaftResponse {
182 let mut metadata = self.metadata.write().await;
183 let response = metadata.apply(log_id.index, command);
184 *self.last_applied.write().await = Some(*log_id);
185 RaftResponse { response }
186 }
187
188 async fn create_snapshot(
199 &self,
200 ) -> std::result::Result<(RaftSnapshotMeta, PathBuf), StorageError<NodeId>> {
201 let metadata_guard = self.metadata.read().await;
203 let last_applied_guard = self.last_applied.read().await;
204 let membership_guard = self.membership.read().await;
205
206 let metadata = metadata_guard.clone();
207 let last_applied = *last_applied_guard;
208 let membership = membership_guard.clone();
209
210 drop(membership_guard);
212 drop(last_applied_guard);
213 drop(metadata_guard);
214
215 let snapshot_data = SnapshotData {
216 metadata: metadata.clone(),
217 last_applied,
218 membership: membership.clone(),
219 };
220
221 let data = postcard::to_allocvec(&snapshot_data).map_err(|e| StorageError::IO {
222 source: StorageIOError::read_state_machine(openraft::AnyError::new(&e)),
223 })?;
224
225 if data.len() > MAX_SNAPSHOT_SIZE {
227 return Err(StorageError::IO {
228 source: StorageIOError::read_state_machine(openraft::AnyError::new(
229 &std::io::Error::other(format!(
230 "Snapshot too large: {} bytes > {} byte limit",
231 data.len(),
232 MAX_SNAPSHOT_SIZE
233 )),
234 )),
235 });
236 }
237
238 let snapshot_id = format!("snapshot-{}", metadata.last_applied_index);
239
240 let meta = SnapshotMeta {
241 last_log_id: snapshot_data.last_applied,
242 last_membership: membership,
243 snapshot_id: snapshot_id.clone(),
244 };
245
246 let snap_path = self.snapshot_dir.join(format!("{}.snap", snapshot_id));
250 let tmp_path = self.snapshot_dir.join(format!("{}.snap.tmp", snapshot_id));
251
252 let crc = crc32fast::hash(&data);
253 let mut file_data = data.clone();
254 file_data.extend_from_slice(&crc.to_le_bytes());
255
256 tokio::fs::write(&tmp_path, &file_data)
257 .await
258 .map_err(|e| StorageError::IO {
259 source: StorageIOError::write_snapshot(Some(meta.signature()), &e),
260 })?;
261
262 tokio::fs::rename(&tmp_path, &snap_path)
263 .await
264 .map_err(|e| StorageError::IO {
265 source: StorageIOError::write_snapshot(Some(meta.signature()), &e),
266 })?;
267
268 *self.current_snapshot_id.write().await = Some(snapshot_id.clone());
270 *self.current_snapshot_meta.write().await = Some(meta.clone());
271
272 self.cleanup_old_snapshots(3).await;
274
275 info!(
276 snapshot_id = %meta.snapshot_id,
277 last_log_id = ?meta.last_log_id,
278 size_bytes = data.len(),
279 path = %snap_path.display(),
280 "Created file-backed snapshot"
281 );
282
283 Ok((meta, snap_path))
284 }
285
286 async fn install_snapshot_data(
296 &self,
297 data: &[u8],
298 ) -> std::result::Result<(), StorageError<NodeId>> {
299 let payload = if data.len() > 4 {
301 let (payload, crc_bytes) = data.split_at(data.len() - 4);
302 let stored_crc = u32::from_le_bytes(crc_bytes.try_into().unwrap());
303 let actual_crc = crc32fast::hash(payload);
304 if stored_crc != actual_crc {
305 return Err(StorageError::IO {
306 source: StorageIOError::read_state_machine(openraft::AnyError::new(
307 &std::io::Error::other(format!(
308 "Snapshot CRC mismatch: stored={:#010x} actual={:#010x} — file is corrupt",
309 stored_crc, actual_crc
310 )),
311 )),
312 });
313 }
314 payload
315 } else {
316 data
317 };
318
319 let snapshot_data: SnapshotData =
320 postcard::from_bytes(payload).map_err(|e| StorageError::IO {
321 source: StorageIOError::read_state_machine(openraft::AnyError::new(&e)),
322 })?;
323
324 let mut metadata_guard = self.metadata.write().await;
327 let mut last_applied_guard = self.last_applied.write().await;
328 let mut membership_guard = self.membership.write().await;
329
330 *metadata_guard = snapshot_data.metadata;
331 *last_applied_guard = snapshot_data.last_applied;
332 *membership_guard = snapshot_data.membership;
333
334 drop(membership_guard);
336 drop(last_applied_guard);
337 drop(metadata_guard);
338
339 info!("Installed snapshot from data");
340 Ok(())
341 }
342
343 pub async fn load_latest_snapshot(&self) -> std::result::Result<bool, StorageError<NodeId>> {
348 let latest = self.find_latest_snapshot_file().await;
349 let Some(path) = latest else {
350 debug!("No snapshot files found in {}", self.snapshot_dir.display());
351 return Ok(false);
352 };
353
354 let data = tokio::fs::read(&path).await.map_err(|e| StorageError::IO {
355 source: StorageIOError::read_state_machine(openraft::AnyError::new(&e)),
356 })?;
357
358 self.install_snapshot_data(&data).await?;
359
360 let snapshot_id = path
362 .file_stem()
363 .and_then(|s| s.to_str())
364 .unwrap_or("unknown")
365 .to_string();
366
367 info!(
368 snapshot_id = %snapshot_id,
369 size_bytes = data.len(),
370 path = %path.display(),
371 "Restored state machine from snapshot file"
372 );
373
374 *self.current_snapshot_id.write().await = Some(snapshot_id);
375 Ok(true)
376 }
377
378 async fn find_latest_snapshot_file(&self) -> Option<PathBuf> {
380 let mut entries = match tokio::fs::read_dir(&self.snapshot_dir).await {
381 Ok(e) => e,
382 Err(_) => return None,
383 };
384
385 let mut best: Option<(u64, PathBuf)> = None;
386
387 while let Ok(Some(entry)) = entries.next_entry().await {
388 let path = entry.path();
389 if path.extension().and_then(|e| e.to_str()) != Some("snap") {
390 continue;
391 }
392 let stem = path.file_stem().and_then(|s| s.to_str()).unwrap_or("");
394 if let Some(idx_str) = stem.strip_prefix("snapshot-") {
395 if let Ok(idx) = idx_str.parse::<u64>() {
396 if best.as_ref().is_none_or(|(best_idx, _)| idx > *best_idx) {
397 best = Some((idx, path));
398 }
399 }
400 }
401 }
402
403 best.map(|(_, p)| p)
404 }
405
406 async fn cleanup_old_snapshots(&self, keep: usize) {
408 let mut entries = match tokio::fs::read_dir(&self.snapshot_dir).await {
409 Ok(e) => e,
410 Err(_) => return,
411 };
412
413 let mut snaps: Vec<(u64, PathBuf)> = Vec::new();
414 while let Ok(Some(entry)) = entries.next_entry().await {
415 let path = entry.path();
416 if path.extension().and_then(|e| e.to_str()) != Some("snap") {
417 continue;
418 }
419 let stem = path.file_stem().and_then(|s| s.to_str()).unwrap_or("");
420 if let Some(idx_str) = stem.strip_prefix("snapshot-") {
421 if let Ok(idx) = idx_str.parse::<u64>() {
422 snaps.push((idx, path));
423 }
424 }
425 }
426
427 if snaps.len() <= keep {
428 return;
429 }
430
431 snaps.sort_by_key(|(idx, _)| *idx);
432 let to_remove = snaps.len() - keep;
433 for (_, path) in snaps.into_iter().take(to_remove) {
434 if let Err(e) = tokio::fs::remove_file(&path).await {
435 warn!(path = %path.display(), error = %e, "Failed to remove old snapshot");
436 } else {
437 debug!(path = %path.display(), "Removed old snapshot file");
438 }
439 }
440 }
441}
442
443impl Default for StateMachine {
444 fn default() -> Self {
445 Self::new_default()
446 }
447}
448
449#[derive(Debug, Clone, Serialize, Deserialize)]
451struct SnapshotData {
452 metadata: ClusterMetadata,
453 last_applied: Option<RaftLogId>,
454 membership: RaftStoredMembership,
455}
456
457impl RaftStateMachine<TypeConfig> for StateMachine {
459 type SnapshotBuilder = Self;
460
461 async fn applied_state(
462 &mut self,
463 ) -> std::result::Result<(Option<RaftLogId>, RaftStoredMembership), StorageError<NodeId>> {
464 let last_applied = *self.last_applied.read().await;
465 let membership = self.membership.read().await.clone();
466 Ok((last_applied, membership))
467 }
468
469 async fn apply<I>(
470 &mut self,
471 entries: I,
472 ) -> std::result::Result<Vec<RaftResponse>, StorageError<NodeId>>
473 where
474 I: IntoIterator<Item = RaftEntry> + Send,
475 I::IntoIter: Send,
476 {
477 let mut responses = Vec::new();
478
479 for entry in entries {
480 let log_id = entry.log_id;
481
482 match entry.payload {
483 EntryPayload::Blank => {
484 *self.last_applied.write().await = Some(log_id);
486 responses.push(RaftResponse {
487 response: MetadataResponse::Success,
488 });
489 }
490 EntryPayload::Normal(req) => {
491 let response = self.apply_command(&log_id, req.command).await;
493 responses.push(response);
494 }
495 EntryPayload::Membership(membership) => {
496 *self.membership.write().await =
498 StoredMembership::new(Some(log_id), membership);
499 *self.last_applied.write().await = Some(log_id);
500 responses.push(RaftResponse {
501 response: MetadataResponse::Success,
502 });
503 }
504 }
505 }
506
507 Ok(responses)
508 }
509
510 async fn begin_receiving_snapshot(
511 &mut self,
512 ) -> std::result::Result<Box<tokio::fs::File>, StorageError<NodeId>> {
513 let unique_id = format!(
518 "incoming-{}-{}.snap.tmp",
519 std::process::id(),
520 uuid::Uuid::new_v4().as_simple()
521 );
522 let tmp_path = self.snapshot_dir.join(&unique_id);
523 let file = tokio::fs::OpenOptions::new()
524 .create(true)
525 .write(true)
526 .read(true)
527 .truncate(true)
528 .open(&tmp_path)
529 .await
530 .map_err(|e| StorageError::IO {
531 source: StorageIOError::write_snapshot(None, &e),
532 })?;
533
534 debug!(path = %tmp_path.display(), "Created temp file for incoming snapshot");
535 Ok(Box::new(file))
536 }
537
538 async fn install_snapshot(
539 &mut self,
540 meta: &RaftSnapshotMeta,
541 mut snapshot: Box<tokio::fs::File>,
542 ) -> std::result::Result<(), StorageError<NodeId>> {
543 snapshot
545 .seek(std::io::SeekFrom::Start(0))
546 .await
547 .map_err(|e| StorageError::IO {
548 source: StorageIOError::read_snapshot(Some(meta.signature()), &e),
549 })?;
550
551 let mut data = Vec::new();
552 snapshot
553 .read_to_end(&mut data)
554 .await
555 .map_err(|e| StorageError::IO {
556 source: StorageIOError::read_snapshot(Some(meta.signature()), &e),
557 })?;
558
559 if data.len() > MAX_SNAPSHOT_SIZE {
561 return Err(StorageError::IO {
562 source: StorageIOError::read_snapshot(
563 Some(meta.signature()),
564 &std::io::Error::new(
565 std::io::ErrorKind::InvalidData,
566 format!(
567 "incoming snapshot {} bytes exceeds maximum {} bytes",
568 data.len(),
569 MAX_SNAPSHOT_SIZE
570 ),
571 ),
572 ),
573 });
574 }
575
576 self.install_snapshot_data(&data).await?;
578
579 *self.membership.write().await = meta.last_membership.clone();
581
582 let snap_path = self.snapshot_dir.join(format!("{}.snap", meta.snapshot_id));
585 let pid_prefix = format!("incoming-{}-", std::process::id());
586 if let Ok(mut entries) = tokio::fs::read_dir(&self.snapshot_dir).await {
587 while let Ok(Some(entry)) = entries.next_entry().await {
588 let name = entry.file_name();
589 let name_str = name.to_string_lossy();
590 if name_str.starts_with(&pid_prefix) && name_str.ends_with(".snap.tmp") {
591 let tmp_path = entry.path();
592 if !snap_path.exists() {
594 let _ = tokio::fs::rename(&tmp_path, &snap_path).await.map_err(|e| {
595 StorageError::IO {
596 source: StorageIOError::write_snapshot(Some(meta.signature()), &e),
597 }
598 });
599 } else {
600 let _ = tokio::fs::remove_file(&tmp_path).await;
601 }
602 }
603 }
604 }
605
606 *self.current_snapshot_id.write().await = Some(meta.snapshot_id.clone());
608 *self.current_snapshot_meta.write().await = Some(meta.clone());
609
610 self.cleanup_old_snapshots(3).await;
611
612 info!(
613 snapshot_id = %meta.snapshot_id,
614 size_bytes = data.len(),
615 "Installed snapshot from leader and persisted to disk"
616 );
617 Ok(())
618 }
619
620 async fn get_current_snapshot(
621 &mut self,
622 ) -> std::result::Result<Option<RaftSnapshot>, StorageError<NodeId>> {
623 if let Some(id) = self.current_snapshot_id.read().await.as_deref() {
625 let existing = self.snapshot_dir.join(format!("{}.snap", id));
626 if existing.exists() {
627 if let Some(meta) = self.current_snapshot_meta.read().await.clone() {
628 let sig = meta.signature();
629 let file =
630 tokio::fs::File::open(&existing)
631 .await
632 .map_err(|e| StorageError::IO {
633 source: StorageIOError::read_snapshot(Some(sig), &e),
634 })?;
635 return Ok(Some(Snapshot {
636 meta,
637 snapshot: Box::new(file),
638 }));
639 }
640 }
641 }
642
643 let (meta, snap_path) = self.create_snapshot().await?;
645 let sig = meta.signature();
646 let file = tokio::fs::File::open(&snap_path)
647 .await
648 .map_err(|e| StorageError::IO {
649 source: StorageIOError::read_snapshot(Some(sig), &e),
650 })?;
651 Ok(Some(Snapshot {
652 meta,
653 snapshot: Box::new(file),
654 }))
655 }
656
657 async fn get_snapshot_builder(&mut self) -> Self::SnapshotBuilder {
658 let metadata = self.metadata.read().await;
663 let last_applied = self.last_applied.read().await;
664 let membership = self.membership.read().await;
665 let snap_id = self.current_snapshot_id.read().await;
666 let snap_meta = self.current_snapshot_meta.read().await;
667
668 Self {
669 metadata: RwLock::new(metadata.clone()),
670 last_applied: RwLock::new(*last_applied),
671 membership: RwLock::new(membership.clone()),
672 snapshot_dir: self.snapshot_dir.clone(),
673 current_snapshot_id: RwLock::new(snap_id.clone()),
674 current_snapshot_meta: RwLock::new(snap_meta.clone()),
675 }
676 }
677}
678
679impl openraft::storage::RaftSnapshotBuilder<TypeConfig> for StateMachine {
681 async fn build_snapshot(&mut self) -> std::result::Result<RaftSnapshot, StorageError<NodeId>> {
682 let (meta, snap_path) = self.create_snapshot().await?;
683 let sig = meta.signature();
684 let file = tokio::fs::File::open(&snap_path)
685 .await
686 .map_err(|e| StorageError::IO {
687 source: StorageIOError::read_snapshot(Some(sig), &e),
688 })?;
689 Ok(Snapshot {
690 meta,
691 snapshot: Box::new(file),
692 })
693 }
694}
695
696#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
703pub enum SerializationFormat {
704 Json,
706 #[default]
708 Binary,
709}
710
711#[derive(Debug, Clone)]
713pub struct RaftCompressionConfig {
714 pub enabled: bool,
716 pub min_size: usize,
718 pub adaptive: bool,
720}
721
722impl Default for RaftCompressionConfig {
723 fn default() -> Self {
724 Self {
725 enabled: true,
726 min_size: 1024, adaptive: true,
728 }
729 }
730}
731
732pub const CLUSTER_SECRET_HEADER: &str = "X-Rivven-Cluster-Secret";
738
739#[derive(Clone)]
741pub struct NetworkFactory {
742 nodes: Arc<RwLock<BTreeMap<NodeId, String>>>,
744 client: reqwest::Client,
746 format: SerializationFormat,
748 compression: RaftCompressionConfig,
750 cluster_secret: Option<String>,
752}
753
754impl NetworkFactory {
755 pub fn new() -> Result<Self> {
757 Self::with_format(SerializationFormat::Binary)
758 }
759
760 pub fn with_format(format: SerializationFormat) -> Result<Self> {
762 Ok(Self {
763 nodes: Arc::new(RwLock::new(BTreeMap::new())),
764 client: reqwest::Client::builder()
765 .timeout(std::time::Duration::from_secs(5))
766 .pool_max_idle_per_host(10) .pool_idle_timeout(std::time::Duration::from_secs(60))
768 .tcp_keepalive(std::time::Duration::from_secs(30))
769 .tcp_nodelay(true) .build()
771 .map_err(|e| {
772 ClusterError::Network(format!("Failed to create HTTP client: {}", e))
773 })?,
774 format,
775 compression: RaftCompressionConfig::default(),
776 cluster_secret: None,
777 })
778 }
779
780 pub fn with_compression(
782 format: SerializationFormat,
783 compression: RaftCompressionConfig,
784 ) -> Result<Self> {
785 Ok(Self {
786 compression,
787 ..Self::with_format(format)?
788 })
789 }
790
791 pub fn with_cluster_secret(mut self, secret: Option<String>) -> Self {
793 self.cluster_secret = secret;
794 self
795 }
796
797 pub async fn add_node(&self, node_id: NodeId, addr: String) {
799 self.nodes.write().await.insert(node_id, addr);
800 }
801
802 pub async fn remove_node(&self, node_id: NodeId) {
804 self.nodes.write().await.remove(&node_id);
805 }
806}
807
808pub struct Network {
810 #[allow(dead_code)]
812 target: NodeId,
813 target_addr: String,
814 client: reqwest::Client,
815 format: SerializationFormat,
816 compression: RaftCompressionConfig,
817 cluster_secret: Option<String>,
819}
820
821impl Network {
822 pub fn new(
823 target: NodeId,
824 target_addr: String,
825 client: reqwest::Client,
826 format: SerializationFormat,
827 compression: RaftCompressionConfig,
828 cluster_secret: Option<String>,
829 ) -> Self {
830 Self {
831 target,
832 target_addr,
833 client,
834 format,
835 compression,
836 cluster_secret,
837 }
838 }
839
840 fn apply_auth(&self, request: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
842 if let Some(ref secret) = self.cluster_secret {
843 request.header(CLUSTER_SECRET_HEADER, secret)
844 } else {
845 request
846 }
847 }
848
849 fn serialize<T: Serialize>(&self, data: &T) -> std::result::Result<Vec<u8>, String> {
851 match self.format {
852 SerializationFormat::Json => serde_json::to_vec(data).map_err(|e| e.to_string()),
853 SerializationFormat::Binary => postcard::to_allocvec(data).map_err(|e| e.to_string()),
854 }
855 }
856
857 fn deserialize<T: serde::de::DeserializeOwned>(
859 &self,
860 data: &[u8],
861 ) -> std::result::Result<T, String> {
862 match self.format {
863 SerializationFormat::Json => serde_json::from_slice(data).map_err(|e| e.to_string()),
864 SerializationFormat::Binary => postcard::from_bytes(data).map_err(|e| e.to_string()),
865 }
866 }
867
868 fn content_type(&self) -> &'static str {
870 match self.format {
871 SerializationFormat::Json => "application/json",
872 SerializationFormat::Binary => "application/octet-stream",
873 }
874 }
875
876 #[cfg(feature = "compression")]
878 fn maybe_compress(&self, data: Vec<u8>) -> (Vec<u8>, bool) {
879 use rivven_core::compression::{CompressionConfig, Compressor};
880
881 if !self.compression.enabled || data.len() < self.compression.min_size {
882 return (data, false);
883 }
884
885 let config = CompressionConfig {
886 min_size: self.compression.min_size,
887 adaptive: self.compression.adaptive,
888 ..Default::default()
889 };
890 let compressor = Compressor::with_config(config);
891
892 match compressor.compress(&data) {
893 Ok(compressed) => {
894 if compressed.len() < data.len() {
896 (compressed.to_vec(), true)
897 } else {
898 (data, false)
899 }
900 }
901 Err(_) => (data, false),
902 }
903 }
904
905 #[cfg(not(feature = "compression"))]
906 fn maybe_compress(&self, data: Vec<u8>) -> (Vec<u8>, bool) {
907 (data, false)
908 }
909
910 #[cfg(feature = "compression")]
912 fn maybe_decompress(
913 &self,
914 data: &[u8],
915 was_compressed: bool,
916 ) -> std::result::Result<Vec<u8>, String> {
917 use rivven_core::compression::Compressor;
918
919 if !was_compressed {
920 return Ok(data.to_vec());
921 }
922
923 let compressor = Compressor::new();
924 compressor
925 .decompress(data)
926 .map(|b| b.to_vec())
927 .map_err(|e| e.to_string())
928 }
929
930 #[cfg(not(feature = "compression"))]
931 fn maybe_decompress(
932 &self,
933 data: &[u8],
934 _was_compressed: bool,
935 ) -> std::result::Result<Vec<u8>, String> {
936 Ok(data.to_vec())
937 }
938}
939
940impl RaftNetworkFactory<TypeConfig> for NetworkFactory {
942 type Network = Network;
943
944 async fn new_client(&mut self, target: NodeId, node: &BasicNode) -> Self::Network {
945 Network::new(
946 target,
947 node.addr.clone(),
948 self.client.clone(),
949 self.format,
950 self.compression.clone(),
951 self.cluster_secret.clone(),
952 )
953 }
954}
955
956impl RaftNetwork<TypeConfig> for Network {
958 async fn append_entries(
959 &mut self,
960 rpc: openraft::raft::AppendEntriesRequest<TypeConfig>,
961 _option: RPCOption,
962 ) -> std::result::Result<
963 openraft::raft::AppendEntriesResponse<NodeId>,
964 openraft::error::RPCError<NodeId, BasicNode, openraft::error::RaftError<NodeId>>,
965 > {
966 use crate::observability::{NetworkMetrics, RaftMetrics};
967 let start = std::time::Instant::now();
968
969 let url = format!("{}/raft/append", self.target_addr);
970 let serialized = self.serialize(&rpc).map_err(|e| {
971 openraft::error::RPCError::Network(openraft::error::NetworkError::new(
972 &NetworkErrorWrapper(e),
973 ))
974 })?;
975
976 let (body, compressed) = self.maybe_compress(serialized);
978 let uncompressed_size = body.len();
979
980 NetworkMetrics::add_bytes_sent(body.len() as u64);
981 RaftMetrics::increment_append_entries_sent();
982
983 let mut request = self.client.post(&url).body(body);
985 request = request.header("Content-Type", self.content_type());
986 request = self.apply_auth(request);
987 if compressed {
988 request = request.header("X-Rivven-Compressed", "1");
989 request = request.header("X-Rivven-Original-Size", uncompressed_size.to_string());
990 }
991
992 let resp = request.send().await.map_err(|e| {
993 NetworkMetrics::increment_rpc_errors("append_entries");
994 openraft::error::RPCError::Network(openraft::error::NetworkError::new(&e))
995 })?;
996
997 if !resp.status().is_success() {
998 NetworkMetrics::increment_rpc_errors("append_entries");
999 return Err(openraft::error::RPCError::Network(
1000 openraft::error::NetworkError::new(&NetworkErrorWrapper(format!(
1001 "HTTP error: {}",
1002 resp.status()
1003 ))),
1004 ));
1005 }
1006
1007 let resp_compressed = resp
1009 .headers()
1010 .get("X-Rivven-Compressed")
1011 .map(|v| v == "1")
1012 .unwrap_or(false);
1013
1014 let bytes = resp.bytes().await.map_err(|e| {
1015 openraft::error::RPCError::Network(openraft::error::NetworkError::new(&e))
1016 })?;
1017
1018 NetworkMetrics::add_bytes_received(bytes.len() as u64);
1019 RaftMetrics::record_append_entries_latency(start.elapsed());
1020
1021 let response_data = self
1023 .maybe_decompress(&bytes, resp_compressed)
1024 .map_err(|e| {
1025 openraft::error::RPCError::Network(openraft::error::NetworkError::new(
1026 &NetworkErrorWrapper(e),
1027 ))
1028 })?;
1029
1030 let response: openraft::raft::AppendEntriesResponse<NodeId> =
1031 self.deserialize(&response_data).map_err(|e| {
1032 openraft::error::RPCError::Network(openraft::error::NetworkError::new(
1033 &NetworkErrorWrapper(e),
1034 ))
1035 })?;
1036
1037 Ok(response)
1038 }
1039
1040 async fn install_snapshot(
1041 &mut self,
1042 rpc: openraft::raft::InstallSnapshotRequest<TypeConfig>,
1043 _option: RPCOption,
1044 ) -> std::result::Result<
1045 openraft::raft::InstallSnapshotResponse<NodeId>,
1046 openraft::error::RPCError<
1047 NodeId,
1048 BasicNode,
1049 openraft::error::RaftError<NodeId, openraft::error::InstallSnapshotError>,
1050 >,
1051 > {
1052 use crate::observability::{NetworkMetrics, RaftMetrics};
1053 let start = std::time::Instant::now();
1054
1055 let url = format!("{}/raft/snapshot", self.target_addr);
1056 let serialized = self.serialize(&rpc).map_err(|e| {
1057 openraft::error::RPCError::Network(openraft::error::NetworkError::new(
1058 &NetworkErrorWrapper(e),
1059 ))
1060 })?;
1061
1062 let (body, compressed) = self.maybe_compress(serialized);
1064 let uncompressed_size = body.len();
1065
1066 NetworkMetrics::add_bytes_sent(body.len() as u64);
1067
1068 let mut request = self.client.post(&url).body(body);
1070 request = request.header("Content-Type", self.content_type());
1071 request = self.apply_auth(request);
1072 if compressed {
1073 request = request.header("X-Rivven-Compressed", "1");
1074 request = request.header("X-Rivven-Original-Size", uncompressed_size.to_string());
1075 }
1076
1077 let resp = request.send().await.map_err(|e| {
1078 NetworkMetrics::increment_rpc_errors("install_snapshot");
1079 openraft::error::RPCError::Network(openraft::error::NetworkError::new(&e))
1080 })?;
1081
1082 if !resp.status().is_success() {
1083 NetworkMetrics::increment_rpc_errors("install_snapshot");
1084 return Err(openraft::error::RPCError::Network(
1085 openraft::error::NetworkError::new(&NetworkErrorWrapper(format!(
1086 "HTTP error: {}",
1087 resp.status()
1088 ))),
1089 ));
1090 }
1091
1092 let bytes = resp.bytes().await.map_err(|e| {
1093 openraft::error::RPCError::Network(openraft::error::NetworkError::new(&e))
1094 })?;
1095
1096 NetworkMetrics::add_bytes_received(bytes.len() as u64);
1097 RaftMetrics::record_snapshot_duration(start.elapsed());
1098
1099 let response: openraft::raft::InstallSnapshotResponse<NodeId> =
1100 self.deserialize(&bytes).map_err(|e| {
1101 openraft::error::RPCError::Network(openraft::error::NetworkError::new(
1102 &NetworkErrorWrapper(e),
1103 ))
1104 })?;
1105
1106 Ok(response)
1107 }
1108
1109 async fn vote(
1110 &mut self,
1111 rpc: openraft::raft::VoteRequest<NodeId>,
1112 _option: RPCOption,
1113 ) -> std::result::Result<
1114 openraft::raft::VoteResponse<NodeId>,
1115 openraft::error::RPCError<NodeId, BasicNode, openraft::error::RaftError<NodeId>>,
1116 > {
1117 use crate::observability::{NetworkMetrics, RaftMetrics};
1118 let start = std::time::Instant::now();
1119
1120 let url = format!("{}/raft/vote", self.target_addr);
1121 let body = self.serialize(&rpc).map_err(|e| {
1122 openraft::error::RPCError::Network(openraft::error::NetworkError::new(
1123 &NetworkErrorWrapper(e),
1124 ))
1125 })?;
1126
1127 NetworkMetrics::add_bytes_sent(body.len() as u64);
1128
1129 let resp = self
1130 .client
1131 .post(&url)
1132 .body(body)
1133 .header("Content-Type", self.content_type());
1134 let resp = self.apply_auth(resp);
1135 let resp = resp.send().await.map_err(|e| {
1136 NetworkMetrics::increment_rpc_errors("vote");
1137 openraft::error::RPCError::Network(openraft::error::NetworkError::new(&e))
1138 })?;
1139
1140 if !resp.status().is_success() {
1141 NetworkMetrics::increment_rpc_errors("vote");
1142 return Err(openraft::error::RPCError::Network(
1143 openraft::error::NetworkError::new(&NetworkErrorWrapper(format!(
1144 "HTTP error: {}",
1145 resp.status()
1146 ))),
1147 ));
1148 }
1149
1150 let bytes = resp.bytes().await.map_err(|e| {
1151 openraft::error::RPCError::Network(openraft::error::NetworkError::new(&e))
1152 })?;
1153
1154 NetworkMetrics::add_bytes_received(bytes.len() as u64);
1155 RaftMetrics::record_vote_latency(start.elapsed());
1156 RaftMetrics::increment_elections();
1157
1158 let response: openraft::raft::VoteResponse<NodeId> =
1159 self.deserialize(&bytes).map_err(|e| {
1160 openraft::error::RPCError::Network(openraft::error::NetworkError::new(
1161 &NetworkErrorWrapper(e),
1162 ))
1163 })?;
1164
1165 Ok(response)
1166 }
1167}
1168
1169#[derive(Debug, Clone)]
1175pub struct RaftNodeConfig {
1176 pub node_id: String,
1178 pub standalone: bool,
1180 pub data_dir: std::path::PathBuf,
1182 pub heartbeat_interval_ms: u64,
1184 pub election_timeout_min_ms: u64,
1186 pub election_timeout_max_ms: u64,
1187 pub snapshot_threshold: u64,
1189 pub initial_members: Vec<(NodeId, BasicNode)>,
1191 pub cluster_secret: Option<String>,
1193}
1194
1195#[allow(dead_code)]
1201pub(crate) struct PendingBatch {
1202 commands: Vec<MetadataCommand>,
1204 responders: Vec<tokio::sync::oneshot::Sender<Result<MetadataResponse>>>,
1206 started: std::time::Instant,
1208}
1209
1210#[derive(Debug, Clone)]
1212pub struct BatchConfig {
1213 pub max_batch_size: usize,
1215 pub max_wait_us: u64,
1217 pub enabled: bool,
1219}
1220
1221impl Default for BatchConfig {
1222 fn default() -> Self {
1223 Self {
1224 max_batch_size: 100,
1225 max_wait_us: 1000, enabled: true,
1227 }
1228 }
1229}
1230
1231pub struct BatchAccumulator {
1239 pending: tokio::sync::Mutex<Option<PendingBatch>>,
1241 config: BatchConfig,
1243 notify: tokio::sync::Notify,
1245}
1246
1247impl BatchAccumulator {
1248 pub fn new(config: BatchConfig) -> Self {
1250 Self {
1251 pending: tokio::sync::Mutex::new(None),
1252 config,
1253 notify: tokio::sync::Notify::new(),
1254 }
1255 }
1256
1257 pub async fn add(
1259 &self,
1260 command: MetadataCommand,
1261 ) -> tokio::sync::oneshot::Receiver<Result<MetadataResponse>> {
1262 let (tx, rx) = tokio::sync::oneshot::channel();
1263
1264 let should_flush = {
1265 let mut pending = self.pending.lock().await;
1266
1267 if pending.is_none() {
1268 *pending = Some(PendingBatch {
1269 commands: vec![command],
1270 responders: vec![tx],
1271 started: std::time::Instant::now(),
1272 });
1273 false
1274 } else if let Some(batch) = pending.as_mut() {
1275 batch.commands.push(command);
1276 batch.responders.push(tx);
1277 batch.commands.len() >= self.config.max_batch_size
1278 } else {
1279 unreachable!("pending was checked to be Some")
1280 }
1281 };
1282
1283 self.notify.notify_one();
1284
1285 if should_flush {
1286 self.notify.notify_one();
1288 }
1289
1290 rx
1291 }
1292
1293 #[allow(dead_code)]
1295 pub(crate) async fn take_if_ready(&self) -> Option<PendingBatch> {
1296 let mut pending = self.pending.lock().await;
1297
1298 if let Some(ref batch) = *pending {
1299 let elapsed = batch.started.elapsed();
1300 let size = batch.commands.len();
1301
1302 if size >= self.config.max_batch_size
1303 || elapsed.as_micros() as u64 >= self.config.max_wait_us
1304 {
1305 return pending.take();
1306 }
1307 }
1308 None
1309 }
1310
1311 pub async fn wait_ready(&self) {
1313 let timeout = std::time::Duration::from_micros(self.config.max_wait_us);
1314 let _ = tokio::time::timeout(timeout, self.notify.notified()).await;
1315 }
1316}
1317
1318impl Default for RaftNodeConfig {
1319 fn default() -> Self {
1320 Self {
1321 node_id: "node-1".to_string(),
1322 standalone: true,
1323 data_dir: std::path::PathBuf::from("./data/raft"),
1324 heartbeat_interval_ms: 150,
1325 election_timeout_min_ms: 300,
1326 election_timeout_max_ms: 600,
1327 snapshot_threshold: 10000,
1328 initial_members: vec![],
1329 cluster_secret: None,
1330 }
1331 }
1332}
1333
1334pub struct RaftNode {
1336 raft: Option<openraft::Raft<TypeConfig>>,
1338 #[allow(dead_code)]
1340 log_store: Option<Arc<LogStore>>,
1341 state_machine: StateMachine,
1343 network: NetworkFactory,
1345 node_id: NodeId,
1347 node_id_str: String,
1349 standalone: bool,
1351 next_index: RwLock<u64>,
1353 data_dir: std::path::PathBuf,
1355 raft_config: RaftNodeConfig,
1357 cluster_secret: Option<String>,
1359}
1360
1361impl RaftNode {
1362 pub async fn new(config: &ClusterConfig) -> Result<Self> {
1364 let raft_config = RaftNodeConfig {
1365 node_id: config.node_id.clone(),
1366 standalone: config.mode == crate::config::ClusterMode::Standalone,
1367 data_dir: config.data_dir.join("raft"),
1368 heartbeat_interval_ms: config.raft.heartbeat_interval.as_millis() as u64,
1369 election_timeout_min_ms: config.raft.election_timeout_min.as_millis() as u64,
1370 election_timeout_max_ms: config.raft.election_timeout_max.as_millis() as u64,
1371 snapshot_threshold: config.raft.snapshot_threshold,
1372 initial_members: vec![],
1373 cluster_secret: config.raft.cluster_secret.clone(),
1374 };
1375 Self::with_config(raft_config).await
1376 }
1377
1378 pub async fn with_config(config: RaftNodeConfig) -> Result<Self> {
1380 std::fs::create_dir_all(&config.data_dir)
1381 .map_err(|e| ClusterError::RaftStorage(e.to_string()))?;
1382
1383 let snapshot_dir = config.data_dir.join("snapshots");
1384 let state_machine = StateMachine::new(snapshot_dir);
1385 let network = NetworkFactory::new()
1386 .map_err(|e| {
1387 ClusterError::RaftStorage(format!("Failed to create network factory: {}", e))
1388 })?
1389 .with_cluster_secret(config.cluster_secret.clone());
1390 let node_id = hash_node_id(&config.node_id);
1391
1392 info!(
1393 node_id,
1394 node_id_str = %config.node_id,
1395 standalone = config.standalone,
1396 data_dir = %config.data_dir.display(),
1397 "Created Raft node"
1398 );
1399
1400 Ok(Self {
1401 raft: None,
1402 log_store: None,
1403 state_machine,
1404 network,
1405 node_id,
1406 node_id_str: config.node_id.clone(),
1407 standalone: config.standalone,
1408 next_index: RwLock::new(1),
1409 data_dir: config.data_dir.clone(),
1410 cluster_secret: config.cluster_secret.clone(),
1411 raft_config: config,
1412 })
1413 }
1414
1415 pub async fn start(&mut self) -> Result<()> {
1417 if self.standalone {
1418 info!(node_id = self.node_id, "Starting in standalone mode");
1419 return Ok(());
1420 }
1421
1422 let log_store = LogStore::new(&self.data_dir)
1424 .map_err(|e| ClusterError::RaftStorage(format!("Failed to create log store: {}", e)))?;
1425
1426 let raft_config = openraft::Config {
1428 cluster_name: "rivven-cluster".to_string(),
1429 heartbeat_interval: self.raft_config.heartbeat_interval_ms,
1430 election_timeout_min: self.raft_config.election_timeout_min_ms,
1431 election_timeout_max: self.raft_config.election_timeout_max_ms,
1432 snapshot_policy: openraft::SnapshotPolicy::LogsSinceLast(
1433 self.raft_config.snapshot_threshold,
1434 ),
1435 max_in_snapshot_log_to_keep: 1000,
1436 ..Default::default()
1437 };
1438
1439 let raft_config = Arc::new(
1440 raft_config
1441 .validate()
1442 .map_err(|e| ClusterError::RaftStorage(format!("Invalid Raft config: {}", e)))?,
1443 );
1444
1445 let snapshot_dir = self.data_dir.join("snapshots");
1447 let state_machine = StateMachine::new(&snapshot_dir);
1448
1449 match state_machine.load_latest_snapshot().await {
1451 Ok(true) => info!("Restored state machine from snapshot file"),
1452 Ok(false) => debug!("No existing snapshot found, starting fresh"),
1453 Err(e) => warn!(error = %e, "Failed to load snapshot, starting fresh"),
1454 }
1455
1456 let network = NetworkFactory::new()
1458 .map_err(|e| ClusterError::Network(format!("Failed to create network factory: {}", e)))?
1459 .with_cluster_secret(self.cluster_secret.clone());
1460 for (id, addr) in self.network.nodes.read().await.iter() {
1462 network.add_node(*id, addr.clone()).await;
1463 }
1464
1465 let raft =
1467 openraft::Raft::new(self.node_id, raft_config, network, log_store, state_machine)
1468 .await
1469 .map_err(|e| ClusterError::RaftStorage(format!("Failed to create Raft: {}", e)))?;
1470
1471 self.raft = Some(raft);
1472
1473 info!(
1474 node_id = self.node_id,
1475 node_id_str = %self.node_id_str,
1476 "Cluster mode Raft initialized and ready"
1477 );
1478 Ok(())
1479 }
1480
1481 pub async fn bootstrap(&self, members: BTreeMap<NodeId, BasicNode>) -> Result<()> {
1484 if self.standalone {
1485 return Ok(());
1486 }
1487
1488 if let Some(ref raft) = self.raft {
1489 raft.initialize(members)
1490 .await
1491 .map_err(|e| ClusterError::RaftStorage(format!("Failed to bootstrap: {}", e)))?;
1492 info!(node_id = self.node_id, "Bootstrapped Raft cluster");
1493 }
1494 Ok(())
1495 }
1496
1497 pub async fn propose(&self, command: MetadataCommand) -> Result<MetadataResponse> {
1499 use crate::observability::RaftMetrics;
1500 let start = std::time::Instant::now();
1501
1502 if self.standalone {
1503 let index = {
1505 let mut next = self.next_index.write().await;
1506 let idx = *next;
1507 *next += 1;
1508 idx
1509 };
1510 let log_id = LogId::new(openraft::CommittedLeaderId::new(0, self.node_id), index);
1515 let response = self.state_machine.apply_command(&log_id, command).await;
1516
1517 RaftMetrics::increment_proposals();
1518 RaftMetrics::increment_commits();
1519 RaftMetrics::record_proposal_latency(start.elapsed());
1520
1521 return Ok(response.response);
1522 }
1523
1524 if let Some(ref raft) = self.raft {
1526 let request = RaftRequest { command };
1527 let result = raft
1528 .client_write(request)
1529 .await
1530 .map_err(|e| ClusterError::RaftStorage(format!("Client write failed: {}", e)))?;
1531
1532 RaftMetrics::increment_proposals();
1533 RaftMetrics::increment_commits();
1534 RaftMetrics::record_proposal_latency(start.elapsed());
1535
1536 return Ok(result.data.response);
1537 }
1538
1539 Err(ClusterError::RaftStorage(
1540 "Raft not initialized".to_string(),
1541 ))
1542 }
1543
1544 pub async fn propose_batch(
1553 &self,
1554 commands: Vec<MetadataCommand>,
1555 ) -> Result<Vec<MetadataResponse>> {
1556 use crate::observability::RaftMetrics;
1557
1558 if commands.is_empty() {
1559 return Ok(vec![]);
1560 }
1561
1562 let batch_size = commands.len();
1563 RaftMetrics::record_batch_size(batch_size);
1564
1565 if self.standalone {
1566 let mut responses = Vec::with_capacity(commands.len());
1568 for command in commands {
1569 let index = {
1570 let mut next = self.next_index.write().await;
1571 let idx = *next;
1572 *next += 1;
1573 idx
1574 };
1575 let log_id = LogId::new(openraft::CommittedLeaderId::new(0, self.node_id), index);
1577 let response = self.state_machine.apply_command(&log_id, command).await;
1578 responses.push(response.response);
1579 }
1580 return Ok(responses);
1581 }
1582
1583 if let Some(ref raft) = self.raft {
1586 let batch_command = MetadataCommand::Batch(commands);
1587 let request = RaftRequest {
1588 command: batch_command,
1589 };
1590
1591 let result = raft
1592 .client_write(request)
1593 .await
1594 .map_err(|e| ClusterError::RaftStorage(format!("Batch write failed: {}", e)))?;
1595
1596 RaftMetrics::increment_proposals();
1597 RaftMetrics::increment_commits();
1598
1599 match result.data.response {
1601 MetadataResponse::BatchResponses(responses) => return Ok(responses),
1602 other => return Ok(vec![other; batch_size]),
1605 }
1606 }
1607
1608 Err(ClusterError::RaftStorage(
1609 "Raft not initialized".to_string(),
1610 ))
1611 }
1612
1613 pub async fn ensure_linearizable_read(&self) -> Result<()> {
1624 if self.standalone {
1625 return Ok(());
1627 }
1628
1629 if let Some(ref raft) = self.raft {
1630 let applied = raft.ensure_linearizable().await.map_err(|e| {
1633 ClusterError::RaftStorage(format!("Linearizable read failed: {}", e))
1634 })?;
1635
1636 debug!(
1637 applied_log = %applied.map(|l| l.index.to_string()).unwrap_or_else(|| "none".to_string()),
1638 "Linearizable read confirmed"
1639 );
1640 return Ok(());
1641 }
1642
1643 Err(ClusterError::RaftStorage(
1644 "Raft not initialized".to_string(),
1645 ))
1646 }
1647
1648 pub async fn linearizable_metadata(
1653 &self,
1654 ) -> Result<tokio::sync::RwLockReadGuard<'_, ClusterMetadata>> {
1655 self.ensure_linearizable_read().await?;
1657 Ok(self.state_machine.metadata().await)
1659 }
1660
1661 pub async fn metadata(&self) -> tokio::sync::RwLockReadGuard<'_, ClusterMetadata> {
1663 self.state_machine.metadata().await
1664 }
1665
1666 pub fn is_leader(&self) -> bool {
1668 if self.standalone {
1669 return true;
1670 }
1671
1672 if let Some(ref raft) = self.raft {
1673 let metrics = raft.metrics().borrow().clone();
1674 return metrics.current_leader == Some(self.node_id);
1675 }
1676 false
1677 }
1678
1679 pub fn leader(&self) -> Option<NodeId> {
1681 if self.standalone {
1682 return Some(self.node_id);
1683 }
1684
1685 if let Some(ref raft) = self.raft {
1686 let metrics = raft.metrics().borrow().clone();
1687 return metrics.current_leader;
1688 }
1689 None
1690 }
1691
1692 pub fn node_id(&self) -> NodeId {
1694 self.node_id
1695 }
1696
1697 pub fn node_id_str(&self) -> &str {
1699 &self.node_id_str
1700 }
1701
1702 pub fn get_raft(&self) -> Option<&openraft::Raft<TypeConfig>> {
1704 self.raft.as_ref()
1705 }
1706
1707 pub async fn add_peer(&self, node_id: NodeId, addr: String) {
1709 self.network.add_node(node_id, addr).await;
1710 }
1711
1712 pub async fn remove_peer(&self, node_id: NodeId) {
1714 self.network.remove_node(node_id).await;
1715 }
1716
1717 pub async fn snapshot(&self) -> Result<()> {
1719 if !self.standalone {
1721 if let Some(ref raft) = self.raft {
1722 raft.trigger().snapshot().await.map_err(|e| {
1723 ClusterError::RaftStorage(format!("Snapshot trigger failed: {}", e))
1724 })?;
1725 info!(node_id = self.node_id, "Triggered Raft snapshot");
1726 return Ok(());
1727 }
1728 }
1729
1730 let (_meta, data) = self
1732 .state_machine
1733 .create_snapshot()
1734 .await
1735 .map_err(|e| ClusterError::RaftStorage(format!("{}", e)))?;
1736
1737 info!(path = %data.display(), "Created standalone snapshot");
1738 Ok(())
1739 }
1740
1741 pub fn metrics(&self) -> Option<openraft::RaftMetrics<NodeId, BasicNode>> {
1743 self.raft.as_ref().map(|r| r.metrics().borrow().clone())
1744 }
1745
1746 pub async fn handle_append_entries(
1752 &self,
1753 req: openraft::raft::AppendEntriesRequest<TypeConfig>,
1754 ) -> std::result::Result<openraft::raft::AppendEntriesResponse<NodeId>, ClusterError> {
1755 if let Some(ref raft) = self.raft {
1756 raft.append_entries(req)
1757 .await
1758 .map_err(|e| ClusterError::RaftStorage(format!("{}", e)))
1759 } else {
1760 Err(ClusterError::RaftStorage(
1761 "Raft not initialized".to_string(),
1762 ))
1763 }
1764 }
1765
1766 pub async fn handle_install_snapshot(
1768 &self,
1769 req: openraft::raft::InstallSnapshotRequest<TypeConfig>,
1770 ) -> std::result::Result<openraft::raft::InstallSnapshotResponse<NodeId>, ClusterError> {
1771 if let Some(ref raft) = self.raft {
1772 raft.install_snapshot(req)
1773 .await
1774 .map_err(|e| ClusterError::RaftStorage(format!("{}", e)))
1775 } else {
1776 Err(ClusterError::RaftStorage(
1777 "Raft not initialized".to_string(),
1778 ))
1779 }
1780 }
1781
1782 pub async fn handle_vote(
1784 &self,
1785 req: openraft::raft::VoteRequest<NodeId>,
1786 ) -> std::result::Result<openraft::raft::VoteResponse<NodeId>, ClusterError> {
1787 if let Some(ref raft) = self.raft {
1788 raft.vote(req)
1789 .await
1790 .map_err(|e| ClusterError::RaftStorage(format!("{}", e)))
1791 } else {
1792 Err(ClusterError::RaftStorage(
1793 "Raft not initialized".to_string(),
1794 ))
1795 }
1796 }
1797
1798 pub fn verify_cluster_secret(
1810 &self,
1811 header_value: Option<&str>,
1812 ) -> std::result::Result<(), ClusterError> {
1813 let Some(ref expected) = self.cluster_secret else {
1814 return Ok(()); };
1816
1817 let Some(provided) = header_value else {
1818 return Err(ClusterError::Unauthorized(
1819 "Missing X-Rivven-Cluster-Secret header".to_string(),
1820 ));
1821 };
1822
1823 let expected_bytes = expected.as_bytes();
1827 let provided_bytes = provided.as_bytes();
1828 let max_len = expected_bytes.len().max(provided_bytes.len());
1829 let mut diff = (expected_bytes.len() != provided_bytes.len()) as u8;
1831 for i in 0..max_len {
1832 let a = expected_bytes.get(i).copied().unwrap_or(0);
1833 let b = provided_bytes.get(i).copied().unwrap_or(0);
1834 diff |= a ^ b;
1835 }
1836 if diff != 0 {
1837 return Err(ClusterError::Unauthorized(
1838 "Invalid cluster secret".to_string(),
1839 ));
1840 }
1841
1842 Ok(())
1843 }
1844
1845 pub fn cluster_secret(&self) -> Option<&str> {
1847 self.cluster_secret.as_deref()
1848 }
1849}
1850
1851pub fn hash_node_id(node_id: &str) -> NodeId {
1861 let mut hash: u64 = 0xcbf29ce484222325; for byte in node_id.as_bytes() {
1864 hash ^= *byte as u64;
1865 hash = hash.wrapping_mul(0x100000001b3); }
1867 hash
1868}
1869
1870pub type RaftNodeId = NodeId;
1876
1877pub type RaftController = RaftNode;
1879
1880pub use openraft::storage::RaftLogStorage as RaftLogStorageTrait;
1882
1883#[cfg(test)]
1884mod tests {
1885 use super::*;
1886 use openraft::storage::RaftLogStorage;
1887 use tempfile::TempDir;
1888
1889 #[tokio::test]
1890 async fn test_log_storage_creation() {
1891 let temp_dir = TempDir::new().unwrap();
1892 let path = temp_dir.path().join("raft.redb");
1893 let mut storage = LogStore::new(&path).unwrap();
1894
1895 let state = storage.get_log_state().await.unwrap();
1897 assert!(state.last_log_id.is_none());
1898 }
1899
1900 #[tokio::test]
1901 async fn test_state_machine_apply() {
1902 let temp_dir = TempDir::new().unwrap();
1903 let sm = StateMachine::new(temp_dir.path().join("snapshots"));
1904 let log_id = LogId::new(openraft::CommittedLeaderId::new(1, 1), 1);
1905
1906 let cmd = MetadataCommand::CreateTopic {
1907 config: crate::partition::TopicConfig::new("test-topic", 3, 1),
1908 partition_assignments: vec![
1909 vec!["node-1".into()],
1910 vec!["node-1".into()],
1911 vec!["node-1".into()],
1912 ],
1913 };
1914
1915 let response = sm.apply_command(&log_id, cmd).await;
1916 assert!(matches!(
1917 response.response,
1918 MetadataResponse::TopicCreated { .. }
1919 ));
1920
1921 let metadata = sm.metadata().await;
1923 assert!(metadata.topics.contains_key("test-topic"));
1924 }
1925
1926 #[tokio::test]
1927 async fn test_raft_node_standalone() {
1928 let temp_dir = TempDir::new().unwrap();
1929 let config = ClusterConfig {
1930 data_dir: temp_dir.path().to_path_buf(),
1931 ..ClusterConfig::standalone()
1932 };
1933
1934 let mut node = RaftNode::new(&config).await.unwrap();
1935 node.start().await.unwrap();
1936
1937 assert!(node.is_leader());
1938
1939 let response = node.propose(MetadataCommand::Noop).await.unwrap();
1941 assert!(matches!(response, MetadataResponse::Success));
1942 }
1943
1944 #[test]
1945 fn test_hash_node_id() {
1946 let id1 = hash_node_id("node-1");
1947 let id2 = hash_node_id("node-2");
1948 let id1_again = hash_node_id("node-1");
1949
1950 assert_ne!(id1, id2);
1951 assert_eq!(id1, id1_again);
1952 }
1953}