Skip to main content

oxirs_core/distributed/bft/
node.rs

1//! BFT Node implementation and consensus logic
2
3#![allow(dead_code)]
4
5use super::detection::ByzantineDetector;
6use super::messages::BftMessage;
7use super::state_machine::RdfStateMachine;
8use super::types::*;
9use anyhow::{anyhow, Result};
10use dashmap::DashMap;
11use parking_lot::{Mutex, RwLock};
12use sha2::{Digest, Sha256};
13use std::collections::{HashMap, HashSet, VecDeque};
14use std::sync::Arc;
15use std::time::Instant;
16use tokio::sync::mpsc;
17
18/// Consensus state for a specific view and sequence
19#[derive(Debug, Clone)]
20pub struct ConsensusState {
21    pub phase: Phase,
22    pub request: Option<BftMessage>,
23    pub digest: Vec<u8>,
24    pub prepares: HashSet<NodeId>,
25    pub commits: HashSet<NodeId>,
26    pub replied: bool,
27}
28
29/// Byzantine fault tolerant node
30pub struct BftNode {
31    /// Node configuration
32    config: BftConfig,
33
34    /// This node's ID
35    node_id: NodeId,
36
37    /// Current view number
38    view: Arc<RwLock<ViewNumber>>,
39
40    /// Current phase
41    phase: Arc<RwLock<Phase>>,
42
43    /// Sequence number counter
44    sequence_counter: Arc<Mutex<SequenceNumber>>,
45
46    /// Node states (for each view and sequence)
47    states: Arc<DashMap<(ViewNumber, SequenceNumber), ConsensusState>>,
48
49    /// Message log
50    message_log: Arc<RwLock<VecDeque<BftMessage>>>,
51
52    /// Checkpoints
53    checkpoints: Arc<RwLock<HashMap<SequenceNumber, CheckpointProof>>>,
54
55    /// Stable checkpoint
56    stable_checkpoint: Arc<RwLock<SequenceNumber>>,
57
58    /// Other nodes in the cluster
59    nodes: Arc<RwLock<HashMap<NodeId, NodeInfo>>>,
60
61    /// Message sender
62    message_tx: mpsc::UnboundedSender<(NodeId, BftMessage)>,
63
64    /// Message receiver
65    message_rx: Arc<Mutex<mpsc::UnboundedReceiver<(NodeId, BftMessage)>>>,
66
67    /// RDF state machine
68    state_machine: Arc<RwLock<RdfStateMachine>>,
69
70    /// View change timer
71    view_change_timer: Arc<Mutex<Option<Instant>>>,
72
73    /// Byzantine behavior detection
74    byzantine_detector: Arc<RwLock<ByzantineDetector>>,
75}
76
77impl BftNode {
78    /// Create a new BFT node
79    pub fn new(config: BftConfig, node_id: NodeId, nodes: Vec<NodeInfo>) -> Self {
80        let (message_tx, message_rx) = mpsc::unbounded_channel();
81
82        let mut node_map = HashMap::new();
83        for node in nodes {
84            node_map.insert(node.id, node);
85        }
86
87        Self {
88            config: config.clone(),
89            node_id,
90            view: Arc::new(RwLock::new(0)),
91            phase: Arc::new(RwLock::new(Phase::Idle)),
92            sequence_counter: Arc::new(Mutex::new(0)),
93            states: Arc::new(DashMap::new()),
94            message_log: Arc::new(RwLock::new(VecDeque::new())),
95            checkpoints: Arc::new(RwLock::new(HashMap::new())),
96            stable_checkpoint: Arc::new(RwLock::new(0)),
97            nodes: Arc::new(RwLock::new(node_map)),
98            message_tx,
99            message_rx: Arc::new(Mutex::new(message_rx)),
100            state_machine: Arc::new(RwLock::new(RdfStateMachine::new())),
101            view_change_timer: Arc::new(Mutex::new(None)),
102            byzantine_detector: Arc::new(RwLock::new(ByzantineDetector::new(3))), // Default threshold of 3
103        }
104    }
105
106    /// Check if this node is the primary for the current view
107    pub fn is_primary(&self) -> bool {
108        let view = *self.view.read();
109        let num_nodes = self.nodes.read().len() as u64;
110        self.node_id == (view % num_nodes)
111    }
112
113    /// Get the primary node ID for a given view
114    pub fn get_primary(&self, view: ViewNumber) -> NodeId {
115        let num_nodes = self.nodes.read().len() as u64;
116        view % num_nodes
117    }
118
119    /// Calculate message digest
120    fn calculate_digest(message: &BftMessage) -> Vec<u8> {
121        let serialized =
122            oxicode::serde::encode_to_vec(message, oxicode::config::standard()).unwrap_or_default();
123        let mut hasher = Sha256::new();
124        hasher.update(&serialized);
125        hasher.finalize().to_vec()
126    }
127
128    /// Log a message
129    fn log_message(&self, message: BftMessage) {
130        let mut log = self.message_log.write();
131        log.push_back(message);
132
133        // Trim log if it gets too large
134        if log.len() > self.config.max_log_size {
135            log.pop_front();
136        }
137    }
138
139    /// Broadcast message to all other nodes
140    async fn broadcast_message(&self, message: BftMessage) -> Result<()> {
141        let nodes = self.nodes.read();
142        for (&node_id, _) in nodes.iter() {
143            if node_id != self.node_id {
144                self.message_tx
145                    .send((node_id, message.clone()))
146                    .map_err(|e| anyhow!("Failed to send message: {}", e))?;
147            }
148        }
149        Ok(())
150    }
151
152    /// Process incoming message with enhanced Byzantine detection
153    pub async fn process_message(&self, from: NodeId, message: BftMessage) -> Result<()> {
154        let start_time = Instant::now();
155
156        // Enhanced Byzantine detection checks
157        {
158            let mut detector = self.byzantine_detector.write();
159
160            // Check for replay attacks
161            let message_hash = Self::calculate_digest(&message);
162            if detector.check_replay_attack(from, message_hash.clone()) {
163                return Err(anyhow!("Replay attack detected from node {}", from));
164            }
165
166            // Monitor resource usage
167            detector.monitor_resource_usage(from);
168
169            // Update network partition status
170            detector.check_network_partition(from);
171
172            // Check for equivocation (view and sequence dependent)
173            if let BftMessage::PrePrepare { view, sequence, .. }
174            | BftMessage::Prepare { view, sequence, .. }
175            | BftMessage::Commit { view, sequence, .. } = &message
176            {
177                if detector.check_equivocation(from, *view, *sequence, message_hash) {
178                    return Err(anyhow!("Equivocation detected from node {}", from));
179                }
180            }
181        }
182
183        // Log message
184        self.log_message(message.clone());
185
186        match message {
187            BftMessage::Request { .. } if self.is_primary() => {
188                self.handle_client_request(message).await?;
189            }
190
191            BftMessage::PrePrepare {
192                view,
193                sequence,
194                digest,
195                request,
196            } => {
197                self.handle_pre_prepare(from, view, sequence, digest, *request)
198                    .await?;
199            }
200
201            BftMessage::Prepare {
202                view,
203                sequence,
204                digest,
205                node_id,
206            } => {
207                self.handle_prepare(view, sequence, digest, node_id).await?;
208            }
209
210            BftMessage::Commit {
211                view,
212                sequence,
213                digest,
214                node_id,
215            } => {
216                self.handle_commit(view, sequence, digest, node_id).await?;
217            }
218
219            BftMessage::Checkpoint {
220                sequence,
221                state_digest,
222                node_id,
223            } => {
224                self.handle_checkpoint(sequence, state_digest, node_id)
225                    .await?;
226            }
227
228            BftMessage::ViewChange { .. } => {
229                self.handle_view_change(message).await?;
230            }
231
232            BftMessage::NewView { .. } => {
233                self.handle_new_view(message).await?;
234            }
235
236            _ => {}
237        }
238
239        // Record timing information for Byzantine detection
240        let response_time = start_time.elapsed();
241        {
242            let mut detector = self.byzantine_detector.write();
243            detector.report_timing_anomaly(from, response_time);
244        }
245
246        Ok(())
247    }
248
249    /// Handle client request (primary only)
250    async fn handle_client_request(&self, request: BftMessage) -> Result<()> {
251        let view = *self.view.read();
252        let sequence = {
253            let mut counter = self.sequence_counter.lock();
254            *counter += 1;
255            *counter
256        };
257
258        let digest = Self::calculate_digest(&request);
259
260        // Create pre-prepare message
261        let pre_prepare = BftMessage::PrePrepare {
262            view,
263            sequence,
264            digest: digest.clone(),
265            request: Box::new(request.clone()),
266        };
267
268        // Store state
269        let state = ConsensusState {
270            phase: Phase::PrePrepare,
271            request: Some(request),
272            digest: digest.clone(),
273            prepares: HashSet::new(),
274            commits: HashSet::new(),
275            replied: false,
276        };
277        self.states.insert((view, sequence), state);
278
279        // Broadcast pre-prepare to all backup nodes
280        self.broadcast_message(pre_prepare).await?;
281
282        // Move to prepare phase
283        self.enter_prepare_phase(view, sequence, digest).await?;
284
285        Ok(())
286    }
287
288    /// Handle pre-prepare message (backup nodes)
289    async fn handle_pre_prepare(
290        &self,
291        from: NodeId,
292        view: ViewNumber,
293        sequence: SequenceNumber,
294        digest: Vec<u8>,
295        request: BftMessage,
296    ) -> Result<()> {
297        // Verify the message is from the primary
298        if from != self.get_primary(view) {
299            return Err(anyhow!("Pre-prepare not from primary"));
300        }
301
302        // Verify view number
303        if view != *self.view.read() {
304            return Ok(()); // Ignore messages from different views
305        }
306
307        // Verify digest
308        let calculated_digest = Self::calculate_digest(&request);
309        if digest != calculated_digest {
310            return Err(anyhow!("Invalid message digest"));
311        }
312
313        // Store state
314        let state = ConsensusState {
315            phase: Phase::PrePrepare,
316            request: Some(request),
317            digest: digest.clone(),
318            prepares: HashSet::new(),
319            commits: HashSet::new(),
320            replied: false,
321        };
322        self.states.insert((view, sequence), state);
323
324        // Enter prepare phase
325        self.enter_prepare_phase(view, sequence, digest).await?;
326
327        Ok(())
328    }
329
330    /// Enter prepare phase
331    async fn enter_prepare_phase(
332        &self,
333        view: ViewNumber,
334        sequence: SequenceNumber,
335        digest: Vec<u8>,
336    ) -> Result<()> {
337        // Send prepare message
338        let prepare = BftMessage::Prepare {
339            view,
340            sequence,
341            digest,
342            node_id: self.node_id,
343        };
344
345        self.broadcast_message(prepare).await?;
346
347        // Update phase
348        if let Some(mut state) = self.states.get_mut(&(view, sequence)) {
349            state.phase = Phase::Prepare;
350        }
351
352        Ok(())
353    }
354
355    /// Handle prepare message
356    async fn handle_prepare(
357        &self,
358        view: ViewNumber,
359        sequence: SequenceNumber,
360        digest: Vec<u8>,
361        node_id: NodeId,
362    ) -> Result<()> {
363        // Verify view
364        if view != *self.view.read() {
365            return Ok(());
366        }
367
368        // Update prepare count
369        let should_commit = {
370            match self.states.get_mut(&(view, sequence)) {
371                Some(mut state) if state.digest == digest => {
372                    state.prepares.insert(node_id);
373
374                    // Check if we have 2f prepares (including our own)
375                    state.prepares.len() >= 2 * self.config.fault_tolerance
376                }
377                _ => false,
378            }
379        };
380
381        // Enter commit phase if we have enough prepares
382        if should_commit {
383            self.enter_commit_phase(view, sequence, digest).await?;
384        }
385
386        Ok(())
387    }
388
389    /// Enter commit phase
390    async fn enter_commit_phase(
391        &self,
392        view: ViewNumber,
393        sequence: SequenceNumber,
394        digest: Vec<u8>,
395    ) -> Result<()> {
396        // Send commit message
397        let commit = BftMessage::Commit {
398            view,
399            sequence,
400            digest,
401            node_id: self.node_id,
402        };
403
404        self.broadcast_message(commit).await?;
405
406        // Update phase
407        if let Some(mut state) = self.states.get_mut(&(view, sequence)) {
408            state.phase = Phase::Commit;
409        }
410
411        Ok(())
412    }
413
414    /// Handle commit message
415    async fn handle_commit(
416        &self,
417        view: ViewNumber,
418        sequence: SequenceNumber,
419        digest: Vec<u8>,
420        node_id: NodeId,
421    ) -> Result<()> {
422        // Verify view
423        if view != *self.view.read() {
424            return Ok(());
425        }
426
427        // Update commit count and execute if ready
428        let should_execute = {
429            match self.states.get_mut(&(view, sequence)) {
430                Some(mut state) if state.digest == digest => {
431                    state.commits.insert(node_id);
432
433                    // Check if we have 2f+1 commits (including our own)
434                    state.commits.len() > 2 * self.config.fault_tolerance
435                }
436                _ => false,
437            }
438        };
439
440        // Execute operation if we have enough commits
441        if should_execute {
442            self.execute_operation(view, sequence).await?;
443        }
444
445        Ok(())
446    }
447
448    /// Execute operation after consensus
449    async fn execute_operation(&self, view: ViewNumber, sequence: SequenceNumber) -> Result<()> {
450        if let Some(state) = self.states.get(&(view, sequence)) {
451            if let Some(BftMessage::Request {
452                operation,
453                client_id,
454                ..
455            }) = &state.request
456            {
457                // Execute operation on state machine
458                let result = {
459                    let mut sm = self.state_machine.write();
460                    sm.execute(operation.clone())?
461                };
462
463                // Send reply to client
464                let reply = BftMessage::Reply {
465                    view,
466                    sequence,
467                    client_id: client_id.clone(),
468                    result,
469                    timestamp: std::time::SystemTime::now(),
470                };
471
472                // In a real implementation, we would send this to the client
473                // For now, we'll just log it
474                self.log_message(reply);
475
476                // Mark as replied
477                if let Some(mut state) = self.states.get_mut(&(view, sequence)) {
478                    state.replied = true;
479                }
480            }
481        }
482
483        // Check if we should create a checkpoint
484        if sequence % self.config.checkpoint_interval == 0 {
485            self.create_checkpoint(sequence).await?;
486        }
487
488        Ok(())
489    }
490
491    /// Create checkpoint
492    async fn create_checkpoint(&self, sequence: SequenceNumber) -> Result<()> {
493        let state_digest = {
494            let sm = self.state_machine.read();
495            sm.get_state_digest()
496        };
497
498        let checkpoint = BftMessage::Checkpoint {
499            sequence,
500            state_digest: state_digest.clone(),
501            node_id: self.node_id,
502        };
503
504        self.broadcast_message(checkpoint).await?;
505
506        // Store checkpoint
507        let proof = CheckpointProof {
508            sequence,
509            state_digest,
510            signatures: HashMap::new(), // Would contain actual signatures in real implementation
511        };
512
513        self.checkpoints.write().insert(sequence, proof);
514
515        Ok(())
516    }
517
518    /// Handle checkpoint message
519    async fn handle_checkpoint(
520        &self,
521        _sequence: SequenceNumber,
522        state_digest: Vec<u8>,
523        node_id: NodeId,
524    ) -> Result<()> {
525        // Verify checkpoint against our state
526        let our_digest = {
527            let sm = self.state_machine.read();
528            sm.get_state_digest()
529        };
530
531        if state_digest != our_digest {
532            // Byzantine detection - inconsistent state
533            let mut detector = self.byzantine_detector.write();
534            detector.report_inconsistent_pattern(node_id);
535            return Err(anyhow!("Inconsistent checkpoint from node {}", node_id));
536        }
537
538        Ok(())
539    }
540
541    /// Handle view change message
542    async fn handle_view_change(&self, _message: BftMessage) -> Result<()> {
543        // View change logic would be implemented here
544        // This is a complex process involving collecting prepared messages
545        // and agreeing on a new primary
546        Ok(())
547    }
548
549    /// Handle new view message
550    async fn handle_new_view(&self, _message: BftMessage) -> Result<()> {
551        // New view logic would be implemented here
552        // This involves processing the new view and starting consensus
553        // with any prepared but uncommitted operations
554        Ok(())
555    }
556
557    /// Get node status information
558    pub fn get_status(&self) -> NodeStatus {
559        NodeStatus {
560            node_id: self.node_id,
561            view: *self.view.read(),
562            phase: *self.phase.read(),
563            sequence: *self.sequence_counter.lock(),
564            suspected_nodes: self.byzantine_detector.read().get_suspected_nodes().clone(),
565        }
566    }
567}
568
569/// Node status information
570#[derive(Debug, Clone)]
571pub struct NodeStatus {
572    pub node_id: NodeId,
573    pub view: ViewNumber,
574    pub phase: Phase,
575    pub sequence: SequenceNumber,
576    pub suspected_nodes: HashSet<NodeId>,
577}
578
579// Clone implementation for BftNode
580impl Clone for BftNode {
581    fn clone(&self) -> Self {
582        let (message_tx, message_rx) = mpsc::unbounded_channel();
583
584        Self {
585            config: self.config.clone(),
586            node_id: self.node_id,
587            view: self.view.clone(),
588            phase: self.phase.clone(),
589            sequence_counter: self.sequence_counter.clone(),
590            states: self.states.clone(),
591            message_log: self.message_log.clone(),
592            checkpoints: self.checkpoints.clone(),
593            stable_checkpoint: self.stable_checkpoint.clone(),
594            nodes: self.nodes.clone(),
595            message_tx,
596            message_rx: Arc::new(Mutex::new(message_rx)),
597            state_machine: self.state_machine.clone(),
598            view_change_timer: self.view_change_timer.clone(),
599            byzantine_detector: self.byzantine_detector.clone(),
600        }
601    }
602}