1#![forbid(unsafe_code)]
2
3use std::collections::{HashMap, HashSet};
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
13pub enum Trit {
14 Neg = -1,
15 Zero = 0,
16 Pos = 1,
17}
18
19impl Trit {
20 pub fn from_i8(v: i8) -> Option<Self> {
21 match v {
22 -1 => Some(Trit::Neg),
23 0 => Some(Trit::Zero),
24 1 => Some(Trit::Pos),
25 _ => None,
26 }
27 }
28
29 pub fn to_i8(self) -> i8 {
30 self as i8
31 }
32}
33
34pub type NodeId = u64;
36
37#[derive(Debug, Clone)]
42pub struct TernaryNode {
43 pub id: NodeId,
44 pub state: Trit,
45 pub peers: HashSet<NodeId>,
46 pub vector_clock: VectorClock,
47 pub is_alive: bool,
48}
49
50impl TernaryNode {
51 pub fn new(id: NodeId) -> Self {
52 Self {
53 id,
54 state: Trit::Zero,
55 peers: HashSet::new(),
56 vector_clock: VectorClock::new(),
57 is_alive: true,
58 }
59 }
60
61 pub fn with_state(id: NodeId, state: Trit) -> Self {
62 Self {
63 id,
64 state,
65 peers: HashSet::new(),
66 vector_clock: VectorClock::new(),
67 is_alive: true,
68 }
69 }
70
71 pub fn add_peer(&mut self, peer_id: NodeId) {
72 if peer_id != self.id {
73 self.peers.insert(peer_id);
74 }
75 }
76
77 pub fn remove_peer(&mut self, peer_id: NodeId) {
78 self.peers.remove(&peer_id);
79 }
80
81 pub fn set_state(&mut self, state: Trit) {
82 self.state = state;
83 self.vector_clock.increment(self.id);
84 }
85}
86
87#[derive(Debug, Clone)]
94pub struct GossipProtocol {
95 pub nodes: HashMap<NodeId, TernaryNode>,
96 pub round: u64,
97}
98
99impl GossipProtocol {
100 pub fn new() -> Self {
101 Self {
102 nodes: HashMap::new(),
103 round: 0,
104 }
105 }
106
107 pub fn add_node(&mut self, node: TernaryNode) {
108 self.nodes.insert(node.id, node);
109 }
110
111 pub fn run_round(&mut self) -> u32 {
112 let states: HashMap<NodeId, (Trit, VectorClock)> = self
113 .nodes
114 .iter()
115 .map(|(id, n)| (*id, (n.state, n.vector_clock.clone())))
116 .collect();
117
118 let mut updates = 0u32;
119 let node_ids: Vec<NodeId> = self.nodes.keys().copied().collect();
120
121 for node_id in &node_ids {
122 let node = self.nodes.get(node_id).unwrap();
123 let mut peer_states: Vec<Trit> = Vec::new();
124
125 for peer_id in &node.peers {
126 if let Some((state, vc)) = states.get(peer_id) {
127 if vc >= &node.vector_clock {
129 peer_states.push(*state);
130 }
131 }
132 }
133
134 if let Some(new_state) = dominant_trit(&peer_states) {
135 if new_state != node.state {
136 if let Some(n) = self.nodes.get_mut(node_id) {
137 n.state = new_state;
138 n.vector_clock.increment(*node_id);
139 updates += 1;
140 }
141 }
142 }
143 }
144
145 self.round += 1;
146 updates
147 }
148
149 pub fn run_until_converged(&mut self, max_rounds: u64) -> u64 {
150 for i in 0..max_rounds {
151 if self.run_round() == 0 {
152 return i + 1;
153 }
154 }
155 max_rounds
156 }
157
158 pub fn is_converged(&self) -> bool {
159 let states: HashSet<Trit> = self.nodes.values().map(|n| n.state).collect();
160 states.len() <= 1
161 }
162}
163
164fn dominant_trit(trits: &[Trit]) -> Option<Trit> {
166 let mut neg = 0u32;
167 let mut pos = 0u32;
168 for t in trits {
169 match t {
170 Trit::Neg => neg += 1,
171 Trit::Pos => pos += 1,
172 Trit::Zero => {}
173 }
174 }
175 if pos == 0 && neg == 0 {
176 return None;
177 }
178 if pos >= neg {
179 Some(Trit::Pos)
180 } else {
181 Some(Trit::Neg)
182 }
183}
184
185#[derive(Debug, Clone, PartialEq, Eq)]
190pub struct VectorClock {
191 pub counters: HashMap<NodeId, u64>,
192}
193
194impl VectorClock {
195 pub fn new() -> Self {
196 Self {
197 counters: HashMap::new(),
198 }
199 }
200
201 pub fn increment(&mut self, node_id: NodeId) -> u64 {
202 let counter = self.counters.entry(node_id).or_insert(0);
203 *counter += 1;
204 *counter
205 }
206
207 pub fn get(&self, node_id: NodeId) -> u64 {
208 *self.counters.get(&node_id).unwrap_or(&0)
209 }
210
211 pub fn merge(&self, other: &VectorClock) -> VectorClock {
212 let mut merged = self.counters.clone();
213 for (node_id, counter) in &other.counters {
214 let entry = merged.entry(*node_id).or_insert(0);
215 *entry = (*entry).max(*counter);
216 }
217 VectorClock { counters: merged }
218 }
219
220 pub fn happened_before(&self, other: &VectorClock) -> bool {
222 let all_keys: HashSet<NodeId> = self
223 .counters
224 .keys()
225 .chain(other.counters.keys())
226 .copied()
227 .collect();
228
229 let mut at_least_one_less = false;
230 for key in &all_keys {
231 let s = self.get(*key);
232 let o = other.get(*key);
233 if s > o {
234 return false;
235 }
236 if s < o {
237 at_least_one_less = true;
238 }
239 }
240 at_least_one_less
241 }
242
243 pub fn is_concurrent(&self, other: &VectorClock) -> bool {
245 !self.happened_before(other) && !other.happened_before(self) && self != other
246 }
247}
248
249impl PartialOrd for VectorClock {
250 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
251 if self == other {
252 Some(std::cmp::Ordering::Equal)
253 } else if self.happened_before(other) {
254 Some(std::cmp::Ordering::Less)
255 } else if other.happened_before(self) {
256 Some(std::cmp::Ordering::Greater)
257 } else {
258 None }
260 }
261}
262
263#[derive(Debug, Clone)]
269pub struct PartitionDetector {
270 pub last_seen: HashMap<NodeId, u64>,
271 pub timeout_rounds: u64,
272 pub current_round: u64,
273 pub total_nodes: usize,
274}
275
276impl PartitionDetector {
277 pub fn new(total_nodes: usize, timeout_rounds: u64) -> Self {
278 Self {
279 last_seen: HashMap::new(),
280 timeout_rounds,
281 current_round: 0,
282 total_nodes,
283 }
284 }
285
286 pub fn heartbeat(&mut self, node_id: NodeId) {
287 self.last_seen.insert(node_id, self.current_round);
288 }
289
290 pub fn advance_round(&mut self) {
291 self.current_round += 1;
292 }
293
294 pub fn is_alive(&self, node_id: NodeId) -> bool {
295 self.last_seen
296 .get(&node_id)
297 .map(|&r| self.current_round.saturating_sub(r) <= self.timeout_rounds)
298 .unwrap_or(false)
299 }
300
301 pub fn alive_nodes(&self) -> Vec<NodeId> {
302 self.last_seen
303 .keys()
304 .filter(|&&id| self.is_alive(id))
305 .copied()
306 .collect()
307 }
308
309 pub fn partitioned_nodes(&self) -> Vec<NodeId> {
310 self.last_seen
311 .keys()
312 .filter(|&&id| !self.is_alive(id))
313 .copied()
314 .collect()
315 }
316
317 pub fn has_quorum(&self) -> bool {
318 let alive = self.alive_nodes().len();
319 alive * 2 > self.total_nodes
320 }
321
322 pub fn is_partitioned(&self) -> bool {
323 !self.has_quorum()
324 }
325}
326
327#[derive(Debug, Clone, Copy, PartialEq, Eq)]
329pub enum Vote {
330 Negative,
331 Abstain,
332 Positive,
333}
334
335impl Vote {
336 pub fn to_trit(self) -> Trit {
337 match self {
338 Vote::Negative => Trit::Neg,
339 Vote::Abstain => Trit::Zero,
340 Vote::Positive => Trit::Pos,
341 }
342 }
343
344 pub fn from_trit(t: Trit) -> Self {
345 match t {
346 Trit::Neg => Vote::Negative,
347 Trit::Zero => Vote::Abstain,
348 Trit::Pos => Vote::Positive,
349 }
350 }
351}
352
353#[derive(Debug, Clone)]
359pub struct ConsensusProtocol {
360 pub proposers: HashSet<NodeId>,
361 pub acceptors: HashSet<NodeId>,
362 pub learners: HashSet<NodeId>,
363 pub promised_proposal: HashMap<NodeId, u64>,
364 pub accepted_value: HashMap<NodeId, (u64, Vote)>,
365 pub proposal_counter: u64,
366 pub quorum_size: usize,
367}
368
369impl ConsensusProtocol {
370 pub fn new(nodes: &[NodeId]) -> Self {
371 let node_set: HashSet<NodeId> = nodes.iter().copied().collect();
372 let quorum_size = nodes.len() / 2 + 1;
373 Self {
374 proposers: node_set.clone(),
375 acceptors: node_set.clone(),
376 learners: node_set,
377 promised_proposal: HashMap::new(),
378 accepted_value: HashMap::new(),
379 proposal_counter: 0,
380 quorum_size,
381 }
382 }
383
384 pub fn prepare(&mut self, proposer: NodeId) -> u64 {
385 self.proposal_counter += 1;
386 let proposal_num = self.proposal_counter;
387 if let Some(&promised) = self.promised_proposal.get(&proposer) {
389 if promised >= proposal_num {
390 return 0; }
392 }
393 proposal_num
394 }
395
396 pub fn promise(&mut self, acceptor: NodeId, proposal_num: u64) -> bool {
397 if let Some(&promised) = self.promised_proposal.get(&acceptor) {
398 if promised > proposal_num {
399 return false;
400 }
401 }
402 self.promised_proposal.insert(acceptor, proposal_num);
403 true
404 }
405
406 pub fn accept(&mut self, acceptor: NodeId, proposal_num: u64, value: Vote) -> bool {
407 if let Some(&promised) = self.promised_proposal.get(&acceptor) {
408 if promised > proposal_num {
409 return false;
410 }
411 }
412 self.accepted_value
413 .insert(acceptor, (proposal_num, value));
414 true
415 }
416
417 pub fn decide(&self) -> Option<Vote> {
418 let values: Vec<&(u64, Vote)> = self.accepted_value.values().collect();
419 if values.len() < self.quorum_size {
420 return None;
421 }
422
423 let mut proposal_counts: HashMap<u64, Vec<Vote>> = HashMap::new();
425 for (num, vote) in &values {
426 proposal_counts.entry(*num).or_default().push(*vote);
427 }
428
429 let max_proposal = proposal_counts.keys().max()?;
430 let votes = proposal_counts.get(max_proposal)?;
431 if votes.len() < self.quorum_size {
432 return None;
433 }
434
435 let sum: i32 = votes.iter().map(|v| v.to_trit().to_i8() as i32).sum();
436 if sum < 0 {
437 Some(Vote::Negative)
438 } else if sum > 0 {
439 Some(Vote::Positive)
440 } else {
441 Some(Vote::Abstain)
442 }
443 }
444}
445
446#[derive(Debug, Clone)]
451pub struct AntiEntropySync {
452 pub nodes: HashMap<NodeId, TernaryNode>,
453}
454
455impl AntiEntropySync {
456 pub fn new() -> Self {
457 Self {
458 nodes: HashMap::new(),
459 }
460 }
461
462 pub fn add_node(&mut self, node: TernaryNode) {
463 self.nodes.insert(node.id, node);
464 }
465
466 pub fn sync_pair(&mut self, node_a: NodeId, node_b: NodeId) -> bool {
468 let (state_a, vc_a, peers_a) = {
469 let a = self.nodes.get(&node_a).unwrap();
470 (a.state, a.vector_clock.clone(), a.peers.clone())
471 };
472 let (state_b, vc_b, peers_b) = {
473 let b = self.nodes.get(&node_b).unwrap();
474 (b.state, b.vector_clock.clone(), b.peers.clone())
475 };
476
477 let mut changed = false;
478
479 if vc_a.happened_before(&vc_b) {
480 if let Some(a) = self.nodes.get_mut(&node_a) {
482 if a.state != state_b {
483 a.state = state_b;
484 a.vector_clock = vc_a.merge(&vc_b);
485 a.vector_clock.increment(node_a);
486 changed = true;
487 }
488 }
489 } else if vc_b.happened_before(&vc_a) {
490 if let Some(b) = self.nodes.get_mut(&node_b) {
492 if b.state != state_a {
493 b.state = state_a;
494 b.vector_clock = vc_b.merge(&vc_a);
495 b.vector_clock.increment(node_b);
496 changed = true;
497 }
498 }
499 } else if vc_a.is_concurrent(&vc_b) {
500 let merged = dominant_trit(&[state_a, state_b]).unwrap_or(Trit::Zero);
502 let merged_vc = vc_a.merge(&vc_b);
503 if let Some(a) = self.nodes.get_mut(&node_a) {
504 if a.state != merged {
505 a.state = merged;
506 a.vector_clock = merged_vc.clone();
507 a.vector_clock.increment(node_a);
508 changed = true;
509 }
510 }
511 if let Some(b) = self.nodes.get_mut(&node_b) {
512 if b.state != merged {
513 b.state = merged;
514 b.vector_clock = merged_vc;
515 b.vector_clock.increment(node_b);
516 changed = true;
517 }
518 }
519 }
520
521 if let Some(a) = self.nodes.get_mut(&node_a) {
523 for p in &peers_b {
524 a.peers.insert(*p);
525 }
526 }
527 if let Some(b) = self.nodes.get_mut(&node_b) {
528 for p in &peers_a {
529 b.peers.insert(*p);
530 }
531 }
532
533 changed
534 }
535
536 pub fn sync_all(&mut self) -> u32 {
538 let node_ids: Vec<NodeId> = self.nodes.keys().copied().collect();
539 let mut changes = 0u32;
540
541 for i in 0..node_ids.len() {
542 for j in (i + 1)..node_ids.len() {
543 let a_id = node_ids[i];
544 let b_id = node_ids[j];
545 let a = self.nodes.get(&a_id).unwrap();
546 let b = self.nodes.get(&b_id).unwrap();
547 if a.peers.contains(&b_id) || b.peers.contains(&a_id) {
548 if self.sync_pair(a_id, b_id) {
549 changes += 1;
550 }
551 }
552 }
553 }
554
555 changes
556 }
557}
558
559#[cfg(test)]
560mod tests {
561 use super::*;
562
563 #[test]
564 fn test_trit_from_i8() {
565 assert_eq!(Trit::from_i8(-1), Some(Trit::Neg));
566 assert_eq!(Trit::from_i8(0), Some(Trit::Zero));
567 assert_eq!(Trit::from_i8(1), Some(Trit::Pos));
568 assert_eq!(Trit::from_i8(2), None);
569 }
570
571 #[test]
572 fn test_trit_to_i8() {
573 assert_eq!(Trit::Neg.to_i8(), -1);
574 assert_eq!(Trit::Zero.to_i8(), 0);
575 assert_eq!(Trit::Pos.to_i8(), 1);
576 }
577
578 #[test]
579 fn test_ternary_node_new() {
580 let node = TernaryNode::new(1);
581 assert_eq!(node.id, 1);
582 assert_eq!(node.state, Trit::Zero);
583 assert!(node.peers.is_empty());
584 assert!(node.is_alive);
585 }
586
587 #[test]
588 fn test_ternary_node_add_peer() {
589 let mut node = TernaryNode::new(1);
590 node.add_peer(2);
591 node.add_peer(3);
592 assert!(node.peers.contains(&2));
593 assert!(node.peers.contains(&3));
594 assert_eq!(node.peers.len(), 2);
595 }
596
597 #[test]
598 fn test_ternary_node_no_self_peer() {
599 let mut node = TernaryNode::new(1);
600 node.add_peer(1);
601 assert!(node.peers.is_empty());
602 }
603
604 #[test]
605 fn test_ternary_node_set_state() {
606 let mut node = TernaryNode::new(1);
607 node.set_state(Trit::Pos);
608 assert_eq!(node.state, Trit::Pos);
609 assert_eq!(node.vector_clock.get(1), 1);
610 node.set_state(Trit::Neg);
611 assert_eq!(node.state, Trit::Neg);
612 assert_eq!(node.vector_clock.get(1), 2);
613 }
614
615 #[test]
616 fn test_vector_clock_increment() {
617 let mut vc = VectorClock::new();
618 assert_eq!(vc.increment(1), 1);
619 assert_eq!(vc.increment(1), 2);
620 assert_eq!(vc.increment(2), 1);
621 }
622
623 #[test]
624 fn test_vector_clock_happened_before() {
625 let mut vc1 = VectorClock::new();
626 vc1.increment(1);
627 let mut vc2 = VectorClock::new();
628 vc2.increment(1);
629 vc2.increment(1);
630 assert!(vc1.happened_before(&vc2));
631 assert!(!vc2.happened_before(&vc1));
632 }
633
634 #[test]
635 fn test_vector_clock_concurrent() {
636 let mut vc1 = VectorClock::new();
637 vc1.increment(1);
638 let mut vc2 = VectorClock::new();
639 vc2.increment(2);
640 assert!(vc1.is_concurrent(&vc2));
641 assert!(vc2.is_concurrent(&vc1));
642 }
643
644 #[test]
645 fn test_vector_clock_merge() {
646 let mut vc1 = VectorClock::new();
647 vc1.increment(1);
648 let mut vc2 = VectorClock::new();
649 vc2.increment(2);
650 let merged = vc1.merge(&vc2);
651 assert_eq!(merged.get(1), 1);
652 assert_eq!(merged.get(2), 1);
653 }
654
655 #[test]
656 fn test_vector_clock_partial_ord() {
657 let mut vc1 = VectorClock::new();
658 vc1.increment(1);
659 let mut vc2 = VectorClock::new();
660 vc2.increment(1);
661 vc2.increment(2);
662 assert!(vc1 < vc2);
663 assert!(vc2 > vc1);
664 }
665
666 #[test]
667 fn test_gossip_single_round() {
668 let mut gossip = GossipProtocol::new();
669 let mut n1 = TernaryNode::with_state(1, Trit::Pos);
670 n1.add_peer(2);
671 let mut n2 = TernaryNode::new(2);
672 n2.add_peer(1);
673 gossip.add_node(n1);
674 gossip.add_node(n2);
675 let updates = gossip.run_round();
676 assert!(updates > 0);
677 assert!(gossip.is_converged());
678 }
679
680 #[test]
681 fn test_gossip_converged() {
682 let mut gossip = GossipProtocol::new();
683 let mut n1 = TernaryNode::with_state(1, Trit::Pos);
684 n1.add_peer(2);
685 let mut n2 = TernaryNode::with_state(2, Trit::Pos);
686 n2.add_peer(1);
687 gossip.add_node(n1);
688 gossip.add_node(n2);
689 let rounds = gossip.run_until_converged(10);
690 assert_eq!(rounds, 1);
691 assert!(gossip.is_converged());
692 }
693
694 #[test]
695 fn test_partition_detector_alive() {
696 let mut pd = PartitionDetector::new(3, 2);
697 pd.heartbeat(1);
698 pd.heartbeat(2);
699 pd.heartbeat(3);
700 assert!(pd.is_alive(1));
701 assert!(pd.is_alive(2));
702 assert!(pd.is_alive(3));
703 }
704
705 #[test]
706 fn test_partition_detector_timeout() {
707 let mut pd = PartitionDetector::new(3, 1);
708 pd.heartbeat(1);
709 pd.heartbeat(2);
710 pd.heartbeat(3);
711 pd.advance_round();
712 pd.advance_round();
713 assert!(!pd.is_alive(1));
714 assert_eq!(pd.partitioned_nodes().len(), 3);
715 }
716
717 #[test]
718 fn test_partition_detector_quorum() {
719 let mut pd = PartitionDetector::new(3, 2);
720 pd.heartbeat(1);
721 pd.heartbeat(2);
722 assert!(pd.has_quorum());
723 assert!(!pd.is_partitioned());
724 }
725
726 #[test]
727 fn test_partition_detector_no_quorum() {
728 let mut pd = PartitionDetector::new(5, 2);
729 pd.heartbeat(1);
730 pd.heartbeat(2);
731 assert!(!pd.has_quorum());
732 assert!(pd.is_partitioned());
733 }
734
735 #[test]
736 fn test_consensus_prepare_promise() {
737 let mut cp = ConsensusProtocol::new(&[1, 2, 3]);
738 let proposal = cp.prepare(1);
739 assert!(proposal > 0);
740 assert!(cp.promise(1, proposal));
741 assert!(cp.promise(2, proposal));
742 assert!(cp.promise(3, proposal));
743 }
744
745 #[test]
746 fn test_consensus_accept_decide() {
747 let mut cp = ConsensusProtocol::new(&[1, 2, 3]);
748 let proposal = cp.prepare(1);
749 cp.promise(1, proposal);
750 cp.promise(2, proposal);
751 cp.promise(3, proposal);
752 cp.accept(1, proposal, Vote::Positive);
753 cp.accept(2, proposal, Vote::Positive);
754 cp.accept(3, proposal, Vote::Positive);
755 assert_eq!(cp.decide(), Some(Vote::Positive));
756 }
757
758 #[test]
759 fn test_consensus_negative_decision() {
760 let mut cp = ConsensusProtocol::new(&[1, 2, 3]);
761 let proposal = cp.prepare(1);
762 cp.promise(1, proposal);
763 cp.promise(2, proposal);
764 cp.promise(3, proposal);
765 cp.accept(1, proposal, Vote::Negative);
766 cp.accept(2, proposal, Vote::Negative);
767 cp.accept(3, proposal, Vote::Abstain);
768 assert_eq!(cp.decide(), Some(Vote::Negative));
769 }
770
771 #[test]
772 fn test_consensus_no_quorum() {
773 let mut cp = ConsensusProtocol::new(&[1, 2, 3]);
774 assert_eq!(cp.decide(), None);
775 }
776
777 #[test]
778 fn test_anti_entropy_sync_pair() {
779 let mut sync = AntiEntropySync::new();
780 let mut n1 = TernaryNode::with_state(1, Trit::Pos);
781 n1.add_peer(2);
782 let mut n2 = TernaryNode::new(2);
783 n2.add_peer(1);
784 n2.vector_clock.increment(2); sync.add_node(n1);
786 sync.add_node(n2);
787 sync.sync_pair(1, 2);
788 let state2 = sync.nodes.get(&2).unwrap().state;
791 assert_eq!(state2, Trit::Zero); }
793
794 #[test]
795 fn test_vote_trit_conversion() {
796 assert_eq!(Vote::Negative.to_trit(), Trit::Neg);
797 assert_eq!(Vote::Abstain.to_trit(), Trit::Zero);
798 assert_eq!(Vote::Positive.to_trit(), Trit::Pos);
799 assert_eq!(Vote::from_trit(Trit::Neg), Vote::Negative);
800 assert_eq!(Vote::from_trit(Trit::Zero), Vote::Abstain);
801 assert_eq!(Vote::from_trit(Trit::Pos), Vote::Positive);
802 }
803
804 #[test]
805 fn test_dominant_trit() {
806 assert_eq!(dominant_trit(&[Trit::Pos, Trit::Pos, Trit::Neg]), Some(Trit::Pos));
807 assert_eq!(dominant_trit(&[Trit::Neg, Trit::Neg, Trit::Pos]), Some(Trit::Neg));
808 assert_eq!(dominant_trit(&[Trit::Zero, Trit::Zero]), None);
809 assert_eq!(dominant_trit(&[]), None);
810 }
811}