use std::{
any::Any,
cell::RefCell,
collections::{HashMap, HashSet},
convert::Infallible,
fmt::{self, Display, Formatter},
sync::{Arc, RwLock},
};
use rand::seq::IteratorRandom;
use serde::Serialize;
use tokio::sync::mpsc::{self, error::SendError};
use tracing::{debug, error, info, warn};
use crate::{
components::Component,
effect::{
announcements::NetworkAnnouncement, requests::NetworkRequest, EffectBuilder, EffectExt,
Effects,
},
logging,
reactor::{EventQueueHandle, QueueKind},
testing::TestRng,
types::NodeId,
NodeRng,
};
type Network<P> = Arc<RwLock<HashMap<NodeId, mpsc::UnboundedSender<(NodeId, P)>>>>;
#[derive(Debug, Serialize)]
pub(crate) struct Event<P>(NetworkRequest<NodeId, P>);
impl<P> From<NetworkRequest<NodeId, P>> for Event<P> {
fn from(req: NetworkRequest<NodeId, P>) -> Self {
Event(req)
}
}
impl<P: Display> Display for Event<P> {
#[inline]
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
Display::fmt(&self.0, f)
}
}
thread_local! {
static ACTIVE_NETWORK: RefCell<Option<Box<dyn Any>>> = RefCell::new(None);
}
#[derive(Debug, Default)]
pub(crate) struct NetworkController<P> {
nodes: Network<P>,
}
impl<P> NetworkController<P>
where
P: 'static + Send,
{
fn new() -> Self {
let _ = logging::init();
NetworkController {
nodes: Default::default(),
}
}
pub(crate) fn create_active() {
let _ = logging::init();
ACTIVE_NETWORK
.with(|active_network| active_network.borrow_mut().replace(Box::new(Self::new())));
}
pub(crate) fn remove_active() {
assert!(
ACTIVE_NETWORK.with(|active_network| {
active_network
.borrow_mut()
.take()
.expect("tried to remove non-existent network")
.is::<Self>()
}),
"removed network was of wrong type"
);
}
pub(crate) fn create_node<REv>(
event_queue: EventQueueHandle<REv>,
rng: &mut TestRng,
) -> InMemoryNetwork<P>
where
REv: From<NetworkAnnouncement<NodeId, P>> + Send,
{
ACTIVE_NETWORK.with(|active_network| {
active_network
.borrow_mut()
.as_mut()
.expect("tried to create node without active network set")
.downcast_mut::<Self>()
.expect("active network has wrong message type")
.create_node_local(event_queue, rng)
})
}
pub(crate) fn remove_node(node_id: &NodeId) {
ACTIVE_NETWORK.with(|active_network| {
if let Some(active_network) = active_network.borrow_mut().as_mut() {
active_network
.downcast_mut::<Self>()
.expect("active network has wrong message type")
.nodes
.write()
.expect("poisoned lock")
.remove(node_id)
.expect("node doesn't exist in network");
}
})
}
pub(crate) fn create_node_local<REv>(
&self,
event_queue: EventQueueHandle<REv>,
rng: &mut TestRng,
) -> InMemoryNetwork<P>
where
REv: From<NetworkAnnouncement<NodeId, P>> + Send,
{
InMemoryNetwork::new_with_data(event_queue, NodeId::random(rng), self.nodes.clone())
}
}
#[derive(Debug)]
pub(crate) struct InMemoryNetwork<P> {
node_id: NodeId,
nodes: Network<P>,
}
impl<P> InMemoryNetwork<P>
where
P: 'static + Send,
{
pub(crate) fn new<REv>(event_queue: EventQueueHandle<REv>, rng: &mut NodeRng) -> Self
where
REv: From<NetworkAnnouncement<NodeId, P>> + Send,
{
NetworkController::create_node(event_queue, rng)
}
fn new_with_data<REv>(
event_queue: EventQueueHandle<REv>,
node_id: NodeId,
nodes: Network<P>,
) -> Self
where
REv: From<NetworkAnnouncement<NodeId, P>> + Send,
{
let (sender, receiver) = mpsc::unbounded_channel();
{
let mut nodes_write = nodes.write().expect("network lock poisoned");
assert!(!nodes_write.contains_key(&node_id));
nodes_write.insert(node_id, sender);
}
tokio::spawn(receiver_task(event_queue, receiver));
InMemoryNetwork { node_id, nodes }
}
#[inline]
pub(crate) fn node_id(&self) -> NodeId {
self.node_id
}
}
impl<P> InMemoryNetwork<P>
where
P: Display,
{
fn send(
&self,
nodes: &HashMap<NodeId, mpsc::UnboundedSender<(NodeId, P)>>,
dest: NodeId,
payload: P,
) {
if dest == self.node_id {
panic!("can't send message to self");
}
match nodes.get(&dest) {
Some(sender) => {
if let Err(SendError((_, msg))) = sender.send((self.node_id, payload)) {
warn!(%dest, %msg, "could not send message (send error)");
}
}
None => info!(%dest, %payload, "dropping message to non-existent recipient"),
}
}
}
impl<P, REv> Component<REv> for InMemoryNetwork<P>
where
P: Display + Clone,
{
type Event = Event<P>;
type ConstructionError = Infallible;
fn handle_event(
&mut self,
_effect_builder: EffectBuilder<REv>,
rng: &mut NodeRng,
Event(event): Self::Event,
) -> Effects<Self::Event> {
match event {
NetworkRequest::SendMessage {
dest,
payload,
responder,
} => {
if *dest == self.node_id {
panic!("can't send message to self");
}
if let Ok(guard) = self.nodes.read() {
self.send(&guard, *dest, *payload);
} else {
error!("network lock has been poisoned")
};
responder.respond(()).ignore()
}
NetworkRequest::Broadcast { payload, responder } => {
if let Ok(guard) = self.nodes.read() {
for dest in guard.keys().filter(|&node_id| node_id != &self.node_id) {
self.send(&guard, *dest, *payload.clone());
}
} else {
error!("network lock has been poisoned")
};
responder.respond(()).ignore()
}
NetworkRequest::Gossip {
payload,
count,
exclude,
responder,
} => {
if let Ok(guard) = self.nodes.read() {
let chosen: HashSet<_> = guard
.keys()
.filter(|&node_id| !exclude.contains(node_id) && node_id != &self.node_id)
.cloned()
.choose_multiple(rng, count)
.into_iter()
.collect();
for dest in chosen.iter() {
self.send(&guard, *dest, *payload.clone());
}
responder.respond(chosen).ignore()
} else {
error!("network lock has been poisoned");
responder.respond(Default::default()).ignore()
}
}
}
}
}
async fn receiver_task<REv, P>(
event_queue: EventQueueHandle<REv>,
mut receiver: mpsc::UnboundedReceiver<(NodeId, P)>,
) where
REv: From<NetworkAnnouncement<NodeId, P>>,
P: 'static + Send,
{
while let Some((sender, payload)) = receiver.recv().await {
let announce = NetworkAnnouncement::MessageReceived { sender, payload };
event_queue
.schedule(announce, QueueKind::NetworkIncoming)
.await;
}
debug!("receiver shutting down")
}