Skip to main content

raft_hpc_core/
lib.rs

1//! # raft-hpc-core
2//!
3//! Shared Raft consensus infrastructure for HPC systems. Extracted from
4//! lattice-quorum with minimal parameterization — each application provides
5//! its own `TypeConfig` via `openraft::declare_raft_types!`.
6//!
7//! ## What's generic (in this crate)
8//!
9//! - Log stores (in-memory, file-backed, polymorphic variant)
10//! - gRPC transport (network factory, transport server)
11//! - In-memory network (for testing)
12//! - State machine (snapshot management, apply dispatch)
13//! - Backup (export, verify, restore)
14//!
15//! ## What's application-specific (NOT in this crate)
16//!
17//! - `TypeConfig` declaration (`openraft::declare_raft_types!`)
18//! - Command and `CommandResponse` enums
19//! - Application state (`GlobalState`, `JournalState`, etc.)
20//! - `StateMachineState::apply()` implementation
21//! - Client trait implementations
22//! - Factory functions (`create_quorum`, etc.)
23
24#![allow(clippy::significant_drop_tightening)]
25
26pub mod backup;
27pub mod log_store_variant;
28pub mod network;
29pub mod persistent_store;
30pub mod state_machine;
31pub mod store;
32pub mod transport;
33pub mod transport_server;
34
35/// Generated protobuf types for the Raft transport service.
36pub mod proto {
37    #[allow(clippy::all, clippy::pedantic, clippy::nursery)]
38    mod inner {
39        tonic::include_proto!("raft_hpc.v1");
40    }
41    pub use inner::*;
42}
43
44use std::fmt;
45
46use openraft::RaftTypeConfig;
47use serde::Serialize;
48use serde::de::DeserializeOwned;
49
50/// Application state managed by the Raft state machine.
51///
52/// Implement this trait for your application's state type (e.g., `GlobalState`,
53/// `JournalState`). The state machine will call `apply()` for each committed
54/// command and use serde for snapshot serialization.
55pub trait StateMachineState<C: RaftTypeConfig>:
56    Serialize + DeserializeOwned + Default + Send + Sync + 'static
57{
58    /// Apply a committed command to the state, returning a response.
59    fn apply(&mut self, cmd: C::D) -> C::R;
60
61    /// Response value for blank entries and membership changes.
62    fn blank_response() -> C::R;
63}
64
65/// Application state that supports backup metadata extraction.
66///
67/// Implement this to enable backup export/verify/restore with
68/// application-specific metadata (e.g., node count, entry count).
69pub trait BackupMetadataSource {
70    /// Application-specific backup metadata type.
71    type Metadata: Serialize + DeserializeOwned + fmt::Debug + Clone;
72
73    /// Extract metadata from the current state for backup records.
74    fn backup_metadata(&self) -> Self::Metadata;
75}
76
77// Re-exports for convenience.
78pub use backup::{BackupMetadata, export_backup, restore_backup, verify_backup};
79pub use log_store_variant::{LogReaderVariant, LogStoreVariant};
80pub use network::MemNetworkFactory;
81pub use persistent_store::FileLogStore;
82pub use state_machine::HpcStateMachine;
83pub use store::MemLogStore;
84pub use transport::{GrpcNetworkFactory, PeerTlsConfig};
85pub use transport_server::RaftTransportServer;
86
87/// Test types for unit tests across all modules.
88///
89/// Provides a minimal `TestTypeConfig` with simple Command/Response enums
90/// that satisfy all openraft requirements.
91#[cfg(test)]
92pub(crate) mod test_types {
93    use serde::{Deserialize, Serialize};
94    use std::fmt;
95    use std::io::Cursor;
96
97    #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
98    pub enum TestCommand {
99        Set(String, String),
100    }
101
102    impl fmt::Display for TestCommand {
103        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
104            match self {
105                Self::Set(k, v) => write!(f, "Set({k}, {v})"),
106            }
107        }
108    }
109
110    #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
111    pub enum TestResponse {
112        Ok,
113    }
114
115    impl fmt::Display for TestResponse {
116        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
117            write!(f, "Ok")
118        }
119    }
120
121    openraft::declare_raft_types!(
122        pub TestTypeConfig:
123            D = TestCommand,
124            R = TestResponse,
125            NodeId = u64,
126            Node = openraft::impls::BasicNode,
127            SnapshotData = Cursor<Vec<u8>>,
128    );
129
130    /// Simple key-value state for integration tests.
131    #[derive(Debug, Clone, Default, Serialize, Deserialize)]
132    pub struct TestState {
133        pub data: std::collections::HashMap<String, String>,
134    }
135
136    impl crate::StateMachineState<TestTypeConfig> for TestState {
137        fn apply(&mut self, cmd: TestCommand) -> TestResponse {
138            match cmd {
139                TestCommand::Set(k, v) => {
140                    self.data.insert(k, v);
141                    TestResponse::Ok
142                }
143            }
144        }
145
146        fn blank_response() -> TestResponse {
147            TestResponse::Ok
148        }
149    }
150}
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155    use std::collections::BTreeMap;
156    use std::sync::Arc;
157    use test_types::*;
158    use tokio::sync::RwLock;
159
160    /// Create a single-node in-memory quorum for testing.
161    async fn create_test_quorum() -> (openraft::Raft<TestTypeConfig>, Arc<RwLock<TestState>>) {
162        let state = Arc::new(RwLock::new(TestState::default()));
163        let config = Arc::new(
164            openraft::Config {
165                heartbeat_interval: 200,
166                election_timeout_min: 500,
167                election_timeout_max: 1000,
168                ..Default::default()
169            }
170            .validate()
171            .unwrap(),
172        );
173
174        let log_store = MemLogStore::new();
175        let sm = HpcStateMachine::new(Arc::clone(&state));
176        let network = MemNetworkFactory::new();
177
178        let raft = openraft::Raft::new(1, config, network, log_store, sm)
179            .await
180            .unwrap();
181
182        let mut members = BTreeMap::new();
183        members.insert(1u64, openraft::impls::BasicNode::new("127.0.0.1:0"));
184        raft.initialize(members).await.unwrap();
185
186        raft.wait(None)
187            .metrics(|m| m.current_leader == Some(1), "leader elected")
188            .await
189            .unwrap();
190
191        (raft, state)
192    }
193
194    /// Create a multi-node in-memory cluster for testing.
195    async fn create_test_cluster(
196        node_count: u64,
197    ) -> Vec<(openraft::Raft<TestTypeConfig>, Arc<RwLock<TestState>>)> {
198        let network_factory = MemNetworkFactory::new();
199        let mut nodes = Vec::new();
200        let mut members = BTreeMap::new();
201
202        for id in 1..=node_count {
203            members.insert(
204                id,
205                openraft::impls::BasicNode::new(format!("127.0.0.1:{}", 5000 + id)),
206            );
207        }
208
209        for id in 1..=node_count {
210            let state = Arc::new(RwLock::new(TestState::default()));
211            let config = Arc::new(
212                openraft::Config {
213                    heartbeat_interval: 200,
214                    election_timeout_min: 500,
215                    election_timeout_max: 1000,
216                    ..Default::default()
217                }
218                .validate()
219                .unwrap(),
220            );
221
222            let log_store = MemLogStore::new();
223            let sm = HpcStateMachine::new(Arc::clone(&state));
224
225            let raft = openraft::Raft::new(id, config, network_factory.clone(), log_store, sm)
226                .await
227                .unwrap();
228
229            network_factory.register(id, raft.clone()).await;
230            nodes.push((raft, state));
231        }
232
233        nodes[0].0.initialize(members).await.unwrap();
234
235        nodes[0]
236            .0
237            .wait(None)
238            .metrics(|m| m.current_leader.is_some(), "leader elected")
239            .await
240            .unwrap();
241
242        nodes
243    }
244
245    /// Create a multi-node gRPC cluster for testing.
246    async fn create_test_grpc_cluster(
247        node_count: u64,
248    ) -> (
249        Vec<(openraft::Raft<TestTypeConfig>, Arc<RwLock<TestState>>)>,
250        Vec<tokio::task::JoinHandle<()>>,
251    ) {
252        let mut listeners = Vec::new();
253        let mut addresses = Vec::new();
254        for _ in 0..node_count {
255            let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
256            let addr = listener.local_addr().unwrap();
257            addresses.push(addr.to_string());
258            listeners.push(listener);
259        }
260
261        let network_factory = GrpcNetworkFactory::new();
262        let mut members = BTreeMap::new();
263        let mut nodes = Vec::new();
264        let mut server_handles = Vec::new();
265
266        for (i, addr) in addresses.iter().enumerate() {
267            let id = (i + 1) as u64;
268            members.insert(id, openraft::impls::BasicNode::new(addr.clone()));
269            network_factory.register(id, addr.clone()).await;
270        }
271
272        for (i, listener) in listeners.into_iter().enumerate() {
273            let id = (i + 1) as u64;
274            let state = Arc::new(RwLock::new(TestState::default()));
275            let config = Arc::new(
276                openraft::Config {
277                    heartbeat_interval: 200,
278                    election_timeout_min: 500,
279                    election_timeout_max: 1000,
280                    ..Default::default()
281                }
282                .validate()
283                .unwrap(),
284            );
285
286            let log_store = MemLogStore::new();
287            let sm = HpcStateMachine::new(Arc::clone(&state));
288
289            let raft = openraft::Raft::new(id, config, network_factory.clone(), log_store, sm)
290                .await
291                .unwrap();
292
293            let server = RaftTransportServer::new(raft.clone());
294            let handle = tokio::spawn(async move {
295                let incoming = tokio_stream::wrappers::TcpListenerStream::new(listener);
296                let _ = tonic::transport::Server::builder()
297                    .add_service(proto::raft_service_server::RaftServiceServer::new(server))
298                    .serve_with_incoming(incoming)
299                    .await;
300            });
301            server_handles.push(handle);
302
303            nodes.push((raft, state));
304        }
305
306        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
307
308        nodes[0].0.initialize(members).await.unwrap();
309
310        nodes[0]
311            .0
312            .wait(None)
313            .metrics(|m| m.current_leader.is_some(), "leader elected")
314            .await
315            .unwrap();
316
317        (nodes, server_handles)
318    }
319
320    #[tokio::test]
321    async fn single_node_quorum_works() {
322        let (raft, state) = create_test_quorum().await;
323
324        // Write a value through Raft
325        let cmd = TestCommand::Set("key1".into(), "value1".into());
326        raft.client_write(cmd).await.unwrap();
327
328        // Read it back from state
329        let s = state.read().await;
330        assert_eq!(s.data.get("key1").unwrap(), "value1");
331    }
332
333    #[tokio::test]
334    async fn three_node_cluster_works() {
335        let nodes = create_test_cluster(3).await;
336        let (leader, state) = &nodes[0];
337
338        // Write through leader
339        let cmd = TestCommand::Set("k".into(), "v".into());
340        leader.client_write(cmd).await.unwrap();
341
342        // Give time for replication
343        tokio::time::sleep(std::time::Duration::from_millis(500)).await;
344
345        // Read from leader state
346        let s = state.read().await;
347        assert_eq!(s.data.get("k").unwrap(), "v");
348
349        // Verify replicated to followers
350        for (_, fstate) in &nodes[1..] {
351            let s = fstate.read().await;
352            assert!(
353                s.data.contains_key("k"),
354                "Data should be replicated to all nodes"
355            );
356        }
357    }
358
359    #[tokio::test]
360    #[ignore = "slow: spins up 3-node gRPC Raft cluster"]
361    async fn grpc_three_node_cluster_leader_election() {
362        let (nodes, handles) = create_test_grpc_cluster(3).await;
363        let (leader, state) = &nodes[0];
364
365        // Write to prove the cluster is functional
366        let cmd = TestCommand::Set("grpc-key".into(), "grpc-val".into());
367        leader.client_write(cmd).await.unwrap();
368
369        let s = state.read().await;
370        assert_eq!(s.data.get("grpc-key").unwrap(), "grpc-val");
371
372        for h in handles {
373            h.abort();
374        }
375    }
376
377    #[tokio::test]
378    #[ignore = "slow: spins up 3-node gRPC Raft cluster"]
379    async fn grpc_three_node_cluster_log_replication() {
380        let (nodes, handles) = create_test_grpc_cluster(3).await;
381        let (leader, _) = &nodes[0];
382
383        // Write through leader
384        let cmd = TestCommand::Set("replicated".into(), "yes".into());
385        leader.client_write(cmd).await.unwrap();
386
387        // Give time for replication
388        tokio::time::sleep(std::time::Duration::from_millis(500)).await;
389
390        // Verify state replicated to followers
391        for (_, state) in &nodes[1..] {
392            let s = state.read().await;
393            assert!(
394                s.data.contains_key("replicated"),
395                "Data should be replicated to all nodes"
396            );
397        }
398
399        for h in handles {
400            h.abort();
401        }
402    }
403}