use std::sync::Arc;
use rand::distributions::{Distribution, Standard};
use rand::{seq::SliceRandom, Rng};
use serde::{Deserialize, Serialize};
use super::bool_multimap::BoolMultimap;
use super::bool_set::{self, BoolSet};
use super::{FaultKind, Result};
use crate::fault_log::Fault;
use crate::{NetworkInfo, NodeIdT, Target};
pub type Step<N> = crate::Step<Message, BoolSet, N, FaultKind>;
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, PartialOrd, Eq, Ord)]
pub enum Message {
BVal(bool),
Aux(bool),
}
impl Distribution<Message> for Standard {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Message {
let message_type = *["bval", "aux"].choose(rng).unwrap();
match message_type {
"bval" => Message::BVal(rng.gen()),
"aux" => Message::Aux(rng.gen()),
_ => unreachable!(),
}
}
}
#[derive(Debug)]
pub struct SbvBroadcast<N> {
netinfo: Arc<NetworkInfo<N>>,
bin_values: BoolSet,
received_bval: BoolMultimap<N>,
sent_bval: BoolSet,
received_aux: BoolMultimap<N>,
terminated: bool,
}
impl<N: NodeIdT> SbvBroadcast<N> {
pub fn new(netinfo: Arc<NetworkInfo<N>>) -> Self {
SbvBroadcast {
netinfo,
bin_values: bool_set::NONE,
received_bval: BoolMultimap::default(),
sent_bval: bool_set::NONE,
received_aux: BoolMultimap::default(),
terminated: false,
}
}
pub fn clear(&mut self, init: &BoolMultimap<N>) {
self.bin_values = bool_set::NONE;
self.received_bval = init.clone();
self.sent_bval = bool_set::NONE;
self.received_aux = init.clone();
self.terminated = false;
}
pub fn handle_message(&mut self, sender_id: &N, msg: &Message) -> Result<Step<N>> {
match msg {
Message::BVal(b) => self.handle_bval(sender_id, *b),
Message::Aux(b) => self.handle_aux(sender_id, *b),
}
}
pub fn bin_values(&self) -> BoolSet {
self.bin_values
}
pub fn send_bval(&mut self, b: bool) -> Result<Step<N>> {
if !self.sent_bval.insert(b) {
return Ok(Step::default());
}
self.send(&Message::BVal(b))
}
pub fn handle_bval(&mut self, sender_id: &N, b: bool) -> Result<Step<N>> {
if !self.received_bval[b].insert(sender_id.clone()) {
return Ok(Fault::new(sender_id.clone(), FaultKind::DuplicateBVal).into());
}
let count_bval = self.received_bval[b].len();
let mut step = Step::default();
if count_bval == 2 * self.netinfo.num_faulty() + 1 {
self.bin_values.insert(b);
if self.bin_values != bool_set::BOTH {
step.extend(self.send(&Message::Aux(b))?) } else {
step.extend(self.try_output()?); }
}
if count_bval == self.netinfo.num_faulty() + 1 {
step.extend(self.send_bval(b)?);
}
Ok(step)
}
fn send(&mut self, msg: &Message) -> Result<Step<N>> {
if !self.netinfo.is_validator() {
return self.try_output();
}
let step: Step<_> = Target::All.message(msg.clone()).into();
let our_id = &self.netinfo.our_id().clone();
Ok(step.join(self.handle_message(our_id, &msg)?))
}
pub fn handle_aux(&mut self, sender_id: &N, b: bool) -> Result<Step<N>> {
if !self.received_aux[b].insert(sender_id.clone()) {
return Ok(Fault::new(sender_id.clone(), FaultKind::DuplicateAux).into());
}
self.try_output()
}
fn try_output(&mut self) -> Result<Step<N>> {
if self.terminated || self.bin_values == bool_set::NONE {
return Ok(Step::default());
}
let (aux_count, aux_vals) = self.count_aux();
if aux_count < self.netinfo.num_correct() {
return Ok(Step::default());
}
self.terminated = true;
Ok(Step::default().with_output(aux_vals))
}
fn count_aux(&self) -> (usize, BoolSet) {
let mut values = bool_set::NONE;
let mut count = 0;
for b in self.bin_values {
if !self.received_aux[b].is_empty() {
values.insert(b);
count += self.received_aux[b].len();
}
}
(count, values)
}
}