use crate::{
Executor,
muxing::StreamMuxer,
};
use fnv::FnvHashMap;
use futures::{
prelude::*,
channel::mpsc,
stream::FuturesUnordered
};
use std::{
collections::hash_map,
error,
fmt,
mem,
pin::Pin,
task::{Context, Poll},
};
use super::{
Connected,
ConnectedPoint,
Connection,
ConnectionError,
ConnectionHandler,
IntoConnectionHandler,
PendingConnectionError,
Substream
};
use task::{Task, TaskId};
mod task;
type ConnectResult<C, M, TE> = Result<(Connected<C>, M), PendingConnectionError<TE>>;
#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
pub struct ConnectionId(TaskId);
impl ConnectionId {
pub fn new(id: usize) -> Self {
ConnectionId(TaskId(id))
}
}
pub struct Manager<I, O, H, E, HE, C> {
tasks: FnvHashMap<TaskId, TaskInfo<I, C>>,
next_task_id: TaskId,
task_command_buffer_size: usize,
executor: Option<Box<dyn Executor + Send>>,
local_spawns: FuturesUnordered<Pin<Box<dyn Future<Output = ()> + Send>>>,
events_tx: mpsc::Sender<task::Event<O, H, E, HE, C>>,
events_rx: mpsc::Receiver<task::Event<O, H, E, HE, C>>
}
impl<I, O, H, E, HE, C> fmt::Debug for Manager<I, O, H, E, HE, C>
where
C: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_map()
.entries(self.tasks.iter().map(|(id, task)| (id, &task.state)))
.finish()
}
}
#[non_exhaustive]
pub struct ManagerConfig {
pub executor: Option<Box<dyn Executor + Send>>,
pub task_command_buffer_size: usize,
pub task_event_buffer_size: usize,
}
impl Default for ManagerConfig {
fn default() -> Self {
ManagerConfig {
executor: None,
task_event_buffer_size: 32,
task_command_buffer_size: 7,
}
}
}
#[derive(Debug)]
struct TaskInfo<I, C> {
sender: mpsc::Sender<task::Command<I>>,
state: TaskState<C>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
enum TaskState<C> {
Pending,
Established(Connected<C>),
}
#[derive(Debug)]
pub enum Event<'a, I, O, H, TE, HE, C> {
PendingConnectionError {
id: ConnectionId,
error: PendingConnectionError<TE>,
handler: H
},
ConnectionClosed {
id: ConnectionId,
connected: Connected<C>,
error: Option<ConnectionError<HE>>,
},
ConnectionEstablished {
entry: EstablishedEntry<'a, I, C>,
},
ConnectionEvent {
entry: EstablishedEntry<'a, I, C>,
event: O
},
AddressChange {
entry: EstablishedEntry<'a, I, C>,
old_endpoint: ConnectedPoint,
new_endpoint: ConnectedPoint,
},
}
impl<I, O, H, TE, HE, C> Manager<I, O, H, TE, HE, C> {
pub fn new(config: ManagerConfig) -> Self {
let (tx, rx) = mpsc::channel(config.task_event_buffer_size);
Self {
tasks: FnvHashMap::default(),
next_task_id: TaskId(0),
task_command_buffer_size: config.task_command_buffer_size,
executor: config.executor,
local_spawns: FuturesUnordered::new(),
events_tx: tx,
events_rx: rx
}
}
pub fn add_pending<F, M>(&mut self, future: F, handler: H) -> ConnectionId
where
I: Send + 'static,
O: Send + 'static,
TE: error::Error + Send + 'static,
HE: error::Error + Send + 'static,
C: Send + 'static,
M: StreamMuxer + Send + Sync + 'static,
M::OutboundSubstream: Send + 'static,
F: Future<Output = ConnectResult<C, M, TE>> + Send + 'static,
H: IntoConnectionHandler<C> + Send + 'static,
H::Handler: ConnectionHandler<
Substream = Substream<M>,
InEvent = I,
OutEvent = O,
Error = HE
> + Send + 'static,
<H::Handler as ConnectionHandler>::OutboundOpenInfo: Send + 'static,
{
let task_id = self.next_task_id;
self.next_task_id.0 += 1;
let (tx, rx) = mpsc::channel(self.task_command_buffer_size);
self.tasks.insert(task_id, TaskInfo { sender: tx, state: TaskState::Pending });
let task = Box::pin(Task::pending(task_id, self.events_tx.clone(), rx, future, handler));
if let Some(executor) = &mut self.executor {
executor.exec(task);
} else {
self.local_spawns.push(task);
}
ConnectionId(task_id)
}
pub fn add<M>(&mut self, conn: Connection<M, H::Handler>, info: Connected<C>) -> ConnectionId
where
H: IntoConnectionHandler<C> + Send + 'static,
H::Handler: ConnectionHandler<
Substream = Substream<M>,
InEvent = I,
OutEvent = O,
Error = HE
> + Send + 'static,
<H::Handler as ConnectionHandler>::OutboundOpenInfo: Send + 'static,
TE: error::Error + Send + 'static,
HE: error::Error + Send + 'static,
I: Send + 'static,
O: Send + 'static,
M: StreamMuxer + Send + Sync + 'static,
M::OutboundSubstream: Send + 'static,
C: Send + 'static
{
let task_id = self.next_task_id;
self.next_task_id.0 += 1;
let (tx, rx) = mpsc::channel(self.task_command_buffer_size);
self.tasks.insert(task_id, TaskInfo {
sender: tx, state: TaskState::Established(info)
});
let task: Pin<Box<Task<Pin<Box<future::Pending<_>>>, _, _, _, _, _, _>>> =
Box::pin(Task::established(task_id, self.events_tx.clone(), rx, conn));
if let Some(executor) = &mut self.executor {
executor.exec(task);
} else {
self.local_spawns.push(task);
}
ConnectionId(task_id)
}
pub fn entry(&mut self, id: ConnectionId) -> Option<Entry<'_, I, C>> {
if let hash_map::Entry::Occupied(task) = self.tasks.entry(id.0) {
Some(Entry::new(task))
} else {
None
}
}
pub fn is_established(&self, id: &ConnectionId) -> bool {
match self.tasks.get(&id.0) {
Some(TaskInfo { state: TaskState::Established(..), .. }) => true,
_ => false
}
}
pub fn poll<'a>(&'a mut self, cx: &mut Context<'_>) -> Poll<Event<'a, I, O, H, TE, HE, C>> {
while let Poll::Ready(Some(_)) = self.local_spawns.poll_next_unpin(cx) {}
let event = loop {
match self.events_rx.poll_next_unpin(cx) {
Poll::Ready(Some(event)) => {
if self.tasks.contains_key(event.id()) { break event
}
}
Poll::Pending => return Poll::Pending,
Poll::Ready(None) => unreachable!("Manager holds both sender and receiver."),
}
};
if let hash_map::Entry::Occupied(mut task) = self.tasks.entry(*event.id()) {
Poll::Ready(match event {
task::Event::Notify { id: _, event } =>
Event::ConnectionEvent {
entry: EstablishedEntry { task },
event
},
task::Event::Established { id: _, info } => { task.get_mut().state = TaskState::Established(info); Event::ConnectionEstablished {
entry: EstablishedEntry { task },
}
}
task::Event::Failed { id, error, handler } => {
let id = ConnectionId(id);
let _ = task.remove();
Event::PendingConnectionError { id, error, handler }
}
task::Event::AddressChange { id: _, new_address } => {
let (new, old) = if let TaskState::Established(c) = &mut task.get_mut().state {
let mut new_endpoint = c.endpoint.clone();
new_endpoint.set_remote_address(new_address);
let old_endpoint = mem::replace(&mut c.endpoint, new_endpoint.clone());
(new_endpoint, old_endpoint)
} else {
unreachable!(
"`Event::AddressChange` implies (2) occurred on that task and thus (3)."
)
};
Event::AddressChange {
entry: EstablishedEntry { task },
old_endpoint: old,
new_endpoint: new,
}
}
task::Event::Closed { id, error } => {
let id = ConnectionId(id);
let task = task.remove();
match task.state {
TaskState::Established(connected) =>
Event::ConnectionClosed { id, connected, error },
TaskState::Pending => unreachable!(
"`Event::Closed` implies (2) occurred on that task and thus (3)."
),
}
}
})
} else {
unreachable!("By (1)")
}
}
}
#[derive(Debug)]
pub enum Entry<'a, I, C> {
Pending(PendingEntry<'a, I, C>),
Established(EstablishedEntry<'a, I, C>)
}
impl<'a, I, C> Entry<'a, I, C> {
fn new(task: hash_map::OccupiedEntry<'a, TaskId, TaskInfo<I, C>>) -> Self {
match &task.get().state {
TaskState::Pending => Entry::Pending(PendingEntry { task }),
TaskState::Established(_) => Entry::Established(EstablishedEntry { task })
}
}
}
#[derive(Debug)]
pub struct EstablishedEntry<'a, I, C> {
task: hash_map::OccupiedEntry<'a, TaskId, TaskInfo<I, C>>,
}
impl<'a, I, C> EstablishedEntry<'a, I, C> {
pub fn notify_handler(&mut self, event: I) -> Result<(), I> {
let cmd = task::Command::NotifyHandler(event); self.task.get_mut().sender.try_send(cmd)
.map_err(|e| match e.into_inner() {
task::Command::NotifyHandler(event) => event,
_ => panic!("Unexpected command. Expected `NotifyHandler`") })
}
pub fn poll_ready_notify_handler(&mut self, cx: &mut Context<'_>) -> Poll<Result<(),()>> {
self.task.get_mut().sender.poll_ready(cx).map_err(|_| ())
}
pub fn start_close(mut self) {
match self.task.get_mut().sender.clone().try_send(task::Command::Close) {
Ok(()) => {},
Err(e) => assert!(e.is_disconnected(), "No capacity for close command.")
}
}
pub fn connected(&self) -> &Connected<C> {
match &self.task.get().state {
TaskState::Established(c) => c,
TaskState::Pending => unreachable!("By Entry::new()")
}
}
pub fn remove(self) -> Connected<C> {
match self.task.remove().state {
TaskState::Established(c) => c,
TaskState::Pending => unreachable!("By Entry::new()")
}
}
pub fn id(&self) -> ConnectionId {
ConnectionId(*self.task.key())
}
}
#[derive(Debug)]
pub struct PendingEntry<'a, I, C> {
task: hash_map::OccupiedEntry<'a, TaskId, TaskInfo<I, C>>
}
impl<'a, I, C> PendingEntry<'a, I, C> {
pub fn id(&self) -> ConnectionId {
ConnectionId(*self.task.key())
}
pub fn abort(self) {
self.task.remove();
}
}