1use anyhow::{anyhow, Result};
11use serde::{Deserialize, Serialize};
12use std::collections::{HashMap, HashSet};
13use std::sync::Arc;
14use std::time::SystemTime;
15use tokio::sync::RwLock;
16use tracing::info;
17
18#[derive(Debug, Clone)]
20pub struct ByzantineFaultTolerance {
21 #[allow(dead_code)]
22 node_id: String,
23 nodes: Arc<RwLock<HashSet<String>>>,
24 f: usize, #[allow(dead_code)]
26 view: Arc<RwLock<u64>>,
27}
28
29impl ByzantineFaultTolerance {
30 pub fn new(node_id: String, total_nodes: usize) -> Self {
31 let f = (total_nodes - 1) / 3;
32 Self {
33 node_id,
34 nodes: Arc::new(RwLock::new(HashSet::new())),
35 f,
36 view: Arc::new(RwLock::new(0)),
37 }
38 }
39
40 pub async fn propose(&self, _value: Vec<u8>) -> Result<bool> {
41 info!("BFT proposing value from node {}", self.node_id);
42 let nodes = self.nodes.read().await;
43 let required = 2 * self.f + 1;
44 Ok(nodes.len() >= required)
45 }
46
47 pub async fn add_node(&self, node_id: String) {
48 self.nodes.write().await.insert(node_id);
49 }
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct VectorClock {
55 clock: HashMap<String, u64>,
56}
57
58impl Default for VectorClock {
59 fn default() -> Self {
60 Self::new()
61 }
62}
63
64impl VectorClock {
65 pub fn new() -> Self {
66 Self {
67 clock: HashMap::new(),
68 }
69 }
70
71 pub fn increment(&mut self, node_id: String) {
72 *self.clock.entry(node_id).or_insert(0) += 1;
73 }
74
75 pub fn merge(&mut self, other: &VectorClock) {
76 for (node, ×tamp) in &other.clock {
77 let entry = self.clock.entry(node.clone()).or_insert(0);
78 *entry = (*entry).max(timestamp);
79 }
80 }
81
82 pub fn happens_before(&self, other: &VectorClock) -> bool {
83 let mut strictly_less = false;
84
85 for (node, &my_time) in &self.clock {
87 let other_time = other.clock.get(node).copied().unwrap_or(0);
88 if my_time > other_time {
89 return false;
90 }
91 if my_time < other_time {
92 strictly_less = true;
93 }
94 }
95
96 for (node, &other_time) in &other.clock {
98 if !self.clock.contains_key(node) {
99 if other_time > 0 {
101 strictly_less = true;
102 }
103 }
104 }
105
106 strictly_less
107 }
108}
109
110#[derive(Debug, Clone, Serialize, Deserialize)]
112pub struct GCounter {
113 counts: HashMap<String, u64>,
114}
115
116impl Default for GCounter {
117 fn default() -> Self {
118 Self::new()
119 }
120}
121
122impl GCounter {
123 pub fn new() -> Self {
124 Self {
125 counts: HashMap::new(),
126 }
127 }
128
129 pub fn increment(&mut self, node_id: String, amount: u64) {
130 *self.counts.entry(node_id).or_insert(0) += amount;
131 }
132
133 pub fn value(&self) -> u64 {
134 self.counts.values().sum()
135 }
136
137 pub fn merge(&mut self, other: &GCounter) {
138 for (node, &count) in &other.counts {
139 let entry = self.counts.entry(node.clone()).or_insert(0);
140 *entry = (*entry).max(count);
141 }
142 }
143}
144
145#[derive(Debug, Clone, Serialize, Deserialize)]
147pub struct PNCounter {
148 positive: GCounter,
149 negative: GCounter,
150}
151
152impl Default for PNCounter {
153 fn default() -> Self {
154 Self::new()
155 }
156}
157
158impl PNCounter {
159 pub fn new() -> Self {
160 Self {
161 positive: GCounter::new(),
162 negative: GCounter::new(),
163 }
164 }
165
166 pub fn increment(&mut self, node_id: String, amount: u64) {
167 self.positive.increment(node_id, amount);
168 }
169
170 pub fn decrement(&mut self, node_id: String, amount: u64) {
171 self.negative.increment(node_id, amount);
172 }
173
174 pub fn value(&self) -> i64 {
175 self.positive.value() as i64 - self.negative.value() as i64
176 }
177
178 pub fn merge(&mut self, other: &PNCounter) {
179 self.positive.merge(&other.positive);
180 self.negative.merge(&other.negative);
181 }
182}
183
184#[derive(Debug, Clone)]
186pub struct DistributedLock {
187 lock_id: String,
188 holder: Arc<RwLock<Option<String>>>,
189 acquired_at: Arc<RwLock<Option<SystemTime>>>,
190 ttl: std::time::Duration,
191}
192
193impl DistributedLock {
194 pub fn new(lock_id: String, ttl: std::time::Duration) -> Self {
195 Self {
196 lock_id,
197 holder: Arc::new(RwLock::new(None)),
198 acquired_at: Arc::new(RwLock::new(None)),
199 ttl,
200 }
201 }
202
203 pub async fn acquire(&self, node_id: String) -> Result<bool> {
204 let mut holder = self.holder.write().await;
205
206 if let Some(acquired_time) = *self.acquired_at.read().await {
208 if acquired_time.elapsed().unwrap_or_default() > self.ttl {
209 *holder = None;
210 }
211 }
212
213 if holder.is_none() {
214 *holder = Some(node_id);
215 *self.acquired_at.write().await = Some(SystemTime::now());
216 info!("Lock {} acquired", self.lock_id);
217 Ok(true)
218 } else {
219 Ok(false)
220 }
221 }
222
223 pub async fn release(&self, node_id: &str) -> Result<()> {
224 let mut holder = self.holder.write().await;
225 if let Some(ref current_holder) = *holder {
226 if current_holder == node_id {
227 *holder = None;
228 *self.acquired_at.write().await = None;
229 info!("Lock {} released", self.lock_id);
230 return Ok(());
231 }
232 }
233 Err(anyhow!("Not the lock holder"))
234 }
235}
236
237#[derive(Debug, Clone)]
239pub struct NetworkPartitionDetector {
240 #[allow(dead_code)]
241 node_id: String,
242 heartbeats: Arc<RwLock<HashMap<String, SystemTime>>>,
243 timeout: std::time::Duration,
244}
245
246impl NetworkPartitionDetector {
247 pub fn new(node_id: String, timeout: std::time::Duration) -> Self {
248 Self {
249 node_id,
250 heartbeats: Arc::new(RwLock::new(HashMap::new())),
251 timeout,
252 }
253 }
254
255 pub async fn record_heartbeat(&self, node_id: String) {
256 self.heartbeats
257 .write()
258 .await
259 .insert(node_id, SystemTime::now());
260 }
261
262 pub async fn detect_partition(&self) -> Vec<String> {
263 let heartbeats = self.heartbeats.read().await;
264 let now = SystemTime::now();
265
266 heartbeats
267 .iter()
268 .filter_map(|(node, &last_heartbeat)| {
269 if now.duration_since(last_heartbeat).unwrap_or_default() > self.timeout {
270 Some(node.clone())
271 } else {
272 None
273 }
274 })
275 .collect()
276 }
277}
278
279#[derive(Debug)]
281pub struct AdvancedConsensusSystem {
282 bft: Option<Arc<ByzantineFaultTolerance>>,
283 vector_clock: Arc<RwLock<VectorClock>>,
284 locks: Arc<RwLock<HashMap<String, DistributedLock>>>,
285 partition_detector: Arc<NetworkPartitionDetector>,
286}
287
288impl AdvancedConsensusSystem {
289 pub fn new(node_id: String, total_nodes: usize) -> Self {
290 Self {
291 bft: Some(Arc::new(ByzantineFaultTolerance::new(
292 node_id.clone(),
293 total_nodes,
294 ))),
295 vector_clock: Arc::new(RwLock::new(VectorClock::new())),
296 locks: Arc::new(RwLock::new(HashMap::new())),
297 partition_detector: Arc::new(NetworkPartitionDetector::new(
298 node_id,
299 std::time::Duration::from_secs(30),
300 )),
301 }
302 }
303
304 pub async fn propose_value(&self, value: Vec<u8>) -> Result<bool> {
305 if let Some(ref bft) = self.bft {
306 bft.propose(value).await
307 } else {
308 Err(anyhow!("BFT not enabled"))
309 }
310 }
311
312 pub async fn increment_clock(&self, node_id: String) {
313 self.vector_clock.write().await.increment(node_id);
314 }
315
316 pub async fn acquire_lock(&self, lock_id: String, node_id: String) -> Result<bool> {
317 let mut locks = self.locks.write().await;
318 let lock = locks
319 .entry(lock_id.clone())
320 .or_insert_with(|| DistributedLock::new(lock_id, std::time::Duration::from_secs(30)));
321 lock.acquire(node_id).await
322 }
323
324 pub async fn detect_partitions(&self) -> Vec<String> {
325 self.partition_detector.detect_partition().await
326 }
327}
328
329#[cfg(test)]
330mod tests {
331 use super::*;
332
333 #[tokio::test]
334 async fn test_bft() {
335 let bft = ByzantineFaultTolerance::new("node1".to_string(), 4);
336 bft.add_node("node2".to_string()).await;
337 bft.add_node("node3".to_string()).await;
338 bft.add_node("node4".to_string()).await;
339
340 let result = bft.propose(vec![1, 2, 3]).await;
341 assert!(result.is_ok());
342 }
343
344 #[test]
345 fn test_vector_clock() {
346 let mut clock1 = VectorClock::new();
347 let mut clock2 = VectorClock::new();
348
349 clock1.increment("node1".to_string());
351 clock2.increment("node2".to_string());
352
353 assert!(!clock1.happens_before(&clock2));
355 assert!(!clock2.happens_before(&clock1));
356
357 clock1.merge(&clock2);
359 clock1.increment("node1".to_string());
360
361 assert!(clock2.happens_before(&clock1));
363 assert!(!clock1.happens_before(&clock2));
364 }
365
366 #[test]
367 fn test_crdt_gcounter() {
368 let mut counter = GCounter::new();
369 counter.increment("node1".to_string(), 5);
370 counter.increment("node2".to_string(), 3);
371 assert_eq!(counter.value(), 8);
372 }
373
374 #[test]
375 fn test_crdt_pncounter() {
376 let mut counter = PNCounter::new();
377 counter.increment("node1".to_string(), 10);
378 counter.decrement("node1".to_string(), 3);
379 assert_eq!(counter.value(), 7);
380 }
381
382 #[tokio::test]
383 async fn test_distributed_lock() {
384 let lock =
385 DistributedLock::new("test_lock".to_string(), std::time::Duration::from_secs(60));
386
387 let acquired = lock.acquire("node1".to_string()).await;
388 assert!(acquired.is_ok());
389 assert!(acquired.unwrap());
390
391 let acquired2 = lock.acquire("node2".to_string()).await;
392 assert!(acquired2.is_ok());
393 assert!(!acquired2.unwrap());
394 }
395}