use crate::{
PeerId,
muxing::StreamMuxer,
nodes::{
handled_node::{HandledNode, HandledNodeError, NodeHandler},
node::Substream
}
};
use fnv::FnvHashMap;
use futures::{prelude::*, stream, sync::mpsc};
use smallvec::SmallVec;
use std::{
collections::hash_map::{Entry, OccupiedEntry},
error,
fmt,
mem
};
use tokio_executor;
use void::Void;
pub struct HandledNodesTasks<TInEvent, TOutEvent, THandler, TReachErr, THandlerErr> {
tasks: FnvHashMap<TaskId, mpsc::UnboundedSender<TInEvent>>,
next_task_id: TaskId,
to_spawn: SmallVec<[Box<Future<Item = (), Error = ()> + Send>; 8]>,
events_tx: mpsc::UnboundedSender<(InToExtMessage<TOutEvent, THandler, TReachErr, THandlerErr>, TaskId)>,
events_rx: mpsc::UnboundedReceiver<(InToExtMessage<TOutEvent, THandler, TReachErr, THandlerErr>, TaskId)>,
}
impl<TInEvent, TOutEvent, THandler, TReachErr, THandlerErr> fmt::Debug for HandledNodesTasks<TInEvent, TOutEvent, THandler, TReachErr, THandlerErr> {
fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
f.debug_list()
.entries(self.tasks.keys().cloned())
.finish()
}
}
#[derive(Debug)]
pub enum TaskClosedEvent<TReachErr, THandlerErr> {
Reach(TReachErr),
Node(HandledNodeError<THandlerErr>),
}
impl<TReachErr, THandlerErr> fmt::Display for TaskClosedEvent<TReachErr, THandlerErr>
where
TReachErr: fmt::Display,
THandlerErr: fmt::Display,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
TaskClosedEvent::Reach(err) => write!(f, "{}", err),
TaskClosedEvent::Node(err) => write!(f, "{}", err),
}
}
}
impl<TReachErr, THandlerErr> error::Error for TaskClosedEvent<TReachErr, THandlerErr>
where
TReachErr: error::Error + 'static,
THandlerErr: error::Error + 'static
{
fn source(&self) -> Option<&(dyn error::Error + 'static)> {
match self {
TaskClosedEvent::Reach(err) => Some(err),
TaskClosedEvent::Node(err) => Some(err),
}
}
}
#[derive(Debug)]
pub enum HandledNodesEvent<TOutEvent, THandler, TReachErr, THandlerErr> {
TaskClosed {
id: TaskId,
result: Result<(), TaskClosedEvent<TReachErr, THandlerErr>>,
handler: Option<THandler>,
},
NodeReached {
id: TaskId,
peer_id: PeerId,
},
NodeEvent {
id: TaskId,
event: TOutEvent,
},
}
#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
pub struct TaskId(usize);
impl<TInEvent, TOutEvent, THandler, TReachErr, THandlerErr> HandledNodesTasks<TInEvent, TOutEvent, THandler, TReachErr, THandlerErr> {
#[inline]
pub fn new() -> Self {
let (events_tx, events_rx) = mpsc::unbounded();
HandledNodesTasks {
tasks: Default::default(),
next_task_id: TaskId(0),
to_spawn: SmallVec::new(),
events_tx,
events_rx,
}
}
pub fn add_reach_attempt<TFut, TMuxer>(&mut self, future: TFut, handler: THandler) -> TaskId
where
TFut: Future<Item = (PeerId, TMuxer), Error = TReachErr> + Send + 'static,
THandler: NodeHandler<Substream = Substream<TMuxer>, InEvent = TInEvent, OutEvent = TOutEvent, Error = THandlerErr> + Send + 'static,
TReachErr: error::Error + Send + 'static,
THandlerErr: error::Error + Send + 'static,
TInEvent: Send + 'static,
TOutEvent: Send + 'static,
THandler::OutboundOpenInfo: Send + 'static, TMuxer: StreamMuxer + Send + Sync + 'static, TMuxer::OutboundSubstream: Send + 'static, {
let task_id = self.next_task_id;
self.next_task_id.0 += 1;
let (tx, rx) = mpsc::unbounded();
self.tasks.insert(task_id, tx);
let task = Box::new(NodeTask {
inner: NodeTaskInner::Future {
future,
handler,
events_buffer: Vec::new(),
},
events_tx: self.events_tx.clone(),
in_events_rx: rx.fuse(),
id: task_id,
});
self.to_spawn.push(task);
task_id
}
pub fn broadcast_event(&mut self, event: &TInEvent)
where TInEvent: Clone,
{
for sender in self.tasks.values() {
let _ = sender.unbounded_send(event.clone());
}
}
#[inline]
pub fn task(&mut self, id: TaskId) -> Option<Task<TInEvent>> {
match self.tasks.entry(id) {
Entry::Occupied(inner) => Some(Task { inner }),
Entry::Vacant(_) => None,
}
}
#[inline]
pub fn tasks<'a>(&'a self) -> impl Iterator<Item = TaskId> + 'a {
self.tasks.keys().cloned()
}
pub fn poll(&mut self) -> Async<HandledNodesEvent<TOutEvent, THandler, TReachErr, THandlerErr>> {
for to_spawn in self.to_spawn.drain() {
tokio_executor::spawn(to_spawn);
}
loop {
match self.events_rx.poll() {
Ok(Async::Ready(Some((message, task_id)))) => {
if !self.tasks.contains_key(&task_id) {
continue;
};
match message {
InToExtMessage::NodeEvent(event) => {
break Async::Ready(HandledNodesEvent::NodeEvent {
id: task_id,
event,
});
},
InToExtMessage::NodeReached(peer_id) => {
break Async::Ready(HandledNodesEvent::NodeReached {
id: task_id,
peer_id,
});
},
InToExtMessage::TaskClosed(result, handler) => {
let _ = self.tasks.remove(&task_id);
break Async::Ready(HandledNodesEvent::TaskClosed {
id: task_id, result, handler
});
},
}
}
Ok(Async::NotReady) => {
break Async::NotReady;
}
Ok(Async::Ready(None)) => {
unreachable!("The sender is in self as well, therefore the receiver never \
closes.")
},
Err(()) => unreachable!("An unbounded receiver never errors"),
}
}
}
}
pub struct Task<'a, TInEvent: 'a> {
inner: OccupiedEntry<'a, TaskId, mpsc::UnboundedSender<TInEvent>>,
}
impl<'a, TInEvent> Task<'a, TInEvent> {
#[inline]
pub fn send_event(&mut self, event: TInEvent) {
let _ = self.inner.get_mut().unbounded_send(event);
}
#[inline]
pub fn id(&self) -> TaskId {
*self.inner.key()
}
pub fn close(self) {
self.inner.remove();
}
}
impl<'a, TInEvent> fmt::Debug for Task<'a, TInEvent> {
fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
f.debug_tuple("Task")
.field(&self.id())
.finish()
}
}
impl<TInEvent, TOutEvent, THandler, TReachErr, THandlerErr> Stream for HandledNodesTasks<TInEvent, TOutEvent, THandler, TReachErr, THandlerErr> {
type Item = HandledNodesEvent<TOutEvent, THandler, TReachErr, THandlerErr>;
type Error = Void;
#[inline]
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
Ok(self.poll().map(Option::Some))
}
}
#[derive(Debug)]
enum InToExtMessage<TOutEvent, THandler, TReachErr, THandlerErr> {
NodeReached(PeerId),
TaskClosed(Result<(), TaskClosedEvent<TReachErr, THandlerErr>>, Option<THandler>),
NodeEvent(TOutEvent),
}
struct NodeTask<TFut, TMuxer, THandler, TInEvent, TOutEvent, TReachErr>
where
TMuxer: StreamMuxer,
THandler: NodeHandler<Substream = Substream<TMuxer>>,
{
events_tx: mpsc::UnboundedSender<(InToExtMessage<TOutEvent, THandler, TReachErr, THandler::Error>, TaskId)>,
in_events_rx: stream::Fuse<mpsc::UnboundedReceiver<TInEvent>>,
inner: NodeTaskInner<TFut, TMuxer, THandler, TInEvent>,
id: TaskId,
}
enum NodeTaskInner<TFut, TMuxer, THandler, TInEvent>
where
TMuxer: StreamMuxer,
THandler: NodeHandler<Substream = Substream<TMuxer>>,
{
Future {
future: TFut,
handler: THandler,
events_buffer: Vec<TInEvent>,
},
Node(HandledNode<TMuxer, THandler>),
Poisoned,
}
impl<TFut, TMuxer, THandler, TInEvent, TOutEvent, TReachErr> Future for
NodeTask<TFut, TMuxer, THandler, TInEvent, TOutEvent, TReachErr>
where
TMuxer: StreamMuxer,
TFut: Future<Item = (PeerId, TMuxer), Error = TReachErr>,
THandler: NodeHandler<Substream = Substream<TMuxer>, InEvent = TInEvent, OutEvent = TOutEvent>,
{
type Item = ();
type Error = ();
fn poll(&mut self) -> Poll<(), ()> {
loop {
match mem::replace(&mut self.inner, NodeTaskInner::Poisoned) {
NodeTaskInner::Future { mut future, handler, mut events_buffer } => {
loop {
match self.in_events_rx.poll() {
Ok(Async::Ready(None)) => return Ok(Async::Ready(())),
Ok(Async::Ready(Some(event))) => events_buffer.push(event),
Ok(Async::NotReady) => break,
Err(_) => unreachable!("An UnboundedReceiver never errors"),
}
}
match future.poll() {
Ok(Async::Ready((peer_id, muxer))) => {
let event = InToExtMessage::NodeReached(peer_id);
let mut node = HandledNode::new(muxer, handler);
for event in events_buffer {
node.inject_event(event);
}
if self.events_tx.unbounded_send((event, self.id)).is_err() {
node.shutdown();
}
self.inner = NodeTaskInner::Node(node);
}
Ok(Async::NotReady) => {
self.inner = NodeTaskInner::Future { future, handler, events_buffer };
return Ok(Async::NotReady);
},
Err(err) => {
let event = InToExtMessage::TaskClosed(Err(TaskClosedEvent::Reach(err)), Some(handler));
let _ = self.events_tx.unbounded_send((event, self.id));
return Ok(Async::Ready(()));
}
}
},
NodeTaskInner::Node(mut node) => {
if !self.in_events_rx.is_done() {
loop {
match self.in_events_rx.poll() {
Ok(Async::NotReady) => break,
Ok(Async::Ready(Some(event))) => {
node.inject_event(event)
},
Ok(Async::Ready(None)) => {
node.shutdown();
break;
}
Err(()) => unreachable!("An unbounded receiver never errors"),
}
}
}
loop {
match node.poll() {
Ok(Async::NotReady) => {
self.inner = NodeTaskInner::Node(node);
return Ok(Async::NotReady);
},
Ok(Async::Ready(Some(event))) => {
let event = InToExtMessage::NodeEvent(event);
if self.events_tx.unbounded_send((event, self.id)).is_err() {
node.shutdown();
}
}
Ok(Async::Ready(None)) => {
let event = InToExtMessage::TaskClosed(Ok(()), None);
let _ = self.events_tx.unbounded_send((event, self.id));
return Ok(Async::Ready(())); }
Err(err) => {
let event = InToExtMessage::TaskClosed(Err(TaskClosedEvent::Node(err)), None);
let _ = self.events_tx.unbounded_send((event, self.id));
return Ok(Async::Ready(())); }
}
}
},
NodeTaskInner::Poisoned => panic!("the node task panicked or errored earlier")
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io;
use futures::future::{self, FutureResult};
use futures::sync::mpsc::{UnboundedReceiver, UnboundedSender};
use nodes::handled_node::NodeHandlerEvent;
use tests::dummy_handler::{Handler, HandlerState, InEvent, OutEvent, TestHandledNode};
use tests::dummy_muxer::{DummyMuxer, DummyConnectionState};
use tokio::runtime::Builder;
use tokio::runtime::current_thread::Runtime;
use void::Void;
use PeerId;
type TestNodeTask = NodeTask<
FutureResult<(PeerId, DummyMuxer), io::Error>,
DummyMuxer,
Handler,
InEvent,
OutEvent,
io::Error,
>;
struct NodeTaskTestBuilder {
task_id: TaskId,
inner_node: Option<TestHandledNode>,
inner_fut: Option<FutureResult<(PeerId, DummyMuxer), io::Error>>,
}
impl NodeTaskTestBuilder {
fn new() -> Self {
NodeTaskTestBuilder {
task_id: TaskId(123),
inner_node: None,
inner_fut: {
let peer_id = PeerId::random();
Some(future::ok((peer_id, DummyMuxer::new())))
},
}
}
fn with_inner_fut(&mut self, fut: FutureResult<(PeerId, DummyMuxer), io::Error>) -> &mut Self{
self.inner_fut = Some(fut);
self
}
fn with_task_id(&mut self, id: usize) -> &mut Self {
self.task_id = TaskId(id);
self
}
fn node_task(&mut self) -> (
TestNodeTask,
UnboundedSender<InEvent>,
UnboundedReceiver<(InToExtMessage<OutEvent, Handler, io::Error, io::Error>, TaskId)>,
) {
let (events_from_node_task_tx, events_from_node_task_rx) = mpsc::unbounded::<(InToExtMessage<OutEvent, Handler, _, _>, TaskId)>();
let (events_to_node_task_tx, events_to_node_task_rx) = mpsc::unbounded::<InEvent>();
let inner = if self.inner_node.is_some() {
NodeTaskInner::Node(self.inner_node.take().unwrap())
} else {
NodeTaskInner::Future {
future: self.inner_fut.take().unwrap(),
handler: Handler::default(),
events_buffer: Vec::new(),
}
};
let node_task = NodeTask {
inner: inner,
events_tx: events_from_node_task_tx.clone(), in_events_rx: events_to_node_task_rx.fuse(), id: self.task_id,
};
(node_task, events_to_node_task_tx, events_from_node_task_rx)
}
}
type TestHandledNodesTasks = HandledNodesTasks<InEvent, OutEvent, Handler, io::Error, io::Error>;
struct HandledNodeTaskTestBuilder {
muxer: DummyMuxer,
handler: Handler,
task_count: usize,
}
impl HandledNodeTaskTestBuilder {
fn new() -> Self {
HandledNodeTaskTestBuilder {
muxer: DummyMuxer::new(),
handler: Handler::default(),
task_count: 0,
}
}
fn with_tasks(&mut self, amt: usize) -> &mut Self {
self.task_count = amt;
self
}
fn with_muxer_inbound_state(&mut self, state: DummyConnectionState) -> &mut Self {
self.muxer.set_inbound_connection_state(state);
self
}
fn with_muxer_outbound_state(&mut self, state: DummyConnectionState) -> &mut Self {
self.muxer.set_outbound_connection_state(state);
self
}
fn with_handler_state(&mut self, state: HandlerState) -> &mut Self {
self.handler.state = Some(state);
self
}
fn with_handler_states(&mut self, states: Vec<HandlerState>) -> &mut Self {
self.handler.next_states = states;
self
}
fn handled_nodes_tasks(&mut self) -> (TestHandledNodesTasks, Vec<TaskId>) {
let mut handled_nodes = HandledNodesTasks::new();
let peer_id = PeerId::random();
let mut task_ids = Vec::new();
for _i in 0..self.task_count {
let fut = future::ok((peer_id.clone(), self.muxer.clone()));
task_ids.push(
handled_nodes.add_reach_attempt(fut, self.handler.clone())
);
}
(handled_nodes, task_ids)
}
}
#[test]
fn task_emits_event_when_things_happen_in_the_node() {
let (node_task, tx, mut rx) = NodeTaskTestBuilder::new()
.with_task_id(890)
.node_task();
tx.unbounded_send(InEvent::Custom("beef")).expect("send to NodeTask should work");
let mut rt = Runtime::new().unwrap();
rt.spawn(node_task);
let events = rt.block_on(rx.by_ref().take(2).collect()).expect("reading on rx should work");
assert_matches!(events[0], (InToExtMessage::NodeReached(_), TaskId(890)));
assert_matches!(events[1], (InToExtMessage::NodeEvent(ref outevent), TaskId(890)) => {
assert_matches!(outevent, OutEvent::Custom(beef) => {
assert_eq!(beef, &"beef");
})
});
}
#[test]
fn task_exits_when_node_errors() {
let mut rt = Runtime::new().unwrap();
let (node_task, _tx, rx) = NodeTaskTestBuilder::new()
.with_inner_fut(future::err(io::Error::new(io::ErrorKind::Other, "nah")))
.with_task_id(345)
.node_task();
rt.spawn(node_task);
let events = rt.block_on(rx.collect()).expect("rx failed");
assert!(events.len() == 1);
assert_matches!(events[0], (InToExtMessage::TaskClosed{..}, TaskId(345)));
}
#[test]
fn task_exits_when_node_is_done() {
let mut rt = Runtime::new().unwrap();
let fut = {
let peer_id = PeerId::random();
let mut muxer = DummyMuxer::new();
muxer.set_inbound_connection_state(DummyConnectionState::Closed);
muxer.set_outbound_connection_state(DummyConnectionState::Closed);
future::ok((peer_id, muxer))
};
let (node_task, tx, rx) = NodeTaskTestBuilder::new()
.with_inner_fut(fut)
.with_task_id(345)
.node_task();
let create_outbound_substream_event = InEvent::Substream(Some(135));
tx.unbounded_send(create_outbound_substream_event).expect("send msg works");
rt.spawn(node_task);
let events = rt.block_on(rx.collect()).expect("rx failed");
assert_eq!(events.len(), 2);
assert_matches!(events[0].0, InToExtMessage::NodeReached(PeerId{..}));
assert_matches!(events[1].0, InToExtMessage::TaskClosed(Ok(()), _));
}
#[test]
fn query_for_tasks() {
let (mut handled_nodes, task_ids) = HandledNodeTaskTestBuilder::new()
.with_tasks(3)
.handled_nodes_tasks();
assert_eq!(task_ids.len(), 3);
assert_eq!(handled_nodes.task(TaskId(2)).unwrap().id(), task_ids[2]);
assert!(handled_nodes.task(TaskId(545534)).is_none());
}
#[test]
fn send_event_to_task() {
let (mut handled_nodes, _) = HandledNodeTaskTestBuilder::new()
.with_tasks(1)
.handled_nodes_tasks();
let task_id = {
let mut task = handled_nodes.task(TaskId(0)).expect("can fetch a Task");
task.send_event(InEvent::Custom("banana"));
task.id()
};
let mut rt = Builder::new().core_threads(1).build().unwrap();
let mut events = rt.block_on(handled_nodes.into_future()).unwrap();
assert_matches!(events.0.unwrap(), HandledNodesEvent::NodeReached{..});
events = rt.block_on(events.1.into_future()).unwrap();
assert_matches!(events.0.unwrap(), HandledNodesEvent::NodeEvent{id: event_task_id, event} => {
assert_eq!(event_task_id, task_id);
assert_matches!(event, OutEvent::Custom("banana"));
});
}
#[test]
fn iterate_over_all_tasks() {
let (handled_nodes, task_ids) = HandledNodeTaskTestBuilder::new()
.with_tasks(3)
.handled_nodes_tasks();
let mut tasks: Vec<TaskId> = handled_nodes.tasks().collect();
assert!(tasks.len() == 3);
tasks.sort_by_key(|t| t.0 );
assert_eq!(tasks, task_ids);
}
#[test]
fn add_reach_attempt_prepares_a_new_task() {
let mut handled_nodes = HandledNodesTasks::new();
assert_eq!(handled_nodes.tasks().count(), 0);
assert_eq!(handled_nodes.to_spawn.len(), 0);
handled_nodes.add_reach_attempt( future::empty::<_, Void>(), Handler::default() );
assert_eq!(handled_nodes.tasks().count(), 1);
assert_eq!(handled_nodes.to_spawn.len(), 1);
}
#[test]
fn running_handled_tasks_reaches_the_nodes() {
let (mut handled_nodes_tasks, _) = HandledNodeTaskTestBuilder::new()
.with_tasks(5)
.with_muxer_inbound_state(DummyConnectionState::Closed)
.with_muxer_outbound_state(DummyConnectionState::Closed)
.with_handler_state(HandlerState::Err) .handled_nodes_tasks();
let mut rt = Runtime::new().unwrap();
let mut events: (Option<HandledNodesEvent<_, _, _, _>>, TestHandledNodesTasks);
for i in 0..5 {
events = rt.block_on(handled_nodes_tasks.into_future()).unwrap();
assert_matches!(events, (Some(HandledNodesEvent::NodeReached{..}), ref hnt) => {
assert_matches!(hnt, HandledNodesTasks{..});
assert_eq!(hnt.tasks().count(), 5 - i);
});
handled_nodes_tasks = events.1;
events = rt.block_on(handled_nodes_tasks.into_future()).unwrap();
assert_matches!(events, (Some(HandledNodesEvent::TaskClosed{..}), _));
handled_nodes_tasks = events.1;
}
}
#[test]
fn events_in_tasks_are_emitted() {
let handler_states = vec![
HandlerState::Err,
HandlerState::Ready(Some(NodeHandlerEvent::Custom(OutEvent::Custom("from handler2") ))),
HandlerState::Ready(Some(NodeHandlerEvent::Custom(OutEvent::Custom("from handler") ))),
];
let (mut handled_nodes_tasks, _) = HandledNodeTaskTestBuilder::new()
.with_tasks(1)
.with_muxer_inbound_state(DummyConnectionState::Pending)
.with_muxer_outbound_state(DummyConnectionState::Opened)
.with_handler_states(handler_states)
.handled_nodes_tasks();
let tx = {
let mut task0 = handled_nodes_tasks.task(TaskId(0)).unwrap();
let tx = task0.inner.get_mut();
tx.clone()
};
let mut rt = Builder::new().core_threads(1).build().unwrap();
let mut events = rt.block_on(handled_nodes_tasks.into_future()).unwrap();
assert_matches!(events.0.unwrap(), HandledNodesEvent::NodeReached{..});
tx.unbounded_send(InEvent::NextState).expect("send works");
events = rt.block_on(events.1.into_future()).unwrap();
assert_matches!(events.0.unwrap(), HandledNodesEvent::NodeEvent{id: _, event} => {
assert_matches!(event, OutEvent::Custom("from handler"));
});
tx.unbounded_send(InEvent::NextState).expect("send works");
events = rt.block_on(events.1.into_future()).unwrap();
assert_matches!(events.0.unwrap(), HandledNodesEvent::NodeEvent{id: _, event} => {
assert_matches!(event, OutEvent::Custom("from handler2"));
});
tx.unbounded_send(InEvent::NextState).expect("send works");
events = rt.block_on(events.1.into_future()).unwrap();
assert_matches!(events.0.unwrap(), HandledNodesEvent::TaskClosed{id: _, result, handler: _} => {
assert_matches!(result, Err(_));
});
}
}