#![allow(clippy::result_large_err)]
use crate::config::ClusterConfig;
use crate::error::{ClusterError, Result};
use crate::metadata::{ClusterMetadata, MetadataCommand, MetadataResponse};
use openraft::network::{RPCOption, RaftNetwork, RaftNetworkFactory};
use openraft::raft::responder::OneshotResponder;
use openraft::storage::{LogState, RaftLogReader, RaftLogStorage, RaftStateMachine, Snapshot};
use openraft::{
BasicNode, Entry, EntryPayload, LogId, Membership, RaftTypeConfig, SnapshotMeta, StorageError,
StorageIOError, StoredMembership, Vote,
};
use serde::{Deserialize, Serialize};
use std::collections::BTreeMap;
use std::fmt::Debug;
use std::io::Cursor;
use std::ops::RangeBounds;
use std::path::Path;
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{debug, info};
#[derive(Debug)]
struct NetworkErrorWrapper(String);
impl std::fmt::Display for NetworkErrorWrapper {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
impl std::error::Error for NetworkErrorWrapper {}
pub type NodeId = u64;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct RaftRequest {
pub command: MetadataCommand,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct RaftResponse {
pub response: MetadataResponse,
}
#[derive(Debug, Clone, Copy, Default, Eq, PartialEq, Ord, PartialOrd)]
#[cfg_attr(feature = "raft", derive(Serialize, Deserialize))]
pub struct TypeConfig;
impl RaftTypeConfig for TypeConfig {
type D = RaftRequest;
type R = RaftResponse;
type NodeId = NodeId;
type Node = BasicNode;
type Entry = Entry<TypeConfig>;
type SnapshotData = Cursor<Vec<u8>>;
type AsyncRuntime = openraft::TokioRuntime;
type Responder = OneshotResponder<TypeConfig>;
}
pub type RaftLogId = LogId<NodeId>;
pub type RaftVote = Vote<NodeId>;
pub type RaftEntry = Entry<TypeConfig>;
pub type RaftMembership = Membership<NodeId, BasicNode>;
pub type RaftStoredMembership = StoredMembership<NodeId, BasicNode>;
pub type RaftSnapshot = Snapshot<TypeConfig>;
pub type RaftSnapshotMeta = SnapshotMeta<NodeId, BasicNode>;
pub struct LogStore {
db: Arc<rocksdb::DB>,
vote: RwLock<Option<RaftVote>>,
last_purged: RwLock<Option<RaftLogId>>,
committed: RwLock<Option<RaftLogId>>,
}
impl LogStore {
const CF_LOGS: &'static str = "raft_logs";
const CF_STATE: &'static str = "raft_state";
const KEY_VOTE: &'static [u8] = b"vote";
const KEY_LAST_PURGED: &'static [u8] = b"last_purged";
const KEY_COMMITTED: &'static [u8] = b"committed";
pub fn new(path: impl AsRef<Path>) -> Result<Self> {
let path = path.as_ref();
std::fs::create_dir_all(path)
.map_err(|e| ClusterError::RaftStorage(format!("Failed to create dir: {}", e)))?;
let mut opts = rocksdb::Options::default();
opts.create_if_missing(true);
opts.create_missing_column_families(true);
opts.set_write_buffer_size(16 * 1024 * 1024); opts.set_max_write_buffer_number(2);
opts.set_target_file_size_base(16 * 1024 * 1024);
opts.set_wal_dir(path.join("wal"));
let cf_descriptors = vec![
rocksdb::ColumnFamilyDescriptor::new(Self::CF_LOGS, rocksdb::Options::default()),
rocksdb::ColumnFamilyDescriptor::new(Self::CF_STATE, rocksdb::Options::default()),
];
let db = rocksdb::DB::open_cf_descriptors(&opts, path, cf_descriptors)
.map_err(|e| ClusterError::RaftStorage(e.to_string()))?;
let db = Arc::new(db);
let vote = Self::load_state::<RaftVote>(&db, Self::CF_STATE, Self::KEY_VOTE);
let last_purged = Self::load_state::<RaftLogId>(&db, Self::CF_STATE, Self::KEY_LAST_PURGED);
let committed = Self::load_state::<RaftLogId>(&db, Self::CF_STATE, Self::KEY_COMMITTED);
info!(?vote, ?last_purged, ?committed, "Opened Raft log storage");
Ok(Self {
db,
vote: RwLock::new(vote),
last_purged: RwLock::new(last_purged),
committed: RwLock::new(committed),
})
}
fn cf_logs(&self) -> &rocksdb::ColumnFamily {
self.db
.cf_handle(Self::CF_LOGS)
.expect("CF_LOGS must exist")
}
fn cf_state(&self) -> &rocksdb::ColumnFamily {
self.db
.cf_handle(Self::CF_STATE)
.expect("CF_STATE must exist")
}
fn load_state<T: for<'de> Deserialize<'de>>(
db: &rocksdb::DB,
cf_name: &str,
key: &[u8],
) -> Option<T> {
let cf = db.cf_handle(cf_name)?;
let bytes = db.get_cf(cf, key).ok()??;
postcard::from_bytes(&bytes).ok()
}
fn save_state<T: Serialize>(
&self,
key: &[u8],
value: &T,
) -> std::result::Result<(), StorageError<NodeId>> {
let bytes = postcard::to_allocvec(value).map_err(|e| StorageError::IO {
source: StorageIOError::write_logs(openraft::AnyError::new(&e)),
})?;
self.db
.put_cf(self.cf_state(), key, bytes)
.map_err(|e| StorageError::IO {
source: StorageIOError::write_logs(openraft::AnyError::new(&e)),
})
}
fn index_key(index: u64) -> [u8; 8] {
index.to_be_bytes()
}
fn last_log(&self) -> std::result::Result<Option<RaftEntry>, StorageError<NodeId>> {
let cf = self.cf_logs();
let mut iter = self.db.raw_iterator_cf(cf);
iter.seek_to_last();
if iter.valid() {
if let Some(value) = iter.value() {
let entry: RaftEntry =
postcard::from_bytes(value).map_err(|e| StorageError::IO {
source: StorageIOError::read_logs(openraft::AnyError::new(&e)),
})?;
return Ok(Some(entry));
}
}
Ok(None)
}
fn get_log(&self, index: u64) -> std::result::Result<Option<RaftEntry>, StorageError<NodeId>> {
let key = Self::index_key(index);
match self.db.get_cf(self.cf_logs(), key) {
Ok(Some(bytes)) => {
let entry: RaftEntry =
postcard::from_bytes(&bytes).map_err(|e| StorageError::IO {
source: StorageIOError::read_logs(openraft::AnyError::new(&e)),
})?;
Ok(Some(entry))
}
Ok(None) => Ok(None),
Err(e) => Err(StorageError::IO {
source: StorageIOError::read_logs(openraft::AnyError::new(&e)),
}),
}
}
fn append_log(&self, entry: &RaftEntry) -> std::result::Result<(), StorageError<NodeId>> {
let key = Self::index_key(entry.log_id.index);
let value = postcard::to_allocvec(entry).map_err(|e| StorageError::IO {
source: StorageIOError::write_logs(openraft::AnyError::new(&e)),
})?;
self.db
.put_cf(self.cf_logs(), key, value)
.map_err(|e| StorageError::IO {
source: StorageIOError::write_logs(openraft::AnyError::new(&e)),
})
}
fn delete_logs_range(
&self,
start: u64,
end: u64,
) -> std::result::Result<(), StorageError<NodeId>> {
let cf = self.cf_logs();
let mut batch = rocksdb::WriteBatch::default();
for index in start..end {
batch.delete_cf(cf, Self::index_key(index));
}
self.db.write(batch).map_err(|e| StorageError::IO {
source: StorageIOError::write_logs(openraft::AnyError::new(&e)),
})
}
}
impl RaftLogReader<TypeConfig> for LogStore {
async fn try_get_log_entries<RB: RangeBounds<u64> + Clone + Debug + Send>(
&mut self,
range: RB,
) -> std::result::Result<Vec<RaftEntry>, StorageError<NodeId>> {
let start = match range.start_bound() {
std::ops::Bound::Included(&n) => n,
std::ops::Bound::Excluded(&n) => n + 1,
std::ops::Bound::Unbounded => 0,
};
let end = match range.end_bound() {
std::ops::Bound::Included(&n) => n + 1,
std::ops::Bound::Excluded(&n) => n,
std::ops::Bound::Unbounded => u64::MAX,
};
let mut entries = Vec::new();
for index in start..end {
match self.get_log(index)? {
Some(entry) => entries.push(entry),
None => break, }
}
Ok(entries)
}
}
impl RaftLogStorage<TypeConfig> for LogStore {
type LogReader = Self;
async fn get_log_state(
&mut self,
) -> std::result::Result<LogState<TypeConfig>, StorageError<NodeId>> {
let last_purged = *self.last_purged.read().await;
let last_log = self.last_log()?;
let last_log_id = last_log.map(|e| e.log_id).or(last_purged);
Ok(LogState {
last_purged_log_id: last_purged,
last_log_id,
})
}
async fn get_log_reader(&mut self) -> Self::LogReader {
Self {
db: self.db.clone(),
vote: RwLock::new(*self.vote.read().await),
last_purged: RwLock::new(*self.last_purged.read().await),
committed: RwLock::new(*self.committed.read().await),
}
}
async fn save_vote(
&mut self,
vote: &RaftVote,
) -> std::result::Result<(), StorageError<NodeId>> {
self.save_state(Self::KEY_VOTE, vote)?;
*self.vote.write().await = Some(*vote);
debug!(?vote, "Saved vote");
Ok(())
}
async fn read_vote(&mut self) -> std::result::Result<Option<RaftVote>, StorageError<NodeId>> {
Ok(*self.vote.read().await)
}
async fn save_committed(
&mut self,
committed: Option<RaftLogId>,
) -> std::result::Result<(), StorageError<NodeId>> {
if let Some(ref c) = committed {
self.save_state(Self::KEY_COMMITTED, c)?;
}
*self.committed.write().await = committed;
Ok(())
}
async fn read_committed(
&mut self,
) -> std::result::Result<Option<RaftLogId>, StorageError<NodeId>> {
Ok(*self.committed.read().await)
}
async fn append<I>(
&mut self,
entries: I,
callback: openraft::storage::LogFlushed<TypeConfig>,
) -> std::result::Result<(), StorageError<NodeId>>
where
I: IntoIterator<Item = RaftEntry> + Send,
I::IntoIter: Send,
{
for entry in entries {
self.append_log(&entry)?;
}
self.db.flush().map_err(|e| StorageError::IO {
source: StorageIOError::write_logs(openraft::AnyError::new(&e)),
})?;
callback.log_io_completed(Ok(()));
Ok(())
}
async fn truncate(
&mut self,
log_id: RaftLogId,
) -> std::result::Result<(), StorageError<NodeId>> {
let start = log_id.index + 1;
let log_state = RaftLogStorage::get_log_state(self).await?;
if let Some(last) = log_state.last_log_id {
self.delete_logs_range(start, last.index + 1)?;
}
debug!(?log_id, "Truncated logs");
Ok(())
}
async fn purge(&mut self, log_id: RaftLogId) -> std::result::Result<(), StorageError<NodeId>> {
let current_purged = *self.last_purged.read().await;
let start = current_purged.map(|l| l.index + 1).unwrap_or(0);
self.delete_logs_range(start, log_id.index + 1)?;
self.save_state(Self::KEY_LAST_PURGED, &log_id)?;
*self.last_purged.write().await = Some(log_id);
debug!(?log_id, "Purged logs");
Ok(())
}
}
pub struct StateMachine {
metadata: RwLock<ClusterMetadata>,
last_applied: RwLock<Option<RaftLogId>>,
membership: RwLock<RaftStoredMembership>,
}
impl StateMachine {
pub fn new() -> Self {
Self {
metadata: RwLock::new(ClusterMetadata::new()),
last_applied: RwLock::new(None),
membership: RwLock::new(StoredMembership::new(None, Membership::new(vec![], ()))),
}
}
pub async fn metadata(&self) -> tokio::sync::RwLockReadGuard<'_, ClusterMetadata> {
self.metadata.read().await
}
async fn apply_command(&self, log_id: &RaftLogId, command: MetadataCommand) -> RaftResponse {
let mut metadata = self.metadata.write().await;
let response = metadata.apply(log_id.index, command);
*self.last_applied.write().await = Some(*log_id);
RaftResponse { response }
}
async fn create_snapshot(
&self,
) -> std::result::Result<(RaftSnapshotMeta, Vec<u8>), StorageError<NodeId>> {
let metadata = self.metadata.read().await.clone();
let last_applied = *self.last_applied.read().await;
let membership = self.membership.read().await.clone();
let snapshot_data = SnapshotData {
metadata: metadata.clone(),
last_applied,
membership: membership.clone(),
};
let data = postcard::to_allocvec(&snapshot_data).map_err(|e| StorageError::IO {
source: StorageIOError::read_state_machine(openraft::AnyError::new(&e)),
})?;
let meta = SnapshotMeta {
last_log_id: snapshot_data.last_applied,
last_membership: membership,
snapshot_id: format!("snapshot-{}", metadata.last_applied_index),
};
info!(
snapshot_id = %meta.snapshot_id,
last_log_id = ?meta.last_log_id,
"Created snapshot"
);
Ok((meta, data))
}
async fn install_snapshot_data(
&self,
data: &[u8],
) -> std::result::Result<(), StorageError<NodeId>> {
let snapshot_data: SnapshotData =
postcard::from_bytes(data).map_err(|e| StorageError::IO {
source: StorageIOError::read_state_machine(openraft::AnyError::new(&e)),
})?;
*self.metadata.write().await = snapshot_data.metadata;
*self.last_applied.write().await = snapshot_data.last_applied;
*self.membership.write().await = snapshot_data.membership;
info!("Installed snapshot");
Ok(())
}
}
impl Default for StateMachine {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct SnapshotData {
metadata: ClusterMetadata,
last_applied: Option<RaftLogId>,
membership: RaftStoredMembership,
}
impl RaftStateMachine<TypeConfig> for StateMachine {
type SnapshotBuilder = Self;
async fn applied_state(
&mut self,
) -> std::result::Result<(Option<RaftLogId>, RaftStoredMembership), StorageError<NodeId>> {
let last_applied = *self.last_applied.read().await;
let membership = self.membership.read().await.clone();
Ok((last_applied, membership))
}
async fn apply<I>(
&mut self,
entries: I,
) -> std::result::Result<Vec<RaftResponse>, StorageError<NodeId>>
where
I: IntoIterator<Item = RaftEntry> + Send,
I::IntoIter: Send,
{
let mut responses = Vec::new();
for entry in entries {
let log_id = entry.log_id;
match entry.payload {
EntryPayload::Blank => {
*self.last_applied.write().await = Some(log_id);
responses.push(RaftResponse {
response: MetadataResponse::Success,
});
}
EntryPayload::Normal(req) => {
let response = self.apply_command(&log_id, req.command).await;
responses.push(response);
}
EntryPayload::Membership(membership) => {
*self.membership.write().await =
StoredMembership::new(Some(log_id), membership);
*self.last_applied.write().await = Some(log_id);
responses.push(RaftResponse {
response: MetadataResponse::Success,
});
}
}
}
Ok(responses)
}
async fn begin_receiving_snapshot(
&mut self,
) -> std::result::Result<Box<Cursor<Vec<u8>>>, StorageError<NodeId>> {
Ok(Box::new(Cursor::new(Vec::new())))
}
async fn install_snapshot(
&mut self,
meta: &RaftSnapshotMeta,
snapshot: Box<Cursor<Vec<u8>>>,
) -> std::result::Result<(), StorageError<NodeId>> {
let data = snapshot.into_inner();
self.install_snapshot_data(&data).await?;
*self.membership.write().await = meta.last_membership.clone();
info!(
snapshot_id = %meta.snapshot_id,
"Installed snapshot from leader"
);
Ok(())
}
async fn get_current_snapshot(
&mut self,
) -> std::result::Result<Option<RaftSnapshot>, StorageError<NodeId>> {
let (meta, data) = self.create_snapshot().await?;
Ok(Some(Snapshot {
meta,
snapshot: Box::new(Cursor::new(data)),
}))
}
async fn get_snapshot_builder(&mut self) -> Self::SnapshotBuilder {
Self {
metadata: RwLock::new(self.metadata.read().await.clone()),
last_applied: RwLock::new(*self.last_applied.read().await),
membership: RwLock::new(self.membership.read().await.clone()),
}
}
}
impl openraft::storage::RaftSnapshotBuilder<TypeConfig> for StateMachine {
async fn build_snapshot(&mut self) -> std::result::Result<RaftSnapshot, StorageError<NodeId>> {
let (meta, data) = self.create_snapshot().await?;
Ok(Snapshot {
meta,
snapshot: Box::new(Cursor::new(data)),
})
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum SerializationFormat {
Json,
#[default]
Binary,
}
#[derive(Debug, Clone)]
pub struct RaftCompressionConfig {
pub enabled: bool,
pub min_size: usize,
pub adaptive: bool,
}
impl Default for RaftCompressionConfig {
fn default() -> Self {
Self {
enabled: true,
min_size: 1024, adaptive: true,
}
}
}
#[derive(Clone)]
pub struct NetworkFactory {
nodes: Arc<RwLock<BTreeMap<NodeId, String>>>,
client: reqwest::Client,
format: SerializationFormat,
compression: RaftCompressionConfig,
}
impl NetworkFactory {
pub fn new() -> Self {
Self::with_format(SerializationFormat::Binary)
}
pub fn with_format(format: SerializationFormat) -> Self {
Self {
nodes: Arc::new(RwLock::new(BTreeMap::new())),
client: reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(5))
.pool_max_idle_per_host(10) .pool_idle_timeout(std::time::Duration::from_secs(60))
.tcp_keepalive(std::time::Duration::from_secs(30))
.tcp_nodelay(true) .build()
.expect("Failed to create HTTP client"),
format,
compression: RaftCompressionConfig::default(),
}
}
pub fn with_compression(
format: SerializationFormat,
compression: RaftCompressionConfig,
) -> Self {
Self {
compression,
..Self::with_format(format)
}
}
pub async fn add_node(&self, node_id: NodeId, addr: String) {
self.nodes.write().await.insert(node_id, addr);
}
pub async fn remove_node(&self, node_id: NodeId) {
self.nodes.write().await.remove(&node_id);
}
}
impl Default for NetworkFactory {
fn default() -> Self {
Self::new()
}
}
pub struct Network {
#[allow(dead_code)]
target: NodeId,
target_addr: String,
client: reqwest::Client,
format: SerializationFormat,
compression: RaftCompressionConfig,
}
impl Network {
pub fn new(
target: NodeId,
target_addr: String,
client: reqwest::Client,
format: SerializationFormat,
compression: RaftCompressionConfig,
) -> Self {
Self {
target,
target_addr,
client,
format,
compression,
}
}
fn serialize<T: Serialize>(&self, data: &T) -> std::result::Result<Vec<u8>, String> {
match self.format {
SerializationFormat::Json => serde_json::to_vec(data).map_err(|e| e.to_string()),
SerializationFormat::Binary => postcard::to_allocvec(data).map_err(|e| e.to_string()),
}
}
fn deserialize<T: serde::de::DeserializeOwned>(
&self,
data: &[u8],
) -> std::result::Result<T, String> {
match self.format {
SerializationFormat::Json => serde_json::from_slice(data).map_err(|e| e.to_string()),
SerializationFormat::Binary => postcard::from_bytes(data).map_err(|e| e.to_string()),
}
}
fn content_type(&self) -> &'static str {
match self.format {
SerializationFormat::Json => "application/json",
SerializationFormat::Binary => "application/octet-stream",
}
}
#[cfg(feature = "compression")]
fn maybe_compress(&self, data: Vec<u8>) -> (Vec<u8>, bool) {
use rivven_core::compression::{CompressionConfig, Compressor};
if !self.compression.enabled || data.len() < self.compression.min_size {
return (data, false);
}
let config = CompressionConfig {
min_size: self.compression.min_size,
adaptive: self.compression.adaptive,
..Default::default()
};
let compressor = Compressor::with_config(config);
match compressor.compress(&data) {
Ok(compressed) => {
if compressed.len() < data.len() {
(compressed.to_vec(), true)
} else {
(data, false)
}
}
Err(_) => (data, false),
}
}
#[cfg(not(feature = "compression"))]
fn maybe_compress(&self, data: Vec<u8>) -> (Vec<u8>, bool) {
(data, false)
}
#[cfg(feature = "compression")]
fn maybe_decompress(
&self,
data: &[u8],
was_compressed: bool,
) -> std::result::Result<Vec<u8>, String> {
use rivven_core::compression::Compressor;
if !was_compressed {
return Ok(data.to_vec());
}
let compressor = Compressor::new();
compressor
.decompress(data)
.map(|b| b.to_vec())
.map_err(|e| e.to_string())
}
#[cfg(not(feature = "compression"))]
fn maybe_decompress(
&self,
data: &[u8],
_was_compressed: bool,
) -> std::result::Result<Vec<u8>, String> {
Ok(data.to_vec())
}
}
impl RaftNetworkFactory<TypeConfig> for NetworkFactory {
type Network = Network;
async fn new_client(&mut self, target: NodeId, node: &BasicNode) -> Self::Network {
Network::new(
target,
node.addr.clone(),
self.client.clone(),
self.format,
self.compression.clone(),
)
}
}
impl RaftNetwork<TypeConfig> for Network {
async fn append_entries(
&mut self,
rpc: openraft::raft::AppendEntriesRequest<TypeConfig>,
_option: RPCOption,
) -> std::result::Result<
openraft::raft::AppendEntriesResponse<NodeId>,
openraft::error::RPCError<NodeId, BasicNode, openraft::error::RaftError<NodeId>>,
> {
use crate::observability::{NetworkMetrics, RaftMetrics};
let start = std::time::Instant::now();
let url = format!("{}/raft/append", self.target_addr);
let serialized = self.serialize(&rpc).map_err(|e| {
openraft::error::RPCError::Network(openraft::error::NetworkError::new(
&NetworkErrorWrapper(e),
))
})?;
let (body, compressed) = self.maybe_compress(serialized);
let uncompressed_size = body.len();
NetworkMetrics::add_bytes_sent(body.len() as u64);
RaftMetrics::increment_append_entries_sent();
let mut request = self.client.post(&url).body(body);
request = request.header("Content-Type", self.content_type());
if compressed {
request = request.header("X-Rivven-Compressed", "1");
request = request.header("X-Rivven-Original-Size", uncompressed_size.to_string());
}
let resp = request.send().await.map_err(|e| {
NetworkMetrics::increment_rpc_errors("append_entries");
openraft::error::RPCError::Network(openraft::error::NetworkError::new(&e))
})?;
if !resp.status().is_success() {
NetworkMetrics::increment_rpc_errors("append_entries");
return Err(openraft::error::RPCError::Network(
openraft::error::NetworkError::new(&NetworkErrorWrapper(format!(
"HTTP error: {}",
resp.status()
))),
));
}
let resp_compressed = resp
.headers()
.get("X-Rivven-Compressed")
.map(|v| v == "1")
.unwrap_or(false);
let bytes = resp.bytes().await.map_err(|e| {
openraft::error::RPCError::Network(openraft::error::NetworkError::new(&e))
})?;
NetworkMetrics::add_bytes_received(bytes.len() as u64);
RaftMetrics::record_append_entries_latency(start.elapsed());
let response_data = self
.maybe_decompress(&bytes, resp_compressed)
.map_err(|e| {
openraft::error::RPCError::Network(openraft::error::NetworkError::new(
&NetworkErrorWrapper(e),
))
})?;
let response: openraft::raft::AppendEntriesResponse<NodeId> =
self.deserialize(&response_data).map_err(|e| {
openraft::error::RPCError::Network(openraft::error::NetworkError::new(
&NetworkErrorWrapper(e),
))
})?;
Ok(response)
}
async fn install_snapshot(
&mut self,
rpc: openraft::raft::InstallSnapshotRequest<TypeConfig>,
_option: RPCOption,
) -> std::result::Result<
openraft::raft::InstallSnapshotResponse<NodeId>,
openraft::error::RPCError<
NodeId,
BasicNode,
openraft::error::RaftError<NodeId, openraft::error::InstallSnapshotError>,
>,
> {
use crate::observability::{NetworkMetrics, RaftMetrics};
let start = std::time::Instant::now();
let url = format!("{}/raft/snapshot", self.target_addr);
let serialized = self.serialize(&rpc).map_err(|e| {
openraft::error::RPCError::Network(openraft::error::NetworkError::new(
&NetworkErrorWrapper(e),
))
})?;
let (body, compressed) = self.maybe_compress(serialized);
let uncompressed_size = body.len();
NetworkMetrics::add_bytes_sent(body.len() as u64);
let mut request = self.client.post(&url).body(body);
request = request.header("Content-Type", self.content_type());
if compressed {
request = request.header("X-Rivven-Compressed", "1");
request = request.header("X-Rivven-Original-Size", uncompressed_size.to_string());
}
let resp = request.send().await.map_err(|e| {
NetworkMetrics::increment_rpc_errors("install_snapshot");
openraft::error::RPCError::Network(openraft::error::NetworkError::new(&e))
})?;
if !resp.status().is_success() {
NetworkMetrics::increment_rpc_errors("install_snapshot");
return Err(openraft::error::RPCError::Network(
openraft::error::NetworkError::new(&NetworkErrorWrapper(format!(
"HTTP error: {}",
resp.status()
))),
));
}
let bytes = resp.bytes().await.map_err(|e| {
openraft::error::RPCError::Network(openraft::error::NetworkError::new(&e))
})?;
NetworkMetrics::add_bytes_received(bytes.len() as u64);
RaftMetrics::record_snapshot_duration(start.elapsed());
let response: openraft::raft::InstallSnapshotResponse<NodeId> =
self.deserialize(&bytes).map_err(|e| {
openraft::error::RPCError::Network(openraft::error::NetworkError::new(
&NetworkErrorWrapper(e),
))
})?;
Ok(response)
}
async fn vote(
&mut self,
rpc: openraft::raft::VoteRequest<NodeId>,
_option: RPCOption,
) -> std::result::Result<
openraft::raft::VoteResponse<NodeId>,
openraft::error::RPCError<NodeId, BasicNode, openraft::error::RaftError<NodeId>>,
> {
use crate::observability::{NetworkMetrics, RaftMetrics};
let start = std::time::Instant::now();
let url = format!("{}/raft/vote", self.target_addr);
let body = self.serialize(&rpc).map_err(|e| {
openraft::error::RPCError::Network(openraft::error::NetworkError::new(
&NetworkErrorWrapper(e),
))
})?;
NetworkMetrics::add_bytes_sent(body.len() as u64);
let resp = self
.client
.post(&url)
.body(body)
.header("Content-Type", self.content_type())
.send()
.await
.map_err(|e| {
NetworkMetrics::increment_rpc_errors("vote");
openraft::error::RPCError::Network(openraft::error::NetworkError::new(&e))
})?;
if !resp.status().is_success() {
NetworkMetrics::increment_rpc_errors("vote");
return Err(openraft::error::RPCError::Network(
openraft::error::NetworkError::new(&NetworkErrorWrapper(format!(
"HTTP error: {}",
resp.status()
))),
));
}
let bytes = resp.bytes().await.map_err(|e| {
openraft::error::RPCError::Network(openraft::error::NetworkError::new(&e))
})?;
NetworkMetrics::add_bytes_received(bytes.len() as u64);
RaftMetrics::record_vote_latency(start.elapsed());
RaftMetrics::increment_elections();
let response: openraft::raft::VoteResponse<NodeId> =
self.deserialize(&bytes).map_err(|e| {
openraft::error::RPCError::Network(openraft::error::NetworkError::new(
&NetworkErrorWrapper(e),
))
})?;
Ok(response)
}
}
#[derive(Debug, Clone)]
pub struct RaftNodeConfig {
pub node_id: String,
pub standalone: bool,
pub data_dir: std::path::PathBuf,
pub heartbeat_interval_ms: u64,
pub election_timeout_min_ms: u64,
pub election_timeout_max_ms: u64,
pub snapshot_threshold: u64,
pub initial_members: Vec<(NodeId, BasicNode)>,
}
#[allow(dead_code)]
pub(crate) struct PendingBatch {
commands: Vec<MetadataCommand>,
responders: Vec<tokio::sync::oneshot::Sender<Result<MetadataResponse>>>,
started: std::time::Instant,
}
#[derive(Debug, Clone)]
pub struct BatchConfig {
pub max_batch_size: usize,
pub max_wait_us: u64,
pub enabled: bool,
}
impl Default for BatchConfig {
fn default() -> Self {
Self {
max_batch_size: 100,
max_wait_us: 1000, enabled: true,
}
}
}
pub struct BatchAccumulator {
pending: RwLock<Option<PendingBatch>>,
config: BatchConfig,
notify: tokio::sync::Notify,
}
impl BatchAccumulator {
pub fn new(config: BatchConfig) -> Self {
Self {
pending: RwLock::new(None),
config,
notify: tokio::sync::Notify::new(),
}
}
pub async fn add(
&self,
command: MetadataCommand,
) -> tokio::sync::oneshot::Receiver<Result<MetadataResponse>> {
let (tx, rx) = tokio::sync::oneshot::channel();
let should_flush = {
let mut pending = self.pending.write().await;
if pending.is_none() {
*pending = Some(PendingBatch {
commands: vec![command],
responders: vec![tx],
started: std::time::Instant::now(),
});
false
} else {
let batch = pending.as_mut().unwrap();
batch.commands.push(command);
batch.responders.push(tx);
batch.commands.len() >= self.config.max_batch_size
}
};
self.notify.notify_one();
if should_flush {
self.notify.notify_one();
}
rx
}
#[allow(dead_code)]
pub(crate) async fn take_if_ready(&self) -> Option<PendingBatch> {
let mut pending = self.pending.write().await;
if let Some(ref batch) = *pending {
let elapsed = batch.started.elapsed();
let size = batch.commands.len();
if size >= self.config.max_batch_size
|| elapsed.as_micros() as u64 >= self.config.max_wait_us
{
return pending.take();
}
}
None
}
pub async fn wait_ready(&self) {
let timeout = std::time::Duration::from_micros(self.config.max_wait_us);
let _ = tokio::time::timeout(timeout, self.notify.notified()).await;
}
}
impl Default for RaftNodeConfig {
fn default() -> Self {
Self {
node_id: "node-1".to_string(),
standalone: true,
data_dir: std::path::PathBuf::from("./data/raft"),
heartbeat_interval_ms: 150,
election_timeout_min_ms: 300,
election_timeout_max_ms: 600,
snapshot_threshold: 10000,
initial_members: vec![],
}
}
}
pub struct RaftNode {
raft: Option<openraft::Raft<TypeConfig>>,
#[allow(dead_code)]
log_store: Option<Arc<LogStore>>,
state_machine: StateMachine,
network: NetworkFactory,
node_id: NodeId,
node_id_str: String,
standalone: bool,
next_index: RwLock<u64>,
data_dir: std::path::PathBuf,
raft_config: RaftNodeConfig,
}
impl RaftNode {
pub async fn new(config: &ClusterConfig) -> Result<Self> {
let raft_config = RaftNodeConfig {
node_id: config.node_id.clone(),
standalone: config.mode == crate::config::ClusterMode::Standalone,
data_dir: config.data_dir.join("raft"),
heartbeat_interval_ms: config.raft.heartbeat_interval.as_millis() as u64,
election_timeout_min_ms: config.raft.election_timeout_min.as_millis() as u64,
election_timeout_max_ms: config.raft.election_timeout_max.as_millis() as u64,
snapshot_threshold: config.raft.snapshot_threshold,
initial_members: vec![],
};
Self::with_config(raft_config).await
}
pub async fn with_config(config: RaftNodeConfig) -> Result<Self> {
std::fs::create_dir_all(&config.data_dir)
.map_err(|e| ClusterError::RaftStorage(e.to_string()))?;
let state_machine = StateMachine::new();
let network = NetworkFactory::new();
let node_id = hash_node_id(&config.node_id);
info!(
node_id,
node_id_str = %config.node_id,
standalone = config.standalone,
data_dir = %config.data_dir.display(),
"Created Raft node"
);
Ok(Self {
raft: None,
log_store: None,
state_machine,
network,
node_id,
node_id_str: config.node_id.clone(),
standalone: config.standalone,
next_index: RwLock::new(1),
data_dir: config.data_dir.clone(),
raft_config: config,
})
}
pub async fn start(&mut self) -> Result<()> {
if self.standalone {
info!(node_id = self.node_id, "Starting in standalone mode");
return Ok(());
}
let log_store = LogStore::new(&self.data_dir)
.map_err(|e| ClusterError::RaftStorage(format!("Failed to create log store: {}", e)))?;
let raft_config = openraft::Config {
cluster_name: "rivven-cluster".to_string(),
heartbeat_interval: self.raft_config.heartbeat_interval_ms,
election_timeout_min: self.raft_config.election_timeout_min_ms,
election_timeout_max: self.raft_config.election_timeout_max_ms,
snapshot_policy: openraft::SnapshotPolicy::LogsSinceLast(
self.raft_config.snapshot_threshold,
),
max_in_snapshot_log_to_keep: 1000,
..Default::default()
};
let raft_config = Arc::new(
raft_config
.validate()
.map_err(|e| ClusterError::RaftStorage(format!("Invalid Raft config: {}", e)))?,
);
let state_machine = StateMachine::new();
let network = NetworkFactory::new();
for (id, addr) in self.network.nodes.read().await.iter() {
network.add_node(*id, addr.clone()).await;
}
let raft =
openraft::Raft::new(self.node_id, raft_config, network, log_store, state_machine)
.await
.map_err(|e| ClusterError::RaftStorage(format!("Failed to create Raft: {}", e)))?;
self.raft = Some(raft);
info!(
node_id = self.node_id,
node_id_str = %self.node_id_str,
"Cluster mode Raft initialized and ready"
);
Ok(())
}
pub async fn bootstrap(&self, members: BTreeMap<NodeId, BasicNode>) -> Result<()> {
if self.standalone {
return Ok(());
}
if let Some(ref raft) = self.raft {
raft.initialize(members)
.await
.map_err(|e| ClusterError::RaftStorage(format!("Failed to bootstrap: {}", e)))?;
info!(node_id = self.node_id, "Bootstrapped Raft cluster");
}
Ok(())
}
pub async fn propose(&self, command: MetadataCommand) -> Result<MetadataResponse> {
use crate::observability::RaftMetrics;
let start = std::time::Instant::now();
if self.standalone {
let index = {
let mut next = self.next_index.write().await;
let idx = *next;
*next += 1;
idx
};
let log_id = LogId::new(openraft::CommittedLeaderId::new(0, self.node_id), index);
let response = self.state_machine.apply_command(&log_id, command).await;
RaftMetrics::increment_proposals();
RaftMetrics::increment_commits();
RaftMetrics::record_proposal_latency(start.elapsed());
return Ok(response.response);
}
if let Some(ref raft) = self.raft {
let request = RaftRequest { command };
let result = raft
.client_write(request)
.await
.map_err(|e| ClusterError::RaftStorage(format!("Client write failed: {}", e)))?;
RaftMetrics::increment_proposals();
RaftMetrics::increment_commits();
RaftMetrics::record_proposal_latency(start.elapsed());
return Ok(result.data.response);
}
Err(ClusterError::RaftStorage(
"Raft not initialized".to_string(),
))
}
pub async fn propose_batch(
&self,
commands: Vec<MetadataCommand>,
) -> Result<Vec<MetadataResponse>> {
use crate::observability::RaftMetrics;
if commands.is_empty() {
return Ok(vec![]);
}
let batch_size = commands.len();
RaftMetrics::record_batch_size(batch_size);
if self.standalone {
let mut responses = Vec::with_capacity(commands.len());
for command in commands {
let index = {
let mut next = self.next_index.write().await;
let idx = *next;
*next += 1;
idx
};
let log_id = LogId::new(openraft::CommittedLeaderId::new(0, self.node_id), index);
let response = self.state_machine.apply_command(&log_id, command).await;
responses.push(response.response);
}
return Ok(responses);
}
if let Some(ref raft) = self.raft {
let mut responses = Vec::with_capacity(commands.len());
let futures: Vec<_> = commands
.into_iter()
.map(|command| {
let raft = raft.clone();
async move {
let request = RaftRequest { command };
raft.client_write(request).await
}
})
.collect();
let results = futures::future::join_all(futures).await;
for result in results {
match result {
Ok(r) => responses.push(r.data.response),
Err(e) => {
return Err(ClusterError::RaftStorage(format!(
"Batch write failed: {}",
e
)))
}
}
}
return Ok(responses);
}
Err(ClusterError::RaftStorage(
"Raft not initialized".to_string(),
))
}
pub async fn ensure_linearizable_read(&self) -> Result<()> {
if self.standalone {
return Ok(());
}
if let Some(ref raft) = self.raft {
let applied = raft.ensure_linearizable().await.map_err(|e| {
ClusterError::RaftStorage(format!("Linearizable read failed: {}", e))
})?;
debug!(
applied_log = %applied.map(|l| l.index.to_string()).unwrap_or_else(|| "none".to_string()),
"Linearizable read confirmed"
);
return Ok(());
}
Err(ClusterError::RaftStorage(
"Raft not initialized".to_string(),
))
}
pub async fn linearizable_metadata(
&self,
) -> Result<tokio::sync::RwLockReadGuard<'_, ClusterMetadata>> {
self.ensure_linearizable_read().await?;
Ok(self.state_machine.metadata().await)
}
pub async fn metadata(&self) -> tokio::sync::RwLockReadGuard<'_, ClusterMetadata> {
self.state_machine.metadata().await
}
pub fn is_leader(&self) -> bool {
if self.standalone {
return true;
}
if let Some(ref raft) = self.raft {
let metrics = raft.metrics().borrow().clone();
return metrics.current_leader == Some(self.node_id);
}
false
}
pub fn leader(&self) -> Option<NodeId> {
if self.standalone {
return Some(self.node_id);
}
if let Some(ref raft) = self.raft {
let metrics = raft.metrics().borrow().clone();
return metrics.current_leader;
}
None
}
pub fn node_id(&self) -> NodeId {
self.node_id
}
pub fn node_id_str(&self) -> &str {
&self.node_id_str
}
pub fn get_raft(&self) -> Option<&openraft::Raft<TypeConfig>> {
self.raft.as_ref()
}
pub async fn add_peer(&self, node_id: NodeId, addr: String) {
self.network.add_node(node_id, addr).await;
}
pub async fn remove_peer(&self, node_id: NodeId) {
self.network.remove_node(node_id).await;
}
pub async fn snapshot(&self) -> Result<()> {
if !self.standalone {
if let Some(ref raft) = self.raft {
raft.trigger().snapshot().await.map_err(|e| {
ClusterError::RaftStorage(format!("Snapshot trigger failed: {}", e))
})?;
info!(node_id = self.node_id, "Triggered Raft snapshot");
return Ok(());
}
}
let (_meta, data) = self
.state_machine
.create_snapshot()
.await
.map_err(|e| ClusterError::RaftStorage(format!("{}", e)))?;
info!(size = data.len(), "Created standalone snapshot");
Ok(())
}
pub fn metrics(&self) -> Option<openraft::RaftMetrics<NodeId, BasicNode>> {
self.raft.as_ref().map(|r| r.metrics().borrow().clone())
}
pub async fn handle_append_entries(
&self,
req: openraft::raft::AppendEntriesRequest<TypeConfig>,
) -> std::result::Result<openraft::raft::AppendEntriesResponse<NodeId>, ClusterError> {
if let Some(ref raft) = self.raft {
raft.append_entries(req)
.await
.map_err(|e| ClusterError::RaftStorage(format!("{}", e)))
} else {
Err(ClusterError::RaftStorage(
"Raft not initialized".to_string(),
))
}
}
pub async fn handle_install_snapshot(
&self,
req: openraft::raft::InstallSnapshotRequest<TypeConfig>,
) -> std::result::Result<openraft::raft::InstallSnapshotResponse<NodeId>, ClusterError> {
if let Some(ref raft) = self.raft {
raft.install_snapshot(req)
.await
.map_err(|e| ClusterError::RaftStorage(format!("{}", e)))
} else {
Err(ClusterError::RaftStorage(
"Raft not initialized".to_string(),
))
}
}
pub async fn handle_vote(
&self,
req: openraft::raft::VoteRequest<NodeId>,
) -> std::result::Result<openraft::raft::VoteResponse<NodeId>, ClusterError> {
if let Some(ref raft) = self.raft {
raft.vote(req)
.await
.map_err(|e| ClusterError::RaftStorage(format!("{}", e)))
} else {
Err(ClusterError::RaftStorage(
"Raft not initialized".to_string(),
))
}
}
}
pub fn hash_node_id(node_id: &str) -> NodeId {
use std::hash::{Hash, Hasher};
let mut hasher = std::collections::hash_map::DefaultHasher::new();
node_id.hash(&mut hasher);
hasher.finish()
}
pub type RaftNodeId = NodeId;
pub type RaftController = RaftNode;
pub use openraft::storage::RaftLogStorage as RaftLogStorageTrait;
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[tokio::test]
async fn test_log_storage_creation() {
let temp_dir = TempDir::new().unwrap();
let storage = LogStore::new(temp_dir.path()).unwrap();
assert!(storage.db.path().exists());
}
#[tokio::test]
async fn test_state_machine_apply() {
let sm = StateMachine::new();
let log_id = LogId::new(openraft::CommittedLeaderId::new(1, 1), 1);
let cmd = MetadataCommand::CreateTopic {
config: crate::partition::TopicConfig::new("test-topic", 3, 1),
partition_assignments: vec![
vec!["node-1".into()],
vec!["node-1".into()],
vec!["node-1".into()],
],
};
let response = sm.apply_command(&log_id, cmd).await;
assert!(matches!(
response.response,
MetadataResponse::TopicCreated { .. }
));
let metadata = sm.metadata().await;
assert!(metadata.topics.contains_key("test-topic"));
}
#[tokio::test]
async fn test_raft_node_standalone() {
let temp_dir = TempDir::new().unwrap();
let config = ClusterConfig {
data_dir: temp_dir.path().to_path_buf(),
..ClusterConfig::standalone()
};
let mut node = RaftNode::new(&config).await.unwrap();
node.start().await.unwrap();
assert!(node.is_leader());
let response = node.propose(MetadataCommand::Noop).await.unwrap();
assert!(matches!(response, MetadataResponse::Success));
}
#[test]
fn test_hash_node_id() {
let id1 = hash_node_id("node-1");
let id2 = hash_node_id("node-2");
let id1_again = hash_node_id("node-1");
assert_ne!(id1, id2);
assert_eq!(id1, id1_again);
}
}