1use crate::raft::{GraphStateMachine, RaftError, RaftNodeId, RaftResult, Request, Response};
4use serde::{Deserialize, Serialize};
5use std::collections::BTreeSet;
6use std::sync::Arc;
7use tokio::sync::RwLock;
8use tracing::info;
9
10#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash, Default)]
12pub struct NodeId {
13 #[serde(default)]
15 pub id: RaftNodeId,
16 #[serde(default)]
18 pub addr: String,
19}
20
21impl NodeId {
22 pub fn new(id: RaftNodeId, addr: String) -> Self {
23 Self { id, addr }
24 }
25}
26
27pub mod typ {
29 use super::*;
30
31 pub type NodeIdType = RaftNodeId;
33
34 pub type Node = super::NodeId;
36
37 #[derive(Debug, Clone, Serialize, Deserialize)]
39 pub struct Entry {
40 pub request: Request,
41 }
42
43 pub type SnapshotData = Vec<u8>;
45
46 pub struct TypeConfig;
50}
51
52#[derive(Debug, Clone, Default)]
54pub struct SimpleRaftMetrics {
55 pub current_term: u64,
56 pub current_leader: Option<RaftNodeId>,
57 pub last_log_index: u64,
58 pub last_applied: u64,
59}
60
61pub struct RaftNode {
63 node_id: RaftNodeId,
65 state_machine: Arc<RwLock<GraphStateMachine>>,
67 metrics: Arc<RwLock<SimpleRaftMetrics>>,
69 initialized: Arc<RwLock<bool>>,
71}
72
73impl RaftNode {
74 pub fn new(node_id: RaftNodeId, state_machine: GraphStateMachine) -> Self {
76 info!("Creating Raft node with ID: {}", node_id);
77
78 Self {
79 node_id,
80 state_machine: Arc::new(RwLock::new(state_machine)),
81 metrics: Arc::new(RwLock::new(SimpleRaftMetrics::default())),
82 initialized: Arc::new(RwLock::new(false)),
83 }
84 }
85
86 pub fn id(&self) -> RaftNodeId {
88 self.node_id
89 }
90
91 pub async fn initialize(&mut self, _peers: Vec<NodeId>) -> RaftResult<()> {
93 info!("Initializing Raft node {} with peers", self.node_id);
94
95 let mut init = self.initialized.write().await;
96 *init = true;
97
98 let mut metrics = self.metrics.write().await;
99 metrics.current_leader = Some(self.node_id); Ok(())
102 }
103
104 pub async fn write(&self, request: Request) -> RaftResult<Response> {
106 if *self.initialized.read().await {
107 let sm = self.state_machine.read().await;
109 let response = sm.apply(request).await;
110
111 let mut metrics = self.metrics.write().await;
113 metrics.last_log_index += 1;
114 metrics.last_applied = metrics.last_log_index;
115
116 Ok(response)
117 } else {
118 Err(RaftError::Raft("Raft not initialized".to_string()))
119 }
120 }
121
122 pub async fn read(&self, request: Request) -> RaftResult<Response> {
124 let sm = self.state_machine.read().await;
125 Ok(sm.apply(request).await)
126 }
127
128 pub async fn is_leader(&self) -> bool {
130 let metrics = self.metrics.read().await;
131 metrics.current_leader == Some(self.node_id)
132 }
133
134 pub async fn get_leader(&self) -> Option<RaftNodeId> {
136 self.metrics.read().await.current_leader
137 }
138
139 pub async fn add_learner(&self, node_id: RaftNodeId, _node: NodeId) -> RaftResult<()> {
141 info!("Adding learner {} to cluster", node_id);
142
143 if *self.initialized.read().await {
144 Ok(())
145 } else {
146 Err(RaftError::Raft("Raft not initialized".to_string()))
147 }
148 }
149
150 pub async fn change_membership(&self, members: BTreeSet<RaftNodeId>) -> RaftResult<()> {
152 info!("Changing cluster membership to: {:?}", members);
153
154 if *self.initialized.read().await {
155 Ok(())
156 } else {
157 Err(RaftError::Raft("Raft not initialized".to_string()))
158 }
159 }
160
161 pub async fn metrics(&self) -> SimpleRaftMetrics {
163 self.metrics.read().await.clone()
164 }
165
166 pub async fn shutdown(&self) -> RaftResult<()> {
168 info!("Shutting down Raft node {}", self.node_id);
169
170 let mut init = self.initialized.write().await;
171 *init = false;
172
173 Ok(())
174 }
175}
176
177#[cfg(test)]
178mod tests {
179 use super::*;
180 use crate::persistence::PersistenceManager;
181 use tempfile::TempDir;
182
183 #[tokio::test]
184 async fn test_raft_node_creation() {
185 let temp_dir = TempDir::new().unwrap();
186 let persistence = Arc::new(PersistenceManager::new(temp_dir.path()).unwrap());
187 let sm = GraphStateMachine::new(persistence);
188 let node = RaftNode::new(1, sm);
189
190 assert_eq!(node.id(), 1);
191 assert!(!node.is_leader().await);
192 }
193
194 #[tokio::test]
195 async fn test_node_id() {
196 let node_id = NodeId::new(1, "127.0.0.1:5000".to_string());
197 assert_eq!(node_id.id, 1);
198 assert_eq!(node_id.addr, "127.0.0.1:5000");
199 }
200
201 #[tokio::test]
204 async fn test_raft_node_initialize() {
205 let temp_dir = TempDir::new().unwrap();
206 let persistence = Arc::new(PersistenceManager::new(temp_dir.path()).unwrap());
207 let sm = GraphStateMachine::new(persistence);
208 let mut node = RaftNode::new(1, sm);
209
210 assert!(!node.is_leader().await);
211 assert_eq!(node.get_leader().await, None);
212
213 let peers = vec![
214 NodeId::new(2, "127.0.0.1:5001".to_string()),
215 NodeId::new(3, "127.0.0.1:5002".to_string()),
216 ];
217 node.initialize(peers).await.unwrap();
218
219 assert!(node.is_leader().await);
221 assert_eq!(node.get_leader().await, Some(1));
222 }
223
224 #[tokio::test]
225 async fn test_raft_node_write_before_init() {
226 let temp_dir = TempDir::new().unwrap();
227 let persistence = Arc::new(PersistenceManager::new(temp_dir.path()).unwrap());
228 let sm = GraphStateMachine::new(persistence);
229 let node = RaftNode::new(1, sm);
230
231 let request = Request::ExecuteQuery {
232 tenant: "default".to_string(),
233 query: "MATCH (n) RETURN n".to_string(),
234 };
235
236 let result = node.write(request).await;
237 assert!(result.is_err());
238 }
239
240 #[tokio::test]
241 async fn test_raft_node_write_after_init() {
242 let temp_dir = TempDir::new().unwrap();
243 let persistence = Arc::new(PersistenceManager::new(temp_dir.path()).unwrap());
244 let sm = GraphStateMachine::new(persistence);
245 let mut node = RaftNode::new(1, sm);
246
247 node.initialize(vec![]).await.unwrap();
248
249 let request = Request::ExecuteQuery {
250 tenant: "default".to_string(),
251 query: "MATCH (n) RETURN n".to_string(),
252 };
253
254 let result = node.write(request).await;
255 assert!(result.is_ok());
256 let response = result.unwrap();
257 assert!(matches!(response, Response::QueryResult { .. }));
258 }
259
260 #[tokio::test]
261 async fn test_raft_node_read() {
262 let temp_dir = TempDir::new().unwrap();
263 let persistence = Arc::new(PersistenceManager::new(temp_dir.path()).unwrap());
264 let sm = GraphStateMachine::new(persistence);
265 let node = RaftNode::new(1, sm);
266
267 let request = Request::ExecuteQuery {
269 tenant: "default".to_string(),
270 query: "MATCH (n) RETURN n".to_string(),
271 };
272
273 let result = node.read(request).await;
274 assert!(result.is_ok());
275 }
276
277 #[tokio::test]
278 async fn test_raft_node_metrics() {
279 let temp_dir = TempDir::new().unwrap();
280 let persistence = Arc::new(PersistenceManager::new(temp_dir.path()).unwrap());
281 let sm = GraphStateMachine::new(persistence);
282 let mut node = RaftNode::new(1, sm);
283
284 let metrics = node.metrics().await;
285 assert_eq!(metrics.current_term, 0);
286 assert_eq!(metrics.last_log_index, 0);
287 assert_eq!(metrics.last_applied, 0);
288 assert_eq!(metrics.current_leader, None);
289
290 node.initialize(vec![]).await.unwrap();
291
292 let request = Request::ExecuteQuery {
294 tenant: "default".to_string(),
295 query: "MATCH (n) RETURN n".to_string(),
296 };
297 node.write(request).await.unwrap();
298
299 let metrics = node.metrics().await;
300 assert_eq!(metrics.last_log_index, 1);
301 assert_eq!(metrics.last_applied, 1);
302 assert_eq!(metrics.current_leader, Some(1));
303 }
304
305 #[tokio::test]
306 async fn test_raft_node_add_learner_before_init() {
307 let temp_dir = TempDir::new().unwrap();
308 let persistence = Arc::new(PersistenceManager::new(temp_dir.path()).unwrap());
309 let sm = GraphStateMachine::new(persistence);
310 let node = RaftNode::new(1, sm);
311
312 let new_node = NodeId::new(2, "127.0.0.1:5001".to_string());
313 let result = node.add_learner(2, new_node).await;
314 assert!(result.is_err());
315 }
316
317 #[tokio::test]
318 async fn test_raft_node_add_learner_after_init() {
319 let temp_dir = TempDir::new().unwrap();
320 let persistence = Arc::new(PersistenceManager::new(temp_dir.path()).unwrap());
321 let sm = GraphStateMachine::new(persistence);
322 let mut node = RaftNode::new(1, sm);
323
324 node.initialize(vec![]).await.unwrap();
325
326 let new_node = NodeId::new(2, "127.0.0.1:5001".to_string());
327 let result = node.add_learner(2, new_node).await;
328 assert!(result.is_ok());
329 }
330
331 #[tokio::test]
332 async fn test_raft_node_change_membership_before_init() {
333 let temp_dir = TempDir::new().unwrap();
334 let persistence = Arc::new(PersistenceManager::new(temp_dir.path()).unwrap());
335 let sm = GraphStateMachine::new(persistence);
336 let node = RaftNode::new(1, sm);
337
338 let mut members = BTreeSet::new();
339 members.insert(1);
340 members.insert(2);
341 let result = node.change_membership(members).await;
342 assert!(result.is_err());
343 }
344
345 #[tokio::test]
346 async fn test_raft_node_change_membership_after_init() {
347 let temp_dir = TempDir::new().unwrap();
348 let persistence = Arc::new(PersistenceManager::new(temp_dir.path()).unwrap());
349 let sm = GraphStateMachine::new(persistence);
350 let mut node = RaftNode::new(1, sm);
351
352 node.initialize(vec![]).await.unwrap();
353
354 let mut members = BTreeSet::new();
355 members.insert(1);
356 members.insert(2);
357 let result = node.change_membership(members).await;
358 assert!(result.is_ok());
359 }
360
361 #[tokio::test]
362 async fn test_raft_node_shutdown() {
363 let temp_dir = TempDir::new().unwrap();
364 let persistence = Arc::new(PersistenceManager::new(temp_dir.path()).unwrap());
365 let sm = GraphStateMachine::new(persistence);
366 let mut node = RaftNode::new(1, sm);
367
368 node.initialize(vec![]).await.unwrap();
369 assert!(node.is_leader().await);
370
371 node.shutdown().await.unwrap();
372
373 let request = Request::ExecuteQuery {
375 tenant: "default".to_string(),
376 query: "MATCH (n) RETURN n".to_string(),
377 };
378 let result = node.write(request).await;
379 assert!(result.is_err());
380 }
381
382 #[tokio::test]
383 async fn test_raft_node_multiple_writes_increment_metrics() {
384 let temp_dir = TempDir::new().unwrap();
385 let persistence = Arc::new(PersistenceManager::new(temp_dir.path()).unwrap());
386 let sm = GraphStateMachine::new(persistence);
387 let mut node = RaftNode::new(1, sm);
388
389 node.initialize(vec![]).await.unwrap();
390
391 for _ in 0..5 {
392 let request = Request::ExecuteQuery {
393 tenant: "default".to_string(),
394 query: "MATCH (n) RETURN n".to_string(),
395 };
396 node.write(request).await.unwrap();
397 }
398
399 let metrics = node.metrics().await;
400 assert_eq!(metrics.last_log_index, 5);
401 assert_eq!(metrics.last_applied, 5);
402 }
403
404 #[test]
405 fn test_node_id_default() {
406 let node_id = NodeId::default();
407 assert_eq!(node_id.id, 0);
408 assert_eq!(node_id.addr, "");
409 }
410
411 #[test]
412 fn test_node_id_serialization() {
413 let node_id = NodeId::new(42, "10.0.0.1:8080".to_string());
414 let json = serde_json::to_string(&node_id).unwrap();
415 let deserialized: NodeId = serde_json::from_str(&json).unwrap();
416 assert_eq!(deserialized.id, 42);
417 assert_eq!(deserialized.addr, "10.0.0.1:8080");
418 }
419
420 #[test]
421 fn test_node_id_equality() {
422 let a = NodeId::new(1, "addr1".to_string());
423 let b = NodeId::new(1, "addr1".to_string());
424 let c = NodeId::new(2, "addr1".to_string());
425 assert_eq!(a, b);
426 assert_ne!(a, c);
427 }
428
429 #[test]
430 fn test_simple_raft_metrics_default() {
431 let metrics = SimpleRaftMetrics::default();
432 assert_eq!(metrics.current_term, 0);
433 assert_eq!(metrics.current_leader, None);
434 assert_eq!(metrics.last_log_index, 0);
435 assert_eq!(metrics.last_applied, 0);
436 }
437
438 #[test]
439 fn test_entry_serialization() {
440 let entry = typ::Entry {
441 request: Request::ExecuteQuery {
442 tenant: "default".to_string(),
443 query: "MATCH (n) RETURN n".to_string(),
444 },
445 };
446 let json = serde_json::to_string(&entry).unwrap();
447 let deserialized: typ::Entry = serde_json::from_str(&json).unwrap();
448 assert!(matches!(deserialized.request, Request::ExecuteQuery { .. }));
449 }
450}