1#![allow(dead_code)]
4
5use super::detection::ByzantineDetector;
6use super::messages::BftMessage;
7use super::state_machine::RdfStateMachine;
8use super::types::*;
9use anyhow::{anyhow, Result};
10use dashmap::DashMap;
11use parking_lot::{Mutex, RwLock};
12use sha2::{Digest, Sha256};
13use std::collections::{HashMap, HashSet, VecDeque};
14use std::sync::Arc;
15use std::time::Instant;
16use tokio::sync::mpsc;
17
18#[derive(Debug, Clone)]
20pub struct ConsensusState {
21 pub phase: Phase,
22 pub request: Option<BftMessage>,
23 pub digest: Vec<u8>,
24 pub prepares: HashSet<NodeId>,
25 pub commits: HashSet<NodeId>,
26 pub replied: bool,
27}
28
29pub struct BftNode {
31 config: BftConfig,
33
34 node_id: NodeId,
36
37 view: Arc<RwLock<ViewNumber>>,
39
40 phase: Arc<RwLock<Phase>>,
42
43 sequence_counter: Arc<Mutex<SequenceNumber>>,
45
46 states: Arc<DashMap<(ViewNumber, SequenceNumber), ConsensusState>>,
48
49 message_log: Arc<RwLock<VecDeque<BftMessage>>>,
51
52 checkpoints: Arc<RwLock<HashMap<SequenceNumber, CheckpointProof>>>,
54
55 stable_checkpoint: Arc<RwLock<SequenceNumber>>,
57
58 nodes: Arc<RwLock<HashMap<NodeId, NodeInfo>>>,
60
61 message_tx: mpsc::UnboundedSender<(NodeId, BftMessage)>,
63
64 message_rx: Arc<Mutex<mpsc::UnboundedReceiver<(NodeId, BftMessage)>>>,
66
67 state_machine: Arc<RwLock<RdfStateMachine>>,
69
70 view_change_timer: Arc<Mutex<Option<Instant>>>,
72
73 byzantine_detector: Arc<RwLock<ByzantineDetector>>,
75}
76
77impl BftNode {
78 pub fn new(config: BftConfig, node_id: NodeId, nodes: Vec<NodeInfo>) -> Self {
80 let (message_tx, message_rx) = mpsc::unbounded_channel();
81
82 let mut node_map = HashMap::new();
83 for node in nodes {
84 node_map.insert(node.id, node);
85 }
86
87 Self {
88 config: config.clone(),
89 node_id,
90 view: Arc::new(RwLock::new(0)),
91 phase: Arc::new(RwLock::new(Phase::Idle)),
92 sequence_counter: Arc::new(Mutex::new(0)),
93 states: Arc::new(DashMap::new()),
94 message_log: Arc::new(RwLock::new(VecDeque::new())),
95 checkpoints: Arc::new(RwLock::new(HashMap::new())),
96 stable_checkpoint: Arc::new(RwLock::new(0)),
97 nodes: Arc::new(RwLock::new(node_map)),
98 message_tx,
99 message_rx: Arc::new(Mutex::new(message_rx)),
100 state_machine: Arc::new(RwLock::new(RdfStateMachine::new())),
101 view_change_timer: Arc::new(Mutex::new(None)),
102 byzantine_detector: Arc::new(RwLock::new(ByzantineDetector::new(3))), }
104 }
105
106 pub fn is_primary(&self) -> bool {
108 let view = *self.view.read();
109 let num_nodes = self.nodes.read().len() as u64;
110 self.node_id == (view % num_nodes)
111 }
112
113 pub fn get_primary(&self, view: ViewNumber) -> NodeId {
115 let num_nodes = self.nodes.read().len() as u64;
116 view % num_nodes
117 }
118
119 fn calculate_digest(message: &BftMessage) -> Vec<u8> {
121 let serialized =
122 oxicode::serde::encode_to_vec(message, oxicode::config::standard()).unwrap_or_default();
123 let mut hasher = Sha256::new();
124 hasher.update(&serialized);
125 hasher.finalize().to_vec()
126 }
127
128 fn log_message(&self, message: BftMessage) {
130 let mut log = self.message_log.write();
131 log.push_back(message);
132
133 if log.len() > self.config.max_log_size {
135 log.pop_front();
136 }
137 }
138
139 async fn broadcast_message(&self, message: BftMessage) -> Result<()> {
141 let nodes = self.nodes.read();
142 for (&node_id, _) in nodes.iter() {
143 if node_id != self.node_id {
144 self.message_tx
145 .send((node_id, message.clone()))
146 .map_err(|e| anyhow!("Failed to send message: {}", e))?;
147 }
148 }
149 Ok(())
150 }
151
152 pub async fn process_message(&self, from: NodeId, message: BftMessage) -> Result<()> {
154 let start_time = Instant::now();
155
156 {
158 let mut detector = self.byzantine_detector.write();
159
160 let message_hash = Self::calculate_digest(&message);
162 if detector.check_replay_attack(from, message_hash.clone()) {
163 return Err(anyhow!("Replay attack detected from node {}", from));
164 }
165
166 detector.monitor_resource_usage(from);
168
169 detector.check_network_partition(from);
171
172 if let BftMessage::PrePrepare { view, sequence, .. }
174 | BftMessage::Prepare { view, sequence, .. }
175 | BftMessage::Commit { view, sequence, .. } = &message
176 {
177 if detector.check_equivocation(from, *view, *sequence, message_hash) {
178 return Err(anyhow!("Equivocation detected from node {}", from));
179 }
180 }
181 }
182
183 self.log_message(message.clone());
185
186 match message {
187 BftMessage::Request { .. } if self.is_primary() => {
188 self.handle_client_request(message).await?;
189 }
190
191 BftMessage::PrePrepare {
192 view,
193 sequence,
194 digest,
195 request,
196 } => {
197 self.handle_pre_prepare(from, view, sequence, digest, *request)
198 .await?;
199 }
200
201 BftMessage::Prepare {
202 view,
203 sequence,
204 digest,
205 node_id,
206 } => {
207 self.handle_prepare(view, sequence, digest, node_id).await?;
208 }
209
210 BftMessage::Commit {
211 view,
212 sequence,
213 digest,
214 node_id,
215 } => {
216 self.handle_commit(view, sequence, digest, node_id).await?;
217 }
218
219 BftMessage::Checkpoint {
220 sequence,
221 state_digest,
222 node_id,
223 } => {
224 self.handle_checkpoint(sequence, state_digest, node_id)
225 .await?;
226 }
227
228 BftMessage::ViewChange { .. } => {
229 self.handle_view_change(message).await?;
230 }
231
232 BftMessage::NewView { .. } => {
233 self.handle_new_view(message).await?;
234 }
235
236 _ => {}
237 }
238
239 let response_time = start_time.elapsed();
241 {
242 let mut detector = self.byzantine_detector.write();
243 detector.report_timing_anomaly(from, response_time);
244 }
245
246 Ok(())
247 }
248
249 async fn handle_client_request(&self, request: BftMessage) -> Result<()> {
251 let view = *self.view.read();
252 let sequence = {
253 let mut counter = self.sequence_counter.lock();
254 *counter += 1;
255 *counter
256 };
257
258 let digest = Self::calculate_digest(&request);
259
260 let pre_prepare = BftMessage::PrePrepare {
262 view,
263 sequence,
264 digest: digest.clone(),
265 request: Box::new(request.clone()),
266 };
267
268 let state = ConsensusState {
270 phase: Phase::PrePrepare,
271 request: Some(request),
272 digest: digest.clone(),
273 prepares: HashSet::new(),
274 commits: HashSet::new(),
275 replied: false,
276 };
277 self.states.insert((view, sequence), state);
278
279 self.broadcast_message(pre_prepare).await?;
281
282 self.enter_prepare_phase(view, sequence, digest).await?;
284
285 Ok(())
286 }
287
288 async fn handle_pre_prepare(
290 &self,
291 from: NodeId,
292 view: ViewNumber,
293 sequence: SequenceNumber,
294 digest: Vec<u8>,
295 request: BftMessage,
296 ) -> Result<()> {
297 if from != self.get_primary(view) {
299 return Err(anyhow!("Pre-prepare not from primary"));
300 }
301
302 if view != *self.view.read() {
304 return Ok(()); }
306
307 let calculated_digest = Self::calculate_digest(&request);
309 if digest != calculated_digest {
310 return Err(anyhow!("Invalid message digest"));
311 }
312
313 let state = ConsensusState {
315 phase: Phase::PrePrepare,
316 request: Some(request),
317 digest: digest.clone(),
318 prepares: HashSet::new(),
319 commits: HashSet::new(),
320 replied: false,
321 };
322 self.states.insert((view, sequence), state);
323
324 self.enter_prepare_phase(view, sequence, digest).await?;
326
327 Ok(())
328 }
329
330 async fn enter_prepare_phase(
332 &self,
333 view: ViewNumber,
334 sequence: SequenceNumber,
335 digest: Vec<u8>,
336 ) -> Result<()> {
337 let prepare = BftMessage::Prepare {
339 view,
340 sequence,
341 digest,
342 node_id: self.node_id,
343 };
344
345 self.broadcast_message(prepare).await?;
346
347 if let Some(mut state) = self.states.get_mut(&(view, sequence)) {
349 state.phase = Phase::Prepare;
350 }
351
352 Ok(())
353 }
354
355 async fn handle_prepare(
357 &self,
358 view: ViewNumber,
359 sequence: SequenceNumber,
360 digest: Vec<u8>,
361 node_id: NodeId,
362 ) -> Result<()> {
363 if view != *self.view.read() {
365 return Ok(());
366 }
367
368 let should_commit = {
370 match self.states.get_mut(&(view, sequence)) {
371 Some(mut state) if state.digest == digest => {
372 state.prepares.insert(node_id);
373
374 state.prepares.len() >= 2 * self.config.fault_tolerance
376 }
377 _ => false,
378 }
379 };
380
381 if should_commit {
383 self.enter_commit_phase(view, sequence, digest).await?;
384 }
385
386 Ok(())
387 }
388
389 async fn enter_commit_phase(
391 &self,
392 view: ViewNumber,
393 sequence: SequenceNumber,
394 digest: Vec<u8>,
395 ) -> Result<()> {
396 let commit = BftMessage::Commit {
398 view,
399 sequence,
400 digest,
401 node_id: self.node_id,
402 };
403
404 self.broadcast_message(commit).await?;
405
406 if let Some(mut state) = self.states.get_mut(&(view, sequence)) {
408 state.phase = Phase::Commit;
409 }
410
411 Ok(())
412 }
413
414 async fn handle_commit(
416 &self,
417 view: ViewNumber,
418 sequence: SequenceNumber,
419 digest: Vec<u8>,
420 node_id: NodeId,
421 ) -> Result<()> {
422 if view != *self.view.read() {
424 return Ok(());
425 }
426
427 let should_execute = {
429 match self.states.get_mut(&(view, sequence)) {
430 Some(mut state) if state.digest == digest => {
431 state.commits.insert(node_id);
432
433 state.commits.len() > 2 * self.config.fault_tolerance
435 }
436 _ => false,
437 }
438 };
439
440 if should_execute {
442 self.execute_operation(view, sequence).await?;
443 }
444
445 Ok(())
446 }
447
448 async fn execute_operation(&self, view: ViewNumber, sequence: SequenceNumber) -> Result<()> {
450 if let Some(state) = self.states.get(&(view, sequence)) {
451 if let Some(BftMessage::Request {
452 operation,
453 client_id,
454 ..
455 }) = &state.request
456 {
457 let result = {
459 let mut sm = self.state_machine.write();
460 sm.execute(operation.clone())?
461 };
462
463 let reply = BftMessage::Reply {
465 view,
466 sequence,
467 client_id: client_id.clone(),
468 result,
469 timestamp: std::time::SystemTime::now(),
470 };
471
472 self.log_message(reply);
475
476 if let Some(mut state) = self.states.get_mut(&(view, sequence)) {
478 state.replied = true;
479 }
480 }
481 }
482
483 if sequence % self.config.checkpoint_interval == 0 {
485 self.create_checkpoint(sequence).await?;
486 }
487
488 Ok(())
489 }
490
491 async fn create_checkpoint(&self, sequence: SequenceNumber) -> Result<()> {
493 let state_digest = {
494 let sm = self.state_machine.read();
495 sm.get_state_digest()
496 };
497
498 let checkpoint = BftMessage::Checkpoint {
499 sequence,
500 state_digest: state_digest.clone(),
501 node_id: self.node_id,
502 };
503
504 self.broadcast_message(checkpoint).await?;
505
506 let proof = CheckpointProof {
508 sequence,
509 state_digest,
510 signatures: HashMap::new(), };
512
513 self.checkpoints.write().insert(sequence, proof);
514
515 Ok(())
516 }
517
518 async fn handle_checkpoint(
520 &self,
521 _sequence: SequenceNumber,
522 state_digest: Vec<u8>,
523 node_id: NodeId,
524 ) -> Result<()> {
525 let our_digest = {
527 let sm = self.state_machine.read();
528 sm.get_state_digest()
529 };
530
531 if state_digest != our_digest {
532 let mut detector = self.byzantine_detector.write();
534 detector.report_inconsistent_pattern(node_id);
535 return Err(anyhow!("Inconsistent checkpoint from node {}", node_id));
536 }
537
538 Ok(())
539 }
540
541 async fn handle_view_change(&self, _message: BftMessage) -> Result<()> {
543 Ok(())
547 }
548
549 async fn handle_new_view(&self, _message: BftMessage) -> Result<()> {
551 Ok(())
555 }
556
557 pub fn get_status(&self) -> NodeStatus {
559 NodeStatus {
560 node_id: self.node_id,
561 view: *self.view.read(),
562 phase: *self.phase.read(),
563 sequence: *self.sequence_counter.lock(),
564 suspected_nodes: self.byzantine_detector.read().get_suspected_nodes().clone(),
565 }
566 }
567}
568
569#[derive(Debug, Clone)]
571pub struct NodeStatus {
572 pub node_id: NodeId,
573 pub view: ViewNumber,
574 pub phase: Phase,
575 pub sequence: SequenceNumber,
576 pub suspected_nodes: HashSet<NodeId>,
577}
578
579impl Clone for BftNode {
581 fn clone(&self) -> Self {
582 let (message_tx, message_rx) = mpsc::unbounded_channel();
583
584 Self {
585 config: self.config.clone(),
586 node_id: self.node_id,
587 view: self.view.clone(),
588 phase: self.phase.clone(),
589 sequence_counter: self.sequence_counter.clone(),
590 states: self.states.clone(),
591 message_log: self.message_log.clone(),
592 checkpoints: self.checkpoints.clone(),
593 stable_checkpoint: self.stable_checkpoint.clone(),
594 nodes: self.nodes.clone(),
595 message_tx,
596 message_rx: Arc::new(Mutex::new(message_rx)),
597 state_machine: self.state_machine.clone(),
598 view_change_timer: self.view_change_timer.clone(),
599 byzantine_detector: self.byzantine_detector.clone(),
600 }
601 }
602}