#![allow(missing_docs)]
use std::collections::HashMap;
use std::convert::Infallible;
use std::convert::TryInto;
use std::sync::Arc;
use futures::channel::mpsc;
use futures::future::BoxFuture;
use futures::future::FutureExt;
use futures::lock::Mutex;
use futures::sink::SinkExt;
use futures::stream::Stream;
use thiserror::Error;
use crate::append::AppendError;
use crate::communicator::Acceptance;
use crate::communicator::AcceptanceFor;
use crate::communicator::Committed;
use crate::communicator::Communicator;
use crate::communicator::Vote;
use crate::communicator::VoteFor;
use crate::error::ShutDown;
use crate::invocation::AbstainOf;
use crate::invocation::CoordNumOf;
use crate::invocation::Invocation;
use crate::invocation::LogEntryOf;
use crate::invocation::NayOf;
use crate::invocation::NodeIdOf;
use crate::invocation::NodeOf;
use crate::invocation::RoundNumOf;
use crate::invocation::YeaOf;
use crate::retry::RetryPolicy;
use crate::LogEntry;
use crate::NodeInfo;
use crate::RequestHandler;
#[derive(
Clone, Copy, Debug, Default, Eq, Hash, PartialEq, serde::Deserialize, serde::Serialize,
)]
pub struct PrototypingNode(usize);
static NODE_ID_DISPENSER: std::sync::atomic::AtomicUsize = std::sync::atomic::AtomicUsize::new(0);
impl PrototypingNode {
pub fn new() -> Self {
Self(NODE_ID_DISPENSER.fetch_add(1, std::sync::atomic::Ordering::Relaxed))
}
pub fn with_id(id: usize) -> Self {
Self(id)
}
}
impl NodeInfo for PrototypingNode {
type Id = usize;
fn id(&self) -> Self::Id {
self.0
}
}
#[derive(Debug)]
pub struct RetryIndefinitely<I>(u64, crate::util::PhantomSend<I>);
impl<I> RetryIndefinitely<I> {
pub fn without_pausing() -> Self {
Self(0, crate::util::PhantomSend::new())
}
pub fn pausing_up_to(duration: std::time::Duration) -> Self {
Self(duration.as_millis() as u64, crate::util::PhantomSend::new())
}
}
impl<I: Invocation> RetryPolicy for RetryIndefinitely<I> {
type Invocation = I;
type Error = Infallible;
type StaticError = ShutDown;
type Future = BoxFuture<'static, Result<(), Self::Error>>;
fn eval(&mut self, _err: AppendError<Self::Invocation>) -> Self::Future {
let limit = self.0;
async move {
if limit > 0 {
use rand::Rng;
let delay = rand::thread_rng().gen_range(0..=limit);
let delay = std::time::Duration::from_millis(delay);
sleep(delay).await;
}
Ok(())
}
.boxed()
}
}
type RequestHandlers<I> = HashMap<NodeIdOf<I>, RequestHandler<I>>;
type EventListeners<I> = Vec<mpsc::Sender<DirectCommunicatorEvent<I>>>;
type PacketLossRates<I> = HashMap<(NodeIdOf<I>, NodeIdOf<I>), f32>;
type E2eDelays<I> = HashMap<(NodeIdOf<I>, NodeIdOf<I>), rand_distr::Normal<f32>>;
#[derive(Debug)]
pub struct DirectCommunicators<I: Invocation> {
#[allow(clippy::type_complexity)]
request_handlers: Arc<Mutex<RequestHandlers<I>>>,
default_packet_loss: f32,
default_e2e_delay: rand_distr::Normal<f32>,
packet_loss: Arc<Mutex<PacketLossRates<I>>>,
e2e_delay: Arc<Mutex<E2eDelays<I>>>,
event_listeners: Arc<Mutex<EventListeners<I>>>,
}
impl<I: Invocation> DirectCommunicators<I> {
pub fn new() -> Self {
Self::with_characteristics(0.0, rand_distr::Normal::new(0.0, 0.0).unwrap())
}
pub fn with_characteristics(packet_loss: f32, e2e_delay: rand_distr::Normal<f32>) -> Self {
Self {
request_handlers: Arc::new(Mutex::new(HashMap::new())),
default_packet_loss: packet_loss,
default_e2e_delay: e2e_delay,
packet_loss: Arc::new(Mutex::new(HashMap::new())),
e2e_delay: Arc::new(Mutex::new(HashMap::new())),
event_listeners: Arc::new(Mutex::new(Vec::new())),
}
}
pub async fn set_packet_loss(&mut self, from: NodeIdOf<I>, to: NodeIdOf<I>, packet_loss: f32) {
let mut link = self.packet_loss.lock().await;
link.insert((from, to), packet_loss);
}
pub async fn set_delay(
&mut self,
from: NodeIdOf<I>,
to: NodeIdOf<I>,
delay: rand_distr::Normal<f32>,
) {
let mut link = self.e2e_delay.lock().await;
link.insert((from, to), delay);
}
pub async fn register(&self, node_id: NodeIdOf<I>, handler: RequestHandler<I>) {
let mut handlers = self.request_handlers.lock().await;
handlers.insert(node_id, handler);
}
pub fn events(&self) -> impl Stream<Item = DirectCommunicatorEvent<I>> {
let (send, recv) = mpsc::channel(16);
futures::executor::block_on(async {
let mut listeners = self.event_listeners.lock().await;
listeners.push(send);
});
recv
}
pub fn create_communicator_for(&self, node_id: NodeIdOf<I>) -> DirectCommunicator<I> {
DirectCommunicator {
set: self.clone(),
node_id,
}
}
}
impl<I: Invocation> Clone for DirectCommunicators<I> {
fn clone(&self) -> Self {
Self {
request_handlers: Arc::clone(&self.request_handlers),
default_packet_loss: self.default_packet_loss,
default_e2e_delay: self.default_e2e_delay,
packet_loss: Arc::clone(&self.packet_loss),
e2e_delay: Arc::clone(&self.e2e_delay),
event_listeners: Arc::clone(&self.event_listeners),
}
}
}
impl<I: Invocation> Default for DirectCommunicators<I> {
fn default() -> Self {
Self::new()
}
}
#[derive(Clone, Debug)]
pub struct DirectCommunicatorEvent<I: Invocation> {
pub sender: NodeIdOf<I>,
pub receiver: NodeIdOf<I>,
pub e2e_delay: std::time::Duration,
pub dropped: bool,
pub payload: DirectCommunicatorPayload<I>,
}
#[derive(Clone, Debug)]
pub enum DirectCommunicatorPayload<I: Invocation> {
Prepare {
round_num: RoundNumOf<I>,
coord_num: CoordNumOf<I>,
},
Promise(bool),
Propose {
round_num: RoundNumOf<I>,
coord_num: CoordNumOf<I>,
log_entry: Arc<LogEntryOf<I>>,
},
Accept(bool),
Commit {
round_num: RoundNumOf<I>,
coord_num: CoordNumOf<I>,
log_entry: Arc<LogEntryOf<I>>,
},
CommitById {
round_num: RoundNumOf<I>,
coord_num: CoordNumOf<I>,
},
Committed(bool),
}
#[derive(Debug, Error)]
pub enum DirectCommunicatorError {
#[error("other")]
Other,
#[error("timeout")]
Timeout,
}
#[derive(Debug)]
pub struct DirectCommunicator<I: Invocation> {
set: DirectCommunicators<I>,
node_id: NodeIdOf<I>,
}
impl<I: Invocation> Clone for DirectCommunicator<I> {
fn clone(&self) -> Self {
Self {
set: self.set.clone(),
node_id: self.node_id,
}
}
}
macro_rules! send_fn {
(
$self:ident, $receivers:ident $(, $non_copy_arg:ident)* ;
$method:ident $(, $arg:ident)* ;
$request_payload:expr;
$response_payload:expr;
) => {{
$receivers
.iter()
.map(move |receiver| {
let this = $self.clone();
let receiver_id = receiver.id();
$( send_fn!(@ $non_copy_arg); )*
(
receiver,
async move {
let (packet_loss_rate_there, packet_loss_rate_back) = {
let per_link = this.set.packet_loss.lock().await;
let there = per_link.get(&(this.node_id, receiver_id)).copied();
let there = there.unwrap_or(this.set.default_packet_loss);
let back = per_link.get(&(receiver_id, this.node_id)).copied();
let back = back.unwrap_or(this.set.default_packet_loss);
(there, back)
};
let (e2e_delay_distr_there, e2e_delay_distr_back) = {
let per_link = this.set.e2e_delay.lock().await;
let there = per_link.get(&(this.node_id, receiver_id)).copied();
let there = there.unwrap_or(this.set.default_e2e_delay);
let back = per_link.get(&(receiver_id, this.node_id)).copied();
let back = back.unwrap_or(this.set.default_e2e_delay);
(there, back)
};
let e2e_delay = delay(&e2e_delay_distr_there);
let dropped = roll_for_failure(packet_loss_rate_there);
{
let listeners = this.set.event_listeners.lock().await;
for mut l in listeners.iter().cloned() {
let _ = l.send(DirectCommunicatorEvent {
sender: this.node_id,
receiver: receiver_id,
e2e_delay,
dropped,
payload: $request_payload,
}).await;
}
}
sleep(e2e_delay).await;
if dropped {
return Err(DirectCommunicatorError::Timeout);
}
let result = {
let handlers = this.set.request_handlers.lock().await;
let handler = match handlers.get(&receiver_id) {
Some(handler) => handler,
None => return Err(DirectCommunicatorError::Other),
};
handler.$method($($arg),*)
}
.await;
let response = result
.try_into()
.map_err(|_| DirectCommunicatorError::Other);
let e2e_delay = delay(&e2e_delay_distr_back);
let dropped = roll_for_failure(packet_loss_rate_back);
{
let listeners = this.set.event_listeners.lock().await;
for mut l in listeners.iter().cloned() {
let _ = l.send(DirectCommunicatorEvent {
sender: receiver_id,
receiver: this.node_id,
e2e_delay,
dropped,
payload: $response_payload(&response),
}).await;
}
}
sleep(e2e_delay).await;
if dropped {
return Err(DirectCommunicatorError::Timeout);
}
response
}
.boxed(),
)
})
.collect()
}};
(@ $non_copy_arg:ident) => {
let $non_copy_arg = $non_copy_arg.clone();
}
}
impl<I: Invocation + 'static> Communicator for DirectCommunicator<I> {
type Node = NodeOf<I>;
type RoundNum = RoundNumOf<I>;
type CoordNum = CoordNumOf<I>;
type LogEntry = LogEntryOf<I>;
type Error = DirectCommunicatorError;
type SendPrepare = BoxFuture<'static, Result<VoteFor<Self>, Self::Error>>;
type Abstain = AbstainOf<I>;
type SendProposal = BoxFuture<'static, Result<AcceptanceFor<Self>, Self::Error>>;
type Yea = YeaOf<I>;
type Nay = NayOf<I>;
type SendCommit = BoxFuture<'static, Result<Committed, Self::Error>>;
type SendCommitById = BoxFuture<'static, Result<Committed, Self::Error>>;
fn send_prepare<'a>(
&mut self,
receivers: &'a [Self::Node],
round_num: Self::RoundNum,
coord_num: Self::CoordNum,
) -> Vec<(&'a Self::Node, Self::SendPrepare)> {
send_fn!(
self, receivers;
handle_prepare, round_num, coord_num;
DirectCommunicatorPayload::Prepare { round_num, coord_num };
|r| DirectCommunicatorPayload::Promise(matches!(r, &Ok(Vote::Given(_))));
)
}
fn send_proposal<'a>(
&mut self,
receivers: &'a [Self::Node],
round_num: Self::RoundNum,
coord_num: Self::CoordNum,
log_entry: Arc<Self::LogEntry>,
) -> Vec<(&'a Self::Node, Self::SendProposal)> {
send_fn!(
self, receivers, log_entry;
handle_proposal, round_num, coord_num, log_entry;
DirectCommunicatorPayload::Propose { round_num, coord_num, log_entry: log_entry.clone() };
|r| DirectCommunicatorPayload::Accept(matches!(r, &Ok(Acceptance::Given(_))));
)
}
fn send_commit<'a>(
&mut self,
receivers: &'a [Self::Node],
round_num: Self::RoundNum,
coord_num: Self::CoordNum,
log_entry: Arc<Self::LogEntry>,
) -> Vec<(&'a Self::Node, Self::SendCommit)> {
send_fn!(
self, receivers, log_entry;
handle_commit, round_num, coord_num, log_entry;
DirectCommunicatorPayload::Commit { round_num, coord_num, log_entry: log_entry.clone() };
|r| DirectCommunicatorPayload::Committed(matches!(r, &Ok(_)));
)
}
fn send_commit_by_id<'a>(
&mut self,
receivers: &'a [Self::Node],
round_num: Self::RoundNum,
coord_num: Self::CoordNum,
log_entry_id: <Self::LogEntry as LogEntry>::Id,
) -> Vec<(&'a Self::Node, Self::SendCommitById)> {
send_fn!(
self, receivers;
handle_commit_by_id, round_num, coord_num, log_entry_id;
DirectCommunicatorPayload::CommitById { round_num, coord_num };
|r| DirectCommunicatorPayload::Committed(matches!(r, &Ok(_)));
)
}
}
fn roll_for_failure(rate: f32) -> bool {
use rand::Rng;
rand::thread_rng().gen::<f32>() < rate
}
async fn sleep(duration: std::time::Duration) {
if duration > std::time::Duration::ZERO {
futures_timer::Delay::new(duration).await;
}
}
fn delay(distr: &rand_distr::Normal<f32>) -> std::time::Duration {
use rand::distributions::Distribution;
let delay_ms = distr.sample(&mut rand::thread_rng());
let delay_ms = delay_ms as u64;
std::time::Duration::from_millis(delay_ms)
}