oxirs_federate/
advanced_consensus.rs

1//! Advanced Consensus Features
2//!
3//! Implements advanced distributed consensus mechanisms:
4//! - Byzantine Fault Tolerance (BFT)
5//! - Conflict-free Replicated Data Types (CRDTs)
6//! - Vector clocks for causality tracking
7//! - Distributed locking mechanisms
8//! - Network partition handling
9
10use anyhow::{anyhow, Result};
11use serde::{Deserialize, Serialize};
12use std::collections::{HashMap, HashSet};
13use std::sync::Arc;
14use std::time::SystemTime;
15use tokio::sync::RwLock;
16use tracing::info;
17
18/// Byzantine Fault Tolerant Consensus
19#[derive(Debug, Clone)]
20pub struct ByzantineFaultTolerance {
21    #[allow(dead_code)]
22    node_id: String,
23    nodes: Arc<RwLock<HashSet<String>>>,
24    f: usize, // Maximum number of Byzantine nodes tolerated
25    #[allow(dead_code)]
26    view: Arc<RwLock<u64>>,
27}
28
29impl ByzantineFaultTolerance {
30    pub fn new(node_id: String, total_nodes: usize) -> Self {
31        let f = (total_nodes - 1) / 3;
32        Self {
33            node_id,
34            nodes: Arc::new(RwLock::new(HashSet::new())),
35            f,
36            view: Arc::new(RwLock::new(0)),
37        }
38    }
39
40    pub async fn propose(&self, _value: Vec<u8>) -> Result<bool> {
41        info!("BFT proposing value from node {}", self.node_id);
42        let nodes = self.nodes.read().await;
43        let required = 2 * self.f + 1;
44        Ok(nodes.len() >= required)
45    }
46
47    pub async fn add_node(&self, node_id: String) {
48        self.nodes.write().await.insert(node_id);
49    }
50}
51
52/// Vector Clock for causality
53#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct VectorClock {
55    clock: HashMap<String, u64>,
56}
57
58impl Default for VectorClock {
59    fn default() -> Self {
60        Self::new()
61    }
62}
63
64impl VectorClock {
65    pub fn new() -> Self {
66        Self {
67            clock: HashMap::new(),
68        }
69    }
70
71    pub fn increment(&mut self, node_id: String) {
72        *self.clock.entry(node_id).or_insert(0) += 1;
73    }
74
75    pub fn merge(&mut self, other: &VectorClock) {
76        for (node, &timestamp) in &other.clock {
77            let entry = self.clock.entry(node.clone()).or_insert(0);
78            *entry = (*entry).max(timestamp);
79        }
80    }
81
82    pub fn happens_before(&self, other: &VectorClock) -> bool {
83        let mut strictly_less = false;
84
85        // Check all nodes in self
86        for (node, &my_time) in &self.clock {
87            let other_time = other.clock.get(node).copied().unwrap_or(0);
88            if my_time > other_time {
89                return false;
90            }
91            if my_time < other_time {
92                strictly_less = true;
93            }
94        }
95
96        // Also check nodes that exist in other but not in self
97        for (node, &other_time) in &other.clock {
98            if !self.clock.contains_key(node) {
99                // self[node] is implicitly 0, other[node] > 0
100                if other_time > 0 {
101                    strictly_less = true;
102                }
103            }
104        }
105
106        strictly_less
107    }
108}
109
110/// CRDT - Grow-only Counter
111#[derive(Debug, Clone, Serialize, Deserialize)]
112pub struct GCounter {
113    counts: HashMap<String, u64>,
114}
115
116impl Default for GCounter {
117    fn default() -> Self {
118        Self::new()
119    }
120}
121
122impl GCounter {
123    pub fn new() -> Self {
124        Self {
125            counts: HashMap::new(),
126        }
127    }
128
129    pub fn increment(&mut self, node_id: String, amount: u64) {
130        *self.counts.entry(node_id).or_insert(0) += amount;
131    }
132
133    pub fn value(&self) -> u64 {
134        self.counts.values().sum()
135    }
136
137    pub fn merge(&mut self, other: &GCounter) {
138        for (node, &count) in &other.counts {
139            let entry = self.counts.entry(node.clone()).or_insert(0);
140            *entry = (*entry).max(count);
141        }
142    }
143}
144
145/// CRDT - Positive-Negative Counter
146#[derive(Debug, Clone, Serialize, Deserialize)]
147pub struct PNCounter {
148    positive: GCounter,
149    negative: GCounter,
150}
151
152impl Default for PNCounter {
153    fn default() -> Self {
154        Self::new()
155    }
156}
157
158impl PNCounter {
159    pub fn new() -> Self {
160        Self {
161            positive: GCounter::new(),
162            negative: GCounter::new(),
163        }
164    }
165
166    pub fn increment(&mut self, node_id: String, amount: u64) {
167        self.positive.increment(node_id, amount);
168    }
169
170    pub fn decrement(&mut self, node_id: String, amount: u64) {
171        self.negative.increment(node_id, amount);
172    }
173
174    pub fn value(&self) -> i64 {
175        self.positive.value() as i64 - self.negative.value() as i64
176    }
177
178    pub fn merge(&mut self, other: &PNCounter) {
179        self.positive.merge(&other.positive);
180        self.negative.merge(&other.negative);
181    }
182}
183
184/// Distributed Lock
185#[derive(Debug, Clone)]
186pub struct DistributedLock {
187    lock_id: String,
188    holder: Arc<RwLock<Option<String>>>,
189    acquired_at: Arc<RwLock<Option<SystemTime>>>,
190    ttl: std::time::Duration,
191}
192
193impl DistributedLock {
194    pub fn new(lock_id: String, ttl: std::time::Duration) -> Self {
195        Self {
196            lock_id,
197            holder: Arc::new(RwLock::new(None)),
198            acquired_at: Arc::new(RwLock::new(None)),
199            ttl,
200        }
201    }
202
203    pub async fn acquire(&self, node_id: String) -> Result<bool> {
204        let mut holder = self.holder.write().await;
205
206        // Check if lock is expired
207        if let Some(acquired_time) = *self.acquired_at.read().await {
208            if acquired_time.elapsed().unwrap_or_default() > self.ttl {
209                *holder = None;
210            }
211        }
212
213        if holder.is_none() {
214            *holder = Some(node_id);
215            *self.acquired_at.write().await = Some(SystemTime::now());
216            info!("Lock {} acquired", self.lock_id);
217            Ok(true)
218        } else {
219            Ok(false)
220        }
221    }
222
223    pub async fn release(&self, node_id: &str) -> Result<()> {
224        let mut holder = self.holder.write().await;
225        if let Some(ref current_holder) = *holder {
226            if current_holder == node_id {
227                *holder = None;
228                *self.acquired_at.write().await = None;
229                info!("Lock {} released", self.lock_id);
230                return Ok(());
231            }
232        }
233        Err(anyhow!("Not the lock holder"))
234    }
235}
236
237/// Network Partition Detector
238#[derive(Debug, Clone)]
239pub struct NetworkPartitionDetector {
240    #[allow(dead_code)]
241    node_id: String,
242    heartbeats: Arc<RwLock<HashMap<String, SystemTime>>>,
243    timeout: std::time::Duration,
244}
245
246impl NetworkPartitionDetector {
247    pub fn new(node_id: String, timeout: std::time::Duration) -> Self {
248        Self {
249            node_id,
250            heartbeats: Arc::new(RwLock::new(HashMap::new())),
251            timeout,
252        }
253    }
254
255    pub async fn record_heartbeat(&self, node_id: String) {
256        self.heartbeats
257            .write()
258            .await
259            .insert(node_id, SystemTime::now());
260    }
261
262    pub async fn detect_partition(&self) -> Vec<String> {
263        let heartbeats = self.heartbeats.read().await;
264        let now = SystemTime::now();
265
266        heartbeats
267            .iter()
268            .filter_map(|(node, &last_heartbeat)| {
269                if now.duration_since(last_heartbeat).unwrap_or_default() > self.timeout {
270                    Some(node.clone())
271                } else {
272                    None
273                }
274            })
275            .collect()
276    }
277}
278
279/// Advanced Consensus System
280#[derive(Debug)]
281pub struct AdvancedConsensusSystem {
282    bft: Option<Arc<ByzantineFaultTolerance>>,
283    vector_clock: Arc<RwLock<VectorClock>>,
284    locks: Arc<RwLock<HashMap<String, DistributedLock>>>,
285    partition_detector: Arc<NetworkPartitionDetector>,
286}
287
288impl AdvancedConsensusSystem {
289    pub fn new(node_id: String, total_nodes: usize) -> Self {
290        Self {
291            bft: Some(Arc::new(ByzantineFaultTolerance::new(
292                node_id.clone(),
293                total_nodes,
294            ))),
295            vector_clock: Arc::new(RwLock::new(VectorClock::new())),
296            locks: Arc::new(RwLock::new(HashMap::new())),
297            partition_detector: Arc::new(NetworkPartitionDetector::new(
298                node_id,
299                std::time::Duration::from_secs(30),
300            )),
301        }
302    }
303
304    pub async fn propose_value(&self, value: Vec<u8>) -> Result<bool> {
305        if let Some(ref bft) = self.bft {
306            bft.propose(value).await
307        } else {
308            Err(anyhow!("BFT not enabled"))
309        }
310    }
311
312    pub async fn increment_clock(&self, node_id: String) {
313        self.vector_clock.write().await.increment(node_id);
314    }
315
316    pub async fn acquire_lock(&self, lock_id: String, node_id: String) -> Result<bool> {
317        let mut locks = self.locks.write().await;
318        let lock = locks
319            .entry(lock_id.clone())
320            .or_insert_with(|| DistributedLock::new(lock_id, std::time::Duration::from_secs(30)));
321        lock.acquire(node_id).await
322    }
323
324    pub async fn detect_partitions(&self) -> Vec<String> {
325        self.partition_detector.detect_partition().await
326    }
327}
328
329#[cfg(test)]
330mod tests {
331    use super::*;
332
333    #[tokio::test]
334    async fn test_bft() {
335        let bft = ByzantineFaultTolerance::new("node1".to_string(), 4);
336        bft.add_node("node2".to_string()).await;
337        bft.add_node("node3".to_string()).await;
338        bft.add_node("node4".to_string()).await;
339
340        let result = bft.propose(vec![1, 2, 3]).await;
341        assert!(result.is_ok());
342    }
343
344    #[test]
345    fn test_vector_clock() {
346        let mut clock1 = VectorClock::new();
347        let mut clock2 = VectorClock::new();
348
349        // Test concurrent events (neither happens before the other)
350        clock1.increment("node1".to_string());
351        clock2.increment("node2".to_string());
352
353        // These are concurrent, so neither should happen before the other
354        assert!(!clock1.happens_before(&clock2));
355        assert!(!clock2.happens_before(&clock1));
356
357        // Test causality: merge clock2 into clock1 and advance clock1
358        clock1.merge(&clock2);
359        clock1.increment("node1".to_string());
360
361        // Now clock1 should happen after clock2
362        assert!(clock2.happens_before(&clock1));
363        assert!(!clock1.happens_before(&clock2));
364    }
365
366    #[test]
367    fn test_crdt_gcounter() {
368        let mut counter = GCounter::new();
369        counter.increment("node1".to_string(), 5);
370        counter.increment("node2".to_string(), 3);
371        assert_eq!(counter.value(), 8);
372    }
373
374    #[test]
375    fn test_crdt_pncounter() {
376        let mut counter = PNCounter::new();
377        counter.increment("node1".to_string(), 10);
378        counter.decrement("node1".to_string(), 3);
379        assert_eq!(counter.value(), 7);
380    }
381
382    #[tokio::test]
383    async fn test_distributed_lock() {
384        let lock =
385            DistributedLock::new("test_lock".to_string(), std::time::Duration::from_secs(60));
386
387        let acquired = lock.acquire("node1".to_string()).await;
388        assert!(acquired.is_ok());
389        assert!(acquired.unwrap());
390
391        let acquired2 = lock.acquire("node2".to_string()).await;
392        assert!(acquired2.is_ok());
393        assert!(!acquired2.unwrap());
394    }
395}