1#![allow(clippy::significant_drop_tightening)]
25
26pub mod backup;
27pub mod log_store_variant;
28pub mod network;
29pub mod persistent_store;
30pub mod state_machine;
31pub mod store;
32pub mod transport;
33pub mod transport_server;
34
35pub mod proto {
37 #[allow(clippy::all, clippy::pedantic, clippy::nursery)]
38 mod inner {
39 tonic::include_proto!("raft_hpc.v1");
40 }
41 pub use inner::*;
42}
43
44use std::fmt;
45
46use openraft::RaftTypeConfig;
47use serde::Serialize;
48use serde::de::DeserializeOwned;
49
50pub trait StateMachineState<C: RaftTypeConfig>:
56 Serialize + DeserializeOwned + Default + Send + Sync + 'static
57{
58 fn apply(&mut self, cmd: C::D) -> C::R;
60
61 fn blank_response() -> C::R;
63}
64
65pub trait BackupMetadataSource {
70 type Metadata: Serialize + DeserializeOwned + fmt::Debug + Clone;
72
73 fn backup_metadata(&self) -> Self::Metadata;
75}
76
77pub use backup::{BackupMetadata, export_backup, restore_backup, verify_backup};
79pub use log_store_variant::{LogReaderVariant, LogStoreVariant};
80pub use network::MemNetworkFactory;
81pub use persistent_store::FileLogStore;
82pub use state_machine::HpcStateMachine;
83pub use store::MemLogStore;
84pub use transport::{GrpcNetworkFactory, PeerTlsConfig};
85pub use transport_server::RaftTransportServer;
86
87#[cfg(test)]
92pub(crate) mod test_types {
93 use serde::{Deserialize, Serialize};
94 use std::fmt;
95 use std::io::Cursor;
96
97 #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
98 pub enum TestCommand {
99 Set(String, String),
100 }
101
102 impl fmt::Display for TestCommand {
103 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
104 match self {
105 Self::Set(k, v) => write!(f, "Set({k}, {v})"),
106 }
107 }
108 }
109
110 #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
111 pub enum TestResponse {
112 Ok,
113 }
114
115 impl fmt::Display for TestResponse {
116 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
117 write!(f, "Ok")
118 }
119 }
120
121 openraft::declare_raft_types!(
122 pub TestTypeConfig:
123 D = TestCommand,
124 R = TestResponse,
125 NodeId = u64,
126 Node = openraft::impls::BasicNode,
127 SnapshotData = Cursor<Vec<u8>>,
128 );
129
130 #[derive(Debug, Clone, Default, Serialize, Deserialize)]
132 pub struct TestState {
133 pub data: std::collections::HashMap<String, String>,
134 }
135
136 impl crate::StateMachineState<TestTypeConfig> for TestState {
137 fn apply(&mut self, cmd: TestCommand) -> TestResponse {
138 match cmd {
139 TestCommand::Set(k, v) => {
140 self.data.insert(k, v);
141 TestResponse::Ok
142 }
143 }
144 }
145
146 fn blank_response() -> TestResponse {
147 TestResponse::Ok
148 }
149 }
150}
151
152#[cfg(test)]
153mod tests {
154 use super::*;
155 use std::collections::BTreeMap;
156 use std::sync::Arc;
157 use test_types::*;
158 use tokio::sync::RwLock;
159
160 async fn create_test_quorum() -> (openraft::Raft<TestTypeConfig>, Arc<RwLock<TestState>>) {
162 let state = Arc::new(RwLock::new(TestState::default()));
163 let config = Arc::new(
164 openraft::Config {
165 heartbeat_interval: 200,
166 election_timeout_min: 500,
167 election_timeout_max: 1000,
168 ..Default::default()
169 }
170 .validate()
171 .unwrap(),
172 );
173
174 let log_store = MemLogStore::new();
175 let sm = HpcStateMachine::new(Arc::clone(&state));
176 let network = MemNetworkFactory::new();
177
178 let raft = openraft::Raft::new(1, config, network, log_store, sm)
179 .await
180 .unwrap();
181
182 let mut members = BTreeMap::new();
183 members.insert(1u64, openraft::impls::BasicNode::new("127.0.0.1:0"));
184 raft.initialize(members).await.unwrap();
185
186 raft.wait(None)
187 .metrics(|m| m.current_leader == Some(1), "leader elected")
188 .await
189 .unwrap();
190
191 (raft, state)
192 }
193
194 async fn create_test_cluster(
196 node_count: u64,
197 ) -> Vec<(openraft::Raft<TestTypeConfig>, Arc<RwLock<TestState>>)> {
198 let network_factory = MemNetworkFactory::new();
199 let mut nodes = Vec::new();
200 let mut members = BTreeMap::new();
201
202 for id in 1..=node_count {
203 members.insert(
204 id,
205 openraft::impls::BasicNode::new(format!("127.0.0.1:{}", 5000 + id)),
206 );
207 }
208
209 for id in 1..=node_count {
210 let state = Arc::new(RwLock::new(TestState::default()));
211 let config = Arc::new(
212 openraft::Config {
213 heartbeat_interval: 200,
214 election_timeout_min: 500,
215 election_timeout_max: 1000,
216 ..Default::default()
217 }
218 .validate()
219 .unwrap(),
220 );
221
222 let log_store = MemLogStore::new();
223 let sm = HpcStateMachine::new(Arc::clone(&state));
224
225 let raft = openraft::Raft::new(id, config, network_factory.clone(), log_store, sm)
226 .await
227 .unwrap();
228
229 network_factory.register(id, raft.clone()).await;
230 nodes.push((raft, state));
231 }
232
233 nodes[0].0.initialize(members).await.unwrap();
234
235 nodes[0]
236 .0
237 .wait(None)
238 .metrics(|m| m.current_leader.is_some(), "leader elected")
239 .await
240 .unwrap();
241
242 nodes
243 }
244
245 async fn create_test_grpc_cluster(
247 node_count: u64,
248 ) -> (
249 Vec<(openraft::Raft<TestTypeConfig>, Arc<RwLock<TestState>>)>,
250 Vec<tokio::task::JoinHandle<()>>,
251 ) {
252 let mut listeners = Vec::new();
253 let mut addresses = Vec::new();
254 for _ in 0..node_count {
255 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
256 let addr = listener.local_addr().unwrap();
257 addresses.push(addr.to_string());
258 listeners.push(listener);
259 }
260
261 let network_factory = GrpcNetworkFactory::new();
262 let mut members = BTreeMap::new();
263 let mut nodes = Vec::new();
264 let mut server_handles = Vec::new();
265
266 for (i, addr) in addresses.iter().enumerate() {
267 let id = (i + 1) as u64;
268 members.insert(id, openraft::impls::BasicNode::new(addr.clone()));
269 network_factory.register(id, addr.clone()).await;
270 }
271
272 for (i, listener) in listeners.into_iter().enumerate() {
273 let id = (i + 1) as u64;
274 let state = Arc::new(RwLock::new(TestState::default()));
275 let config = Arc::new(
276 openraft::Config {
277 heartbeat_interval: 200,
278 election_timeout_min: 500,
279 election_timeout_max: 1000,
280 ..Default::default()
281 }
282 .validate()
283 .unwrap(),
284 );
285
286 let log_store = MemLogStore::new();
287 let sm = HpcStateMachine::new(Arc::clone(&state));
288
289 let raft = openraft::Raft::new(id, config, network_factory.clone(), log_store, sm)
290 .await
291 .unwrap();
292
293 let server = RaftTransportServer::new(raft.clone());
294 let handle = tokio::spawn(async move {
295 let incoming = tokio_stream::wrappers::TcpListenerStream::new(listener);
296 let _ = tonic::transport::Server::builder()
297 .add_service(proto::raft_service_server::RaftServiceServer::new(server))
298 .serve_with_incoming(incoming)
299 .await;
300 });
301 server_handles.push(handle);
302
303 nodes.push((raft, state));
304 }
305
306 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
307
308 nodes[0].0.initialize(members).await.unwrap();
309
310 nodes[0]
311 .0
312 .wait(None)
313 .metrics(|m| m.current_leader.is_some(), "leader elected")
314 .await
315 .unwrap();
316
317 (nodes, server_handles)
318 }
319
320 #[tokio::test]
321 async fn single_node_quorum_works() {
322 let (raft, state) = create_test_quorum().await;
323
324 let cmd = TestCommand::Set("key1".into(), "value1".into());
326 raft.client_write(cmd).await.unwrap();
327
328 let s = state.read().await;
330 assert_eq!(s.data.get("key1").unwrap(), "value1");
331 }
332
333 #[tokio::test]
334 async fn three_node_cluster_works() {
335 let nodes = create_test_cluster(3).await;
336 let (leader, state) = &nodes[0];
337
338 let cmd = TestCommand::Set("k".into(), "v".into());
340 leader.client_write(cmd).await.unwrap();
341
342 tokio::time::sleep(std::time::Duration::from_millis(500)).await;
344
345 let s = state.read().await;
347 assert_eq!(s.data.get("k").unwrap(), "v");
348
349 for (_, fstate) in &nodes[1..] {
351 let s = fstate.read().await;
352 assert!(
353 s.data.contains_key("k"),
354 "Data should be replicated to all nodes"
355 );
356 }
357 }
358
359 #[tokio::test]
360 #[ignore = "slow: spins up 3-node gRPC Raft cluster"]
361 async fn grpc_three_node_cluster_leader_election() {
362 let (nodes, handles) = create_test_grpc_cluster(3).await;
363 let (leader, state) = &nodes[0];
364
365 let cmd = TestCommand::Set("grpc-key".into(), "grpc-val".into());
367 leader.client_write(cmd).await.unwrap();
368
369 let s = state.read().await;
370 assert_eq!(s.data.get("grpc-key").unwrap(), "grpc-val");
371
372 for h in handles {
373 h.abort();
374 }
375 }
376
377 #[tokio::test]
378 #[ignore = "slow: spins up 3-node gRPC Raft cluster"]
379 async fn grpc_three_node_cluster_log_replication() {
380 let (nodes, handles) = create_test_grpc_cluster(3).await;
381 let (leader, _) = &nodes[0];
382
383 let cmd = TestCommand::Set("replicated".into(), "yes".into());
385 leader.client_write(cmd).await.unwrap();
386
387 tokio::time::sleep(std::time::Duration::from_millis(500)).await;
389
390 for (_, state) in &nodes[1..] {
392 let s = state.read().await;
393 assert!(
394 s.data.contains_key("replicated"),
395 "Data should be replicated to all nodes"
396 );
397 }
398
399 for h in handles {
400 h.abort();
401 }
402 }
403}