Skip to main content

graphmind/raft/
node.rs

1//! Raft node implementation
2
3use crate::raft::{GraphStateMachine, RaftError, RaftNodeId, RaftResult, Request, Response};
4use serde::{Deserialize, Serialize};
5use std::collections::BTreeSet;
6use std::sync::Arc;
7use tokio::sync::RwLock;
8use tracing::info;
9
10/// Node identifier with address
11#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash, Default)]
12pub struct NodeId {
13    /// Unique node ID
14    #[serde(default)]
15    pub id: RaftNodeId,
16    /// Node address (host:port)
17    #[serde(default)]
18    pub addr: String,
19}
20
21impl NodeId {
22    pub fn new(id: RaftNodeId, addr: String) -> Self {
23        Self { id, addr }
24    }
25}
26
27/// Raft type definitions for openraft
28pub mod typ {
29    use super::*;
30
31    /// Node ID type
32    pub type NodeIdType = RaftNodeId;
33
34    /// Node type containing address information
35    pub type Node = super::NodeId;
36
37    /// Entry type for log entries
38    #[derive(Debug, Clone, Serialize, Deserialize)]
39    pub struct Entry {
40        pub request: Request,
41    }
42
43    /// Snapshot data type
44    pub type SnapshotData = Vec<u8>;
45
46    /// Type configuration for compatibility
47    /// Note: This is a simplified implementation. Full Raft integration
48    /// would use openraft::declare_raft_types! macro with proper trait bounds.
49    pub struct TypeConfig;
50}
51
52/// Raft metrics (simplified)
53#[derive(Debug, Clone, Default)]
54pub struct SimpleRaftMetrics {
55    pub current_term: u64,
56    pub current_leader: Option<RaftNodeId>,
57    pub last_log_index: u64,
58    pub last_applied: u64,
59}
60
61/// Raft node managing consensus
62pub struct RaftNode {
63    /// Node ID
64    node_id: RaftNodeId,
65    /// State machine
66    state_machine: Arc<RwLock<GraphStateMachine>>,
67    /// Current metrics
68    metrics: Arc<RwLock<SimpleRaftMetrics>>,
69    /// Is initialized?
70    initialized: Arc<RwLock<bool>>,
71}
72
73impl RaftNode {
74    /// Create a new Raft node
75    pub fn new(node_id: RaftNodeId, state_machine: GraphStateMachine) -> Self {
76        info!("Creating Raft node with ID: {}", node_id);
77
78        Self {
79            node_id,
80            state_machine: Arc::new(RwLock::new(state_machine)),
81            metrics: Arc::new(RwLock::new(SimpleRaftMetrics::default())),
82            initialized: Arc::new(RwLock::new(false)),
83        }
84    }
85
86    /// Get node ID
87    pub fn id(&self) -> RaftNodeId {
88        self.node_id
89    }
90
91    /// Initialize the Raft instance
92    pub async fn initialize(&mut self, _peers: Vec<NodeId>) -> RaftResult<()> {
93        info!("Initializing Raft node {} with peers", self.node_id);
94
95        let mut init = self.initialized.write().await;
96        *init = true;
97
98        let mut metrics = self.metrics.write().await;
99        metrics.current_leader = Some(self.node_id); // Simplified: this node is leader
100
101        Ok(())
102    }
103
104    /// Submit a write request (goes through Raft consensus)
105    pub async fn write(&self, request: Request) -> RaftResult<Response> {
106        if *self.initialized.read().await {
107            // Apply directly to state machine
108            let sm = self.state_machine.read().await;
109            let response = sm.apply(request).await;
110
111            // Update metrics
112            let mut metrics = self.metrics.write().await;
113            metrics.last_log_index += 1;
114            metrics.last_applied = metrics.last_log_index;
115
116            Ok(response)
117        } else {
118            Err(RaftError::Raft("Raft not initialized".to_string()))
119        }
120    }
121
122    /// Execute a read request (can be served locally if leader)
123    pub async fn read(&self, request: Request) -> RaftResult<Response> {
124        let sm = self.state_machine.read().await;
125        Ok(sm.apply(request).await)
126    }
127
128    /// Check if this node is the leader
129    pub async fn is_leader(&self) -> bool {
130        let metrics = self.metrics.read().await;
131        metrics.current_leader == Some(self.node_id)
132    }
133
134    /// Get current leader ID
135    pub async fn get_leader(&self) -> Option<RaftNodeId> {
136        self.metrics.read().await.current_leader
137    }
138
139    /// Add a new node to the cluster
140    pub async fn add_learner(&self, node_id: RaftNodeId, _node: NodeId) -> RaftResult<()> {
141        info!("Adding learner {} to cluster", node_id);
142
143        if *self.initialized.read().await {
144            Ok(())
145        } else {
146            Err(RaftError::Raft("Raft not initialized".to_string()))
147        }
148    }
149
150    /// Change cluster membership
151    pub async fn change_membership(&self, members: BTreeSet<RaftNodeId>) -> RaftResult<()> {
152        info!("Changing cluster membership to: {:?}", members);
153
154        if *self.initialized.read().await {
155            Ok(())
156        } else {
157            Err(RaftError::Raft("Raft not initialized".to_string()))
158        }
159    }
160
161    /// Get Raft metrics
162    pub async fn metrics(&self) -> SimpleRaftMetrics {
163        self.metrics.read().await.clone()
164    }
165
166    /// Shutdown the Raft node
167    pub async fn shutdown(&self) -> RaftResult<()> {
168        info!("Shutting down Raft node {}", self.node_id);
169
170        let mut init = self.initialized.write().await;
171        *init = false;
172
173        Ok(())
174    }
175}
176
177#[cfg(test)]
178mod tests {
179    use super::*;
180    use crate::persistence::PersistenceManager;
181    use tempfile::TempDir;
182
183    #[tokio::test]
184    async fn test_raft_node_creation() {
185        let temp_dir = TempDir::new().unwrap();
186        let persistence = Arc::new(PersistenceManager::new(temp_dir.path()).unwrap());
187        let sm = GraphStateMachine::new(persistence);
188        let node = RaftNode::new(1, sm);
189
190        assert_eq!(node.id(), 1);
191        assert!(!node.is_leader().await);
192    }
193
194    #[tokio::test]
195    async fn test_node_id() {
196        let node_id = NodeId::new(1, "127.0.0.1:5000".to_string());
197        assert_eq!(node_id.id, 1);
198        assert_eq!(node_id.addr, "127.0.0.1:5000");
199    }
200
201    // ========== Additional RaftNode Coverage Tests ==========
202
203    #[tokio::test]
204    async fn test_raft_node_initialize() {
205        let temp_dir = TempDir::new().unwrap();
206        let persistence = Arc::new(PersistenceManager::new(temp_dir.path()).unwrap());
207        let sm = GraphStateMachine::new(persistence);
208        let mut node = RaftNode::new(1, sm);
209
210        assert!(!node.is_leader().await);
211        assert_eq!(node.get_leader().await, None);
212
213        let peers = vec![
214            NodeId::new(2, "127.0.0.1:5001".to_string()),
215            NodeId::new(3, "127.0.0.1:5002".to_string()),
216        ];
217        node.initialize(peers).await.unwrap();
218
219        // After initialization, simplified impl makes self the leader
220        assert!(node.is_leader().await);
221        assert_eq!(node.get_leader().await, Some(1));
222    }
223
224    #[tokio::test]
225    async fn test_raft_node_write_before_init() {
226        let temp_dir = TempDir::new().unwrap();
227        let persistence = Arc::new(PersistenceManager::new(temp_dir.path()).unwrap());
228        let sm = GraphStateMachine::new(persistence);
229        let node = RaftNode::new(1, sm);
230
231        let request = Request::ExecuteQuery {
232            tenant: "default".to_string(),
233            query: "MATCH (n) RETURN n".to_string(),
234        };
235
236        let result = node.write(request).await;
237        assert!(result.is_err());
238    }
239
240    #[tokio::test]
241    async fn test_raft_node_write_after_init() {
242        let temp_dir = TempDir::new().unwrap();
243        let persistence = Arc::new(PersistenceManager::new(temp_dir.path()).unwrap());
244        let sm = GraphStateMachine::new(persistence);
245        let mut node = RaftNode::new(1, sm);
246
247        node.initialize(vec![]).await.unwrap();
248
249        let request = Request::ExecuteQuery {
250            tenant: "default".to_string(),
251            query: "MATCH (n) RETURN n".to_string(),
252        };
253
254        let result = node.write(request).await;
255        assert!(result.is_ok());
256        let response = result.unwrap();
257        assert!(matches!(response, Response::QueryResult { .. }));
258    }
259
260    #[tokio::test]
261    async fn test_raft_node_read() {
262        let temp_dir = TempDir::new().unwrap();
263        let persistence = Arc::new(PersistenceManager::new(temp_dir.path()).unwrap());
264        let sm = GraphStateMachine::new(persistence);
265        let node = RaftNode::new(1, sm);
266
267        // Read does not require initialization
268        let request = Request::ExecuteQuery {
269            tenant: "default".to_string(),
270            query: "MATCH (n) RETURN n".to_string(),
271        };
272
273        let result = node.read(request).await;
274        assert!(result.is_ok());
275    }
276
277    #[tokio::test]
278    async fn test_raft_node_metrics() {
279        let temp_dir = TempDir::new().unwrap();
280        let persistence = Arc::new(PersistenceManager::new(temp_dir.path()).unwrap());
281        let sm = GraphStateMachine::new(persistence);
282        let mut node = RaftNode::new(1, sm);
283
284        let metrics = node.metrics().await;
285        assert_eq!(metrics.current_term, 0);
286        assert_eq!(metrics.last_log_index, 0);
287        assert_eq!(metrics.last_applied, 0);
288        assert_eq!(metrics.current_leader, None);
289
290        node.initialize(vec![]).await.unwrap();
291
292        // Write a request to update metrics
293        let request = Request::ExecuteQuery {
294            tenant: "default".to_string(),
295            query: "MATCH (n) RETURN n".to_string(),
296        };
297        node.write(request).await.unwrap();
298
299        let metrics = node.metrics().await;
300        assert_eq!(metrics.last_log_index, 1);
301        assert_eq!(metrics.last_applied, 1);
302        assert_eq!(metrics.current_leader, Some(1));
303    }
304
305    #[tokio::test]
306    async fn test_raft_node_add_learner_before_init() {
307        let temp_dir = TempDir::new().unwrap();
308        let persistence = Arc::new(PersistenceManager::new(temp_dir.path()).unwrap());
309        let sm = GraphStateMachine::new(persistence);
310        let node = RaftNode::new(1, sm);
311
312        let new_node = NodeId::new(2, "127.0.0.1:5001".to_string());
313        let result = node.add_learner(2, new_node).await;
314        assert!(result.is_err());
315    }
316
317    #[tokio::test]
318    async fn test_raft_node_add_learner_after_init() {
319        let temp_dir = TempDir::new().unwrap();
320        let persistence = Arc::new(PersistenceManager::new(temp_dir.path()).unwrap());
321        let sm = GraphStateMachine::new(persistence);
322        let mut node = RaftNode::new(1, sm);
323
324        node.initialize(vec![]).await.unwrap();
325
326        let new_node = NodeId::new(2, "127.0.0.1:5001".to_string());
327        let result = node.add_learner(2, new_node).await;
328        assert!(result.is_ok());
329    }
330
331    #[tokio::test]
332    async fn test_raft_node_change_membership_before_init() {
333        let temp_dir = TempDir::new().unwrap();
334        let persistence = Arc::new(PersistenceManager::new(temp_dir.path()).unwrap());
335        let sm = GraphStateMachine::new(persistence);
336        let node = RaftNode::new(1, sm);
337
338        let mut members = BTreeSet::new();
339        members.insert(1);
340        members.insert(2);
341        let result = node.change_membership(members).await;
342        assert!(result.is_err());
343    }
344
345    #[tokio::test]
346    async fn test_raft_node_change_membership_after_init() {
347        let temp_dir = TempDir::new().unwrap();
348        let persistence = Arc::new(PersistenceManager::new(temp_dir.path()).unwrap());
349        let sm = GraphStateMachine::new(persistence);
350        let mut node = RaftNode::new(1, sm);
351
352        node.initialize(vec![]).await.unwrap();
353
354        let mut members = BTreeSet::new();
355        members.insert(1);
356        members.insert(2);
357        let result = node.change_membership(members).await;
358        assert!(result.is_ok());
359    }
360
361    #[tokio::test]
362    async fn test_raft_node_shutdown() {
363        let temp_dir = TempDir::new().unwrap();
364        let persistence = Arc::new(PersistenceManager::new(temp_dir.path()).unwrap());
365        let sm = GraphStateMachine::new(persistence);
366        let mut node = RaftNode::new(1, sm);
367
368        node.initialize(vec![]).await.unwrap();
369        assert!(node.is_leader().await);
370
371        node.shutdown().await.unwrap();
372
373        // After shutdown, writes should fail (not initialized)
374        let request = Request::ExecuteQuery {
375            tenant: "default".to_string(),
376            query: "MATCH (n) RETURN n".to_string(),
377        };
378        let result = node.write(request).await;
379        assert!(result.is_err());
380    }
381
382    #[tokio::test]
383    async fn test_raft_node_multiple_writes_increment_metrics() {
384        let temp_dir = TempDir::new().unwrap();
385        let persistence = Arc::new(PersistenceManager::new(temp_dir.path()).unwrap());
386        let sm = GraphStateMachine::new(persistence);
387        let mut node = RaftNode::new(1, sm);
388
389        node.initialize(vec![]).await.unwrap();
390
391        for _ in 0..5 {
392            let request = Request::ExecuteQuery {
393                tenant: "default".to_string(),
394                query: "MATCH (n) RETURN n".to_string(),
395            };
396            node.write(request).await.unwrap();
397        }
398
399        let metrics = node.metrics().await;
400        assert_eq!(metrics.last_log_index, 5);
401        assert_eq!(metrics.last_applied, 5);
402    }
403
404    #[test]
405    fn test_node_id_default() {
406        let node_id = NodeId::default();
407        assert_eq!(node_id.id, 0);
408        assert_eq!(node_id.addr, "");
409    }
410
411    #[test]
412    fn test_node_id_serialization() {
413        let node_id = NodeId::new(42, "10.0.0.1:8080".to_string());
414        let json = serde_json::to_string(&node_id).unwrap();
415        let deserialized: NodeId = serde_json::from_str(&json).unwrap();
416        assert_eq!(deserialized.id, 42);
417        assert_eq!(deserialized.addr, "10.0.0.1:8080");
418    }
419
420    #[test]
421    fn test_node_id_equality() {
422        let a = NodeId::new(1, "addr1".to_string());
423        let b = NodeId::new(1, "addr1".to_string());
424        let c = NodeId::new(2, "addr1".to_string());
425        assert_eq!(a, b);
426        assert_ne!(a, c);
427    }
428
429    #[test]
430    fn test_simple_raft_metrics_default() {
431        let metrics = SimpleRaftMetrics::default();
432        assert_eq!(metrics.current_term, 0);
433        assert_eq!(metrics.current_leader, None);
434        assert_eq!(metrics.last_log_index, 0);
435        assert_eq!(metrics.last_applied, 0);
436    }
437
438    #[test]
439    fn test_entry_serialization() {
440        let entry = typ::Entry {
441            request: Request::ExecuteQuery {
442                tenant: "default".to_string(),
443                query: "MATCH (n) RETURN n".to_string(),
444            },
445        };
446        let json = serde_json::to_string(&entry).unwrap();
447        let deserialized: typ::Entry = serde_json::from_str(&json).unwrap();
448        assert!(matches!(deserialized.request, Request::ExecuteQuery { .. }));
449    }
450}