use std::cmp;
use std::collections::VecDeque;
use rand::Rng;
use hbbft::{ConsensusProtocol, CpStep};
use crate::net::util::randomly;
use crate::net::{CrankError, NetMessage, NetworkMessage, Node, VirtualNet};
#[derive(Debug)]
pub struct NetHandle<'a, D, A>(&'a VirtualNet<D, A>)
where
D: ConsensusProtocol,
D::Message: Clone,
D::Output: Clone,
A: Adversary<D>;
impl<'a, D: 'a, A> NetHandle<'a, D, A>
where
D: ConsensusProtocol,
D::Message: Clone,
D::Output: Clone,
A: Adversary<D>,
{
#[inline]
pub fn nodes(&self) -> impl Iterator<Item = NodeHandle<'_, D>> {
self.0.nodes().map(NodeHandle::new)
}
#[inline]
pub fn faulty_nodes(&self) -> impl Iterator<Item = &Node<D>> {
self.0.faulty_nodes()
}
#[inline]
pub fn correct_nodes(&self) -> impl Iterator<Item = NodeHandle<'_, D>> {
self.0.correct_nodes().map(NodeHandle::new)
}
#[inline]
pub fn messages(&'a self) -> impl Iterator<Item = &'a NetMessage<D>> {
self.0.messages()
}
#[inline]
pub fn get(&self, id: D::NodeId) -> Option<NodeHandle<'_, D>> {
self.0.get(id).map(NodeHandle::new)
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum QueuePosition {
Front,
Back,
Before(usize),
}
#[derive(Debug)]
pub struct NetMutHandle<'a, D, A>(&'a mut VirtualNet<D, A>)
where
D: ConsensusProtocol,
D::Message: Clone,
D::Output: Clone,
A: Adversary<D>;
impl<'a, D, A> NetMutHandle<'a, D, A>
where
D: ConsensusProtocol,
A: Adversary<D>,
D::NodeId: Clone,
D::Message: Clone,
D::Output: Clone,
{
pub fn new(net: &'a mut VirtualNet<D, A>) -> Self {
NetMutHandle(net)
}
#[inline]
pub fn nodes_mut(&mut self) -> impl Iterator<Item = NodeMutHandle<'_, D>> {
self.0.nodes_mut().map(NodeMutHandle::new)
}
#[inline]
pub fn faulty_nodes_mut(&mut self) -> impl Iterator<Item = &mut Node<D>> {
self.0.faulty_nodes_mut()
}
#[inline]
pub fn correct_nodes_mut(&mut self) -> impl Iterator<Item = NodeMutHandle<'_, D>> {
self.0.correct_nodes_mut().map(NodeMutHandle::new)
}
pub fn dispatch_message<R: Rng>(
&mut self,
msg: NetMessage<D>,
rng: &mut R,
) -> Result<CpStep<D>, CrankError<D>> {
self.0.dispatch_message(msg, rng)
}
#[inline]
pub fn inject_message(&mut self, position: QueuePosition, msg: NetMessage<D>) {
assert!(
self.0
.get(msg.from.clone())
.expect("inject: unknown sender node")
.is_faulty(),
"Tried to inject message not originating from a faulty node."
);
self.0
.get(msg.to.clone())
.expect("inject: unknown recipient node");
match position {
QueuePosition::Front => self.0.messages.push_front(msg),
QueuePosition::Back => self.0.messages.push_back(msg),
QueuePosition::Before(idx) => self.0.messages.insert(idx, msg),
}
}
#[inline]
pub fn swap_messages(&mut self, i: usize, j: usize) {
self.0.swap_messages(i, j);
}
#[inline]
pub fn sort_messages_by<F>(&mut self, f: F)
where
F: FnMut(&NetMessage<D>, &NetMessage<D>) -> cmp::Ordering,
{
self.0.sort_messages_by(f)
}
#[inline]
pub fn get_messages(&self) -> &VecDeque<NetMessage<D>> {
&self.0.messages
}
}
impl<'a, D, A> From<NetMutHandle<'a, D, A>> for NetHandle<'a, D, A>
where
D: ConsensusProtocol,
A: Adversary<D>,
D::Message: Clone,
D::Output: Clone,
{
#[inline]
fn from(n: NetMutHandle<'_, D, A>) -> NetHandle<'_, D, A> {
NetHandle(n.0)
}
}
#[derive(Debug)]
pub struct NodeHandle<'a, D>(&'a Node<D>)
where
D: ConsensusProtocol;
impl<'a, D> NodeHandle<'a, D>
where
D: ConsensusProtocol,
{
#[inline]
fn new(inner: &'a Node<D>) -> Self {
NodeHandle(inner)
}
#[inline]
pub fn id(&self) -> D::NodeId {
self.0.id().clone()
}
#[inline]
pub fn node(&self) -> &'a Node<D> {
self.try_node()
.expect("could not access inner node of handle, node is not faulty")
}
#[inline]
pub fn try_node(&self) -> Option<&'a Node<D>> {
if self.0.is_faulty() {
Some(self.0)
} else {
None
}
}
}
#[derive(Debug)]
pub struct NodeMutHandle<'a, D>(&'a mut Node<D>)
where
D: ConsensusProtocol;
impl<'a, D: 'a> NodeMutHandle<'a, D>
where
D: ConsensusProtocol,
{
fn new(inner: &'a mut Node<D>) -> Self {
NodeMutHandle(inner)
}
#[inline]
pub fn id(&self) -> D::NodeId {
self.0.id().clone()
}
#[inline]
pub fn node_mut(&'a mut self) -> &'a mut Node<D> {
self.try_node_mut()
.expect("could not access inner node of handle, node is not faulty")
}
#[inline]
pub fn try_node_mut(&mut self) -> Option<&mut Node<D>> {
if self.0.is_faulty() {
Some(self.0)
} else {
None
}
}
}
pub trait Adversary<D>
where
Self: Sized,
D: ConsensusProtocol,
D::Message: Clone,
D::Output: Clone,
{
#[inline]
fn pre_crank<R: Rng>(&mut self, _net: NetMutHandle<'_, D, Self>, _rng: &mut R) {}
#[inline]
fn tamper<R: Rng>(
&mut self,
mut net: NetMutHandle<'_, D, Self>,
msg: NetMessage<D>,
rng: &mut R,
) -> Result<CpStep<D>, CrankError<D>> {
net.dispatch_message(msg, rng)
}
}
#[inline]
pub fn sort_ascending<D, A>(net: &mut NetMutHandle<'_, D, A>)
where
D: ConsensusProtocol,
D::Message: Clone,
D::Output: Clone,
A: Adversary<D>,
{
net.sort_messages_by(|a, b| a.to().cmp(&b.to()))
}
#[inline]
pub fn swap_random<R, D, A>(net: &mut NetMutHandle<'_, D, A>, rng: &mut R)
where
R: Rng,
D: ConsensusProtocol,
D::Message: Clone,
D::Output: Clone,
A: Adversary<D>,
{
let l = net.get_messages().len();
if l > 0 {
net.swap_messages(0, rng.gen_range(0, l));
}
}
#[inline]
pub fn random_node<R, D, A>(net: &mut NetMutHandle<'_, D, A>, rng: &mut R) -> Option<D::NodeId>
where
R: Rng,
D: ConsensusProtocol,
D::Message: Clone,
D::Output: Clone,
A: Adversary<D>,
{
let l = net.nodes_mut().count();
if l > 0 {
return Some(
net.nodes_mut()
.nth(rng.gen_range(0, l))
.expect("nodes list changed since last call")
.id(),
);
}
None
}
#[inline]
pub fn sort_by_random_node<R, D, A>(net: &mut NetMutHandle<'_, D, A>, rng: &mut R)
where
R: Rng,
D: ConsensusProtocol,
D::Message: Clone,
D::Output: Clone,
A: Adversary<D>,
{
if let Some(picked_node) = random_node(net, rng) {
net.sort_messages_by(|a, b| {
let a = a.to().clone();
let b = b.to().clone();
if a == b {
cmp::Ordering::Equal
} else if a == picked_node {
cmp::Ordering::Less
} else if b == picked_node {
cmp::Ordering::Greater
} else {
a.cmp(&b)
}
});
}
}
#[derive(Debug, Default)]
pub struct NullAdversary;
impl NullAdversary {
#[inline]
pub fn new() -> Self {
NullAdversary {}
}
}
impl<D> Adversary<D> for NullAdversary
where
D: ConsensusProtocol,
D::Message: Clone,
D::Output: Clone,
{
}
#[derive(Debug, Default)]
pub struct NodeOrderAdversary;
impl NodeOrderAdversary {
#[inline]
pub fn new() -> Self {
NodeOrderAdversary {}
}
}
impl<D> Adversary<D> for NodeOrderAdversary
where
D: ConsensusProtocol,
D::Message: Clone,
D::Output: Clone,
{
#[inline]
fn pre_crank<R: Rng>(&mut self, mut net: NetMutHandle<'_, D, Self>, _rng: &mut R) {
sort_ascending(&mut net);
}
}
#[derive(Copy, Clone, Debug, Default)]
pub struct ReorderingAdversary {}
impl ReorderingAdversary {
pub fn new() -> Self {
ReorderingAdversary {}
}
}
impl<D> Adversary<D> for ReorderingAdversary
where
D: ConsensusProtocol,
D::Message: Clone,
D::Output: Clone,
{
#[inline]
fn pre_crank<R: Rng>(&mut self, mut net: NetMutHandle<'_, D, Self>, rng: &mut R) {
swap_random(&mut net, rng);
}
}
#[derive(Copy, Clone, Debug, Default)]
pub struct RandomAdversary {
p_replay: f32,
p_inject: f32,
}
impl RandomAdversary {
pub fn new(p_replay: f32, p_inject: f32) -> Self {
RandomAdversary { p_replay, p_inject }
}
}
impl<D> Adversary<D> for RandomAdversary
where
D: ConsensusProtocol,
D::Message: Clone,
D::Output: Clone,
rand::distributions::Standard:
rand::distributions::Distribution<<D as ConsensusProtocol>::Message>,
{
#[inline]
fn pre_crank<R: Rng>(&mut self, mut net: NetMutHandle<'_, D, Self>, rng: &mut R) {
sort_by_random_node(&mut net, rng);
}
#[inline]
fn tamper<R: Rng>(
&mut self,
mut net: NetMutHandle<'_, D, Self>,
msg: NetMessage<D>,
rng: &mut R,
) -> Result<CpStep<D>, CrankError<D>> {
if randomly(self.p_replay, rng) {
if let Some(picked_node) = random_node(&mut net, rng) {
let mut new_msg = msg.clone();
new_msg.from = new_msg.to;
new_msg.to = picked_node;
net.inject_message(QueuePosition::Back, new_msg);
}
}
while randomly(self.p_inject, rng) {
let sender = msg.to.clone();
let message: D::Message = rand::random();
let node_ids: Vec<<D as ConsensusProtocol>::NodeId> = net
.nodes_mut()
.map(|node| node.id())
.filter(|node_id| *node_id != sender)
.collect();
for node_id in node_ids {
let new_msg = NetworkMessage::new(sender.clone(), message.clone(), node_id);
net.inject_message(QueuePosition::Back, new_msg);
}
}
net.dispatch_message(msg, rng)
}
}