#![allow(clippy::significant_drop_tightening)]
pub mod backup;
pub mod log_store_variant;
pub mod network;
pub mod persistent_store;
pub mod state_machine;
pub mod store;
pub mod transport;
pub mod transport_server;
pub mod proto {
#[allow(clippy::all, clippy::pedantic, clippy::nursery)]
mod inner {
tonic::include_proto!("raft_hpc.v1");
}
pub use inner::*;
}
use std::fmt;
use openraft::RaftTypeConfig;
use serde::Serialize;
use serde::de::DeserializeOwned;
pub trait StateMachineState<C: RaftTypeConfig>:
Serialize + DeserializeOwned + Default + Send + Sync + 'static
{
fn apply(&mut self, cmd: C::D) -> C::R;
fn blank_response() -> C::R;
}
pub trait BackupMetadataSource {
type Metadata: Serialize + DeserializeOwned + fmt::Debug + Clone;
fn backup_metadata(&self) -> Self::Metadata;
}
pub use backup::{BackupMetadata, export_backup, restore_backup, verify_backup};
pub use log_store_variant::{LogReaderVariant, LogStoreVariant};
pub use network::MemNetworkFactory;
pub use persistent_store::FileLogStore;
pub use state_machine::HpcStateMachine;
pub use store::MemLogStore;
pub use transport::{GrpcNetworkFactory, PeerTlsConfig};
pub use transport_server::RaftTransportServer;
#[cfg(test)]
pub(crate) mod test_types {
use serde::{Deserialize, Serialize};
use std::fmt;
use std::io::Cursor;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub enum TestCommand {
Set(String, String),
}
impl fmt::Display for TestCommand {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Set(k, v) => write!(f, "Set({k}, {v})"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub enum TestResponse {
Ok,
}
impl fmt::Display for TestResponse {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Ok")
}
}
openraft::declare_raft_types!(
pub TestTypeConfig:
D = TestCommand,
R = TestResponse,
NodeId = u64,
Node = openraft::impls::BasicNode,
SnapshotData = Cursor<Vec<u8>>,
);
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct TestState {
pub data: std::collections::HashMap<String, String>,
}
impl crate::StateMachineState<TestTypeConfig> for TestState {
fn apply(&mut self, cmd: TestCommand) -> TestResponse {
match cmd {
TestCommand::Set(k, v) => {
self.data.insert(k, v);
TestResponse::Ok
}
}
}
fn blank_response() -> TestResponse {
TestResponse::Ok
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::BTreeMap;
use std::sync::Arc;
use test_types::*;
use tokio::sync::RwLock;
async fn create_test_quorum() -> (openraft::Raft<TestTypeConfig>, Arc<RwLock<TestState>>) {
let state = Arc::new(RwLock::new(TestState::default()));
let config = Arc::new(
openraft::Config {
heartbeat_interval: 200,
election_timeout_min: 500,
election_timeout_max: 1000,
..Default::default()
}
.validate()
.unwrap(),
);
let log_store = MemLogStore::new();
let sm = HpcStateMachine::new(Arc::clone(&state));
let network = MemNetworkFactory::new();
let raft = openraft::Raft::new(1, config, network, log_store, sm)
.await
.unwrap();
let mut members = BTreeMap::new();
members.insert(1u64, openraft::impls::BasicNode::new("127.0.0.1:0"));
raft.initialize(members).await.unwrap();
raft.wait(None)
.metrics(|m| m.current_leader == Some(1), "leader elected")
.await
.unwrap();
(raft, state)
}
async fn create_test_cluster(
node_count: u64,
) -> Vec<(openraft::Raft<TestTypeConfig>, Arc<RwLock<TestState>>)> {
let network_factory = MemNetworkFactory::new();
let mut nodes = Vec::new();
let mut members = BTreeMap::new();
for id in 1..=node_count {
members.insert(
id,
openraft::impls::BasicNode::new(format!("127.0.0.1:{}", 5000 + id)),
);
}
for id in 1..=node_count {
let state = Arc::new(RwLock::new(TestState::default()));
let config = Arc::new(
openraft::Config {
heartbeat_interval: 200,
election_timeout_min: 500,
election_timeout_max: 1000,
..Default::default()
}
.validate()
.unwrap(),
);
let log_store = MemLogStore::new();
let sm = HpcStateMachine::new(Arc::clone(&state));
let raft = openraft::Raft::new(id, config, network_factory.clone(), log_store, sm)
.await
.unwrap();
network_factory.register(id, raft.clone()).await;
nodes.push((raft, state));
}
nodes[0].0.initialize(members).await.unwrap();
nodes[0]
.0
.wait(None)
.metrics(|m| m.current_leader.is_some(), "leader elected")
.await
.unwrap();
nodes
}
async fn create_test_grpc_cluster(
node_count: u64,
) -> (
Vec<(openraft::Raft<TestTypeConfig>, Arc<RwLock<TestState>>)>,
Vec<tokio::task::JoinHandle<()>>,
) {
let mut listeners = Vec::new();
let mut addresses = Vec::new();
for _ in 0..node_count {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
addresses.push(addr.to_string());
listeners.push(listener);
}
let network_factory = GrpcNetworkFactory::new();
let mut members = BTreeMap::new();
let mut nodes = Vec::new();
let mut server_handles = Vec::new();
for (i, addr) in addresses.iter().enumerate() {
let id = (i + 1) as u64;
members.insert(id, openraft::impls::BasicNode::new(addr.clone()));
network_factory.register(id, addr.clone()).await;
}
for (i, listener) in listeners.into_iter().enumerate() {
let id = (i + 1) as u64;
let state = Arc::new(RwLock::new(TestState::default()));
let config = Arc::new(
openraft::Config {
heartbeat_interval: 200,
election_timeout_min: 500,
election_timeout_max: 1000,
..Default::default()
}
.validate()
.unwrap(),
);
let log_store = MemLogStore::new();
let sm = HpcStateMachine::new(Arc::clone(&state));
let raft = openraft::Raft::new(id, config, network_factory.clone(), log_store, sm)
.await
.unwrap();
let server = RaftTransportServer::new(raft.clone());
let handle = tokio::spawn(async move {
let incoming = tokio_stream::wrappers::TcpListenerStream::new(listener);
let _ = tonic::transport::Server::builder()
.add_service(proto::raft_service_server::RaftServiceServer::new(server))
.serve_with_incoming(incoming)
.await;
});
server_handles.push(handle);
nodes.push((raft, state));
}
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
nodes[0].0.initialize(members).await.unwrap();
nodes[0]
.0
.wait(None)
.metrics(|m| m.current_leader.is_some(), "leader elected")
.await
.unwrap();
(nodes, server_handles)
}
#[tokio::test]
async fn single_node_quorum_works() {
let (raft, state) = create_test_quorum().await;
let cmd = TestCommand::Set("key1".into(), "value1".into());
raft.client_write(cmd).await.unwrap();
let s = state.read().await;
assert_eq!(s.data.get("key1").unwrap(), "value1");
}
#[tokio::test]
async fn three_node_cluster_works() {
let nodes = create_test_cluster(3).await;
let (leader, state) = &nodes[0];
let cmd = TestCommand::Set("k".into(), "v".into());
leader.client_write(cmd).await.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(500)).await;
let s = state.read().await;
assert_eq!(s.data.get("k").unwrap(), "v");
for (_, fstate) in &nodes[1..] {
let s = fstate.read().await;
assert!(
s.data.contains_key("k"),
"Data should be replicated to all nodes"
);
}
}
#[tokio::test]
#[ignore = "slow: spins up 3-node gRPC Raft cluster"]
async fn grpc_three_node_cluster_leader_election() {
let (nodes, handles) = create_test_grpc_cluster(3).await;
let (leader, state) = &nodes[0];
let cmd = TestCommand::Set("grpc-key".into(), "grpc-val".into());
leader.client_write(cmd).await.unwrap();
let s = state.read().await;
assert_eq!(s.data.get("grpc-key").unwrap(), "grpc-val");
for h in handles {
h.abort();
}
}
#[tokio::test]
#[ignore = "slow: spins up 3-node gRPC Raft cluster"]
async fn grpc_three_node_cluster_log_replication() {
let (nodes, handles) = create_test_grpc_cluster(3).await;
let (leader, _) = &nodes[0];
let cmd = TestCommand::Set("replicated".into(), "yes".into());
leader.client_write(cmd).await.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(500)).await;
for (_, state) in &nodes[1..] {
let s = state.read().await;
assert!(
s.data.contains_key("replicated"),
"Data should be replicated to all nodes"
);
}
for h in handles {
h.abort();
}
}
}