1use std::collections::HashSet;
11use std::time::Instant;
12
13use crate::error::{RaftError, Result};
14use crate::log::RaftLog;
15use crate::message::{AppendEntriesRequest, LogEntry};
16use crate::state::{HardState, LeaderState, NodeRole, VolatileState};
17use crate::storage::LogStorage;
18
19use super::config::RaftConfig;
20
21#[derive(Debug, Default)]
26pub struct Ready {
27 pub hard_state: Option<HardState>,
29 pub messages: Vec<(u64, AppendEntriesRequest)>,
31 pub vote_requests: Vec<(u64, crate::message::RequestVoteRequest)>,
33 pub committed_entries: Vec<LogEntry>,
35 pub snapshots_needed: Vec<u64>,
38}
39
40impl Ready {
41 pub fn is_empty(&self) -> bool {
42 self.hard_state.is_none()
43 && self.messages.is_empty()
44 && self.vote_requests.is_empty()
45 && self.committed_entries.is_empty()
46 && self.snapshots_needed.is_empty()
47 }
48}
49
50pub struct RaftNode<S: LogStorage> {
56 pub(super) config: RaftConfig,
57 pub(super) role: NodeRole,
58 pub(super) hard_state: HardState,
59 pub(super) volatile: VolatileState,
60 pub(super) leader_state: Option<LeaderState>,
61 pub(super) log: RaftLog<S>,
62 pub(super) election_deadline: Instant,
64 pub(super) heartbeat_deadline: Instant,
66 pub(super) votes_received: HashSet<u64>,
68 pub(super) ready: Ready,
70 pub(super) leader_id: u64,
72}
73
74impl<S: LogStorage> RaftNode<S> {
75 pub fn new(config: RaftConfig, storage: S) -> Self {
81 let now = Instant::now();
82 let role = if config.starts_as_observer {
83 NodeRole::Observer
84 } else if config.starts_as_learner {
85 NodeRole::Learner
86 } else {
87 NodeRole::Follower
88 };
89 Self {
90 log: RaftLog::new(storage),
91 role,
92 hard_state: HardState::new(),
93 volatile: VolatileState::new(),
94 leader_state: None,
95 election_deadline: now + config.election_timeout_max,
96 heartbeat_deadline: now,
97 votes_received: HashSet::new(),
98 ready: Ready::default(),
99 leader_id: 0,
100 config,
101 }
102 }
103
104 pub fn restore(&mut self) -> Result<()> {
106 self.hard_state = self.log.storage().load_hard_state()?;
107 self.log.restore()?;
108 self.reset_election_timeout();
109 Ok(())
110 }
111
112 pub fn node_id(&self) -> u64 {
113 self.config.node_id
114 }
115
116 pub fn group_id(&self) -> u64 {
117 self.config.group_id
118 }
119
120 pub fn role(&self) -> NodeRole {
121 self.role
122 }
123
124 pub fn leader_id(&self) -> u64 {
125 self.leader_id
126 }
127
128 pub fn current_term(&self) -> u64 {
129 self.hard_state.current_term
130 }
131
132 pub fn commit_index(&self) -> u64 {
133 self.volatile.commit_index
134 }
135
136 pub fn last_applied(&self) -> u64 {
137 self.volatile.last_applied
138 }
139
140 pub fn election_deadline_override(&mut self, deadline: Instant) {
142 self.election_deadline = deadline;
143 }
144
145 pub fn take_ready(&mut self) -> Ready {
148 std::mem::take(&mut self.ready)
149 }
150
151 pub fn advance_applied(&mut self, applied_to: u64) {
153 self.volatile.last_applied = applied_to;
154 }
155
156 pub fn match_index_for(&self, peer: u64) -> Option<u64> {
159 self.leader_state
160 .as_ref()
161 .map(|ls| ls.match_index_for(peer))
162 }
163
164 pub fn log_snapshot_index(&self) -> u64 {
165 self.log.snapshot_index()
166 }
167
168 pub fn log_snapshot_term(&self) -> u64 {
169 self.log.snapshot_term()
170 }
171
172 pub fn log_entries_range(
178 &self,
179 lo: u64,
180 hi: u64,
181 ) -> crate::error::Result<&[crate::message::LogEntry]> {
182 let hi = hi.min(self.volatile.commit_index);
183 self.log.entries_range(lo, hi)
184 }
185
186 pub fn peers(&self) -> &[u64] {
188 &self.config.peers
189 }
190
191 pub fn voters(&self) -> &[u64] {
194 &self.config.peers
195 }
196
197 pub fn learners(&self) -> &[u64] {
199 &self.config.learners
200 }
201
202 pub fn observers(&self) -> &[u64] {
204 &self.config.observers
205 }
206
207 pub fn is_learner_peer(&self, peer: u64) -> bool {
209 self.config.learners.contains(&peer)
210 }
211
212 pub fn tick(&mut self) {
214 let now = Instant::now();
215
216 match self.role {
217 NodeRole::Follower | NodeRole::Candidate => {
218 if now >= self.election_deadline {
219 self.start_election();
220 }
221 }
222 NodeRole::Leader => {
223 if now >= self.heartbeat_deadline {
224 self.replicate_to_all();
225 self.heartbeat_deadline = now + self.config.heartbeat_interval;
226 }
227 }
228 NodeRole::Learner => {
229 }
232 NodeRole::Observer => {
233 }
237 }
238 }
239
240 pub fn propose(&mut self, data: Vec<u8>) -> Result<u64> {
242 if self.role != NodeRole::Leader {
243 return Err(RaftError::NotLeader {
244 leader_hint: if self.leader_id != 0 {
245 Some(self.leader_id)
246 } else {
247 None
248 },
249 });
250 }
251
252 let index = self.log.last_index() + 1;
253 let entry = LogEntry {
254 term: self.hard_state.current_term,
255 index,
256 data,
257 };
258
259 self.log.append(entry)?;
260 self.replicate_to_all();
261
262 if self.config.cluster_size() == 1 {
264 self.volatile.commit_index = index;
265 self.collect_committed_entries();
266 }
267
268 Ok(index)
269 }
270}
271
272#[cfg(test)]
273mod tests {
274 use super::*;
275 use crate::storage::MemStorage;
276 use std::time::Duration;
277
278 fn test_config(node_id: u64, peers: Vec<u64>) -> RaftConfig {
279 RaftConfig {
280 node_id,
281 group_id: 1,
282 peers,
283 learners: vec![],
284 observers: vec![],
285 starts_as_learner: false,
286 starts_as_observer: false,
287 election_timeout_min: Duration::from_millis(150),
288 election_timeout_max: Duration::from_millis(300),
289 heartbeat_interval: Duration::from_millis(50),
290 }
291 }
292
293 #[test]
294 fn single_node_election() {
295 let config = test_config(1, vec![]);
296 let mut node = RaftNode::new(config, MemStorage::new());
297
298 node.election_deadline = Instant::now() - Duration::from_millis(1);
299 node.tick();
300
301 assert_eq!(node.role(), NodeRole::Leader);
302 assert_eq!(node.current_term(), 1);
303 assert_eq!(node.leader_id(), 1);
304 }
305
306 #[test]
307 fn single_node_propose_and_commit() {
308 let config = test_config(1, vec![]);
309 let mut node = RaftNode::new(config, MemStorage::new());
310 node.election_deadline = Instant::now() - Duration::from_millis(1);
311 node.tick();
312 assert_eq!(node.role(), NodeRole::Leader);
313
314 let ready = node.take_ready();
315 assert!(!ready.committed_entries.is_empty());
316 node.advance_applied(ready.committed_entries.last().unwrap().index);
317
318 let idx = node.propose(b"hello".to_vec()).unwrap();
319 assert_eq!(idx, 2);
320
321 let ready = node.take_ready();
322 assert_eq!(ready.committed_entries.len(), 1);
323 assert_eq!(ready.committed_entries[0].data, b"hello");
324 }
325
326 #[test]
327 fn propose_as_follower_fails() {
328 let config = test_config(1, vec![2, 3]);
329 let node = &mut RaftNode::new(config, MemStorage::new());
330 let err = node.propose(b"data".to_vec()).unwrap_err();
331 assert!(matches!(err, RaftError::NotLeader { .. }));
332 }
333
334 #[test]
335 fn snapshot_needed_after_compaction() {
336 let config = test_config(1, vec![2, 3]);
337 let mut node = RaftNode::new(config, MemStorage::new());
338
339 node.election_deadline = Instant::now() - Duration::from_millis(1);
340 node.tick();
341 let _ready = node.take_ready();
342 let resp = crate::message::RequestVoteResponse {
343 term: 1,
344 vote_granted: true,
345 };
346 node.handle_request_vote_response(2, &resp);
347 assert_eq!(node.role(), NodeRole::Leader);
348 let _ = node.take_ready();
349
350 for i in 0..9 {
351 node.propose(vec![i]).unwrap();
352 }
353 let _ = node.take_ready();
354
355 node.log.apply_snapshot(8, 1);
356
357 node.replicate_to_all();
358 let ready = node.take_ready();
359
360 assert!(
361 !ready.snapshots_needed.is_empty(),
362 "expected snapshots_needed to be non-empty"
363 );
364 }
365
366 #[test]
367 fn starts_as_learner_role() {
368 let mut cfg = test_config(2, vec![1]);
369 cfg.starts_as_learner = true;
370 let node = RaftNode::new(cfg, MemStorage::new());
371 assert_eq!(node.role(), NodeRole::Learner);
372 }
373
374 #[test]
375 fn learner_tick_does_not_start_election() {
376 let mut cfg = test_config(2, vec![1]);
377 cfg.starts_as_learner = true;
378 let mut node = RaftNode::new(cfg, MemStorage::new());
379 node.election_deadline = Instant::now() - Duration::from_millis(1);
382 node.tick();
383 assert_eq!(node.role(), NodeRole::Learner);
384 assert_eq!(node.current_term(), 0);
385 }
386}