use bytes::Bytes;
use futures::{
future, pin_mut,
sink::{Sink, SinkExt},
stream::{Stream, StreamExt},
Future, FutureExt,
};
use lazy_static::lazy_static;
use std::{
collections::HashMap,
convert::TryFrom,
error::Error,
fmt,
marker::PhantomData,
mem::size_of,
pin::Pin,
sync::{
atomic::{AtomicBool, AtomicU16, Ordering},
Arc,
},
task::{Context, Poll},
time::Duration,
};
use tokio::{
sync::{mpsc, mpsc::Permit, oneshot, Mutex},
time::{sleep, timeout},
try_join,
};
use super::{
client::{Client, ConnectRequest, ConnectResponse},
credit::{credit_monitor_pair, credit_send_pair, ChannelCreditMonitor, CreditProvider},
listener::{Listener, RemoteConnectMsg, Request},
msg::{ExchangedCfg, MultiplexMsg},
port_allocator::{PortAllocator, PortNumber},
receiver::{PortReceiveMsg, ReceivedData, ReceivedPortRequests, Receiver},
sender::Sender,
Cfg, MultiplexError, PROTOCOL_VERSION,
};
macro_rules! protocol_err {
($msg:expr) => {
super::MultiplexError::Protocol($msg.to_string())
};
}
#[derive(Debug)]
enum PortState {
Connecting {
response_tx: oneshot::Sender<ConnectResponse>,
},
Connected {
remote_port: u32,
sender_credit_provider: CreditProvider,
receiver_tx_data: Option<mpsc::UnboundedSender<PortReceiveMsg>>,
receiver_credit_monitor: ChannelCreditMonitor,
receiver_closed: bool,
receiver_dropped: bool,
sender_dropped: bool,
remote_receiver_closed: Arc<AtomicBool>,
remote_receiver_closed_notify: Arc<Mutex<Option<Vec<oneshot::Sender<()>>>>>,
remote_receiver_dropped: bool,
},
}
#[derive(Debug)]
pub(crate) enum PortEvt {
Accepted {
local_port: PortNumber,
remote_port: u32,
port_tx: oneshot::Sender<(Sender, Receiver)>,
},
Rejected {
remote_port: u32,
no_ports: bool,
},
SendData {
remote_port: u32,
data: Bytes,
first: bool,
last: bool,
},
SendPorts {
remote_port: u32,
first: bool,
last: bool,
wait: bool,
ports: Vec<(PortNumber, oneshot::Sender<ConnectResponse>)>,
},
ReturnCredits {
remote_port: u32,
credits: u32,
},
SenderDropped {
local_port: u32,
},
ReceiverClosed {
local_port: u32,
},
ReceiverDropped {
local_port: u32,
},
}
#[derive(Debug)]
enum GlobalEvt {
ConnectReq(ConnectRequest),
AllClientsDropped,
ListenerDropped,
Port(PortEvt),
SendGoodbye,
Flush,
}
enum SendCmd {
Send(TransportMsg),
Flush,
}
struct TransportMsg {
msg: MultiplexMsg,
data: Option<Bytes>,
}
impl TransportMsg {
fn new(msg: MultiplexMsg) -> Self {
assert!(!matches!(&msg, &MultiplexMsg::Data { .. }), "MultiplexMsg::Data with missing data");
Self { msg, data: None }
}
fn with_data(msg: MultiplexMsg, data: Bytes) -> Self {
assert!(matches!(&msg, &MultiplexMsg::Data { .. }), "MultiplexMsg with unexpected data");
Self { msg, data: Some(data) }
}
}
pub struct Multiplexer<TransportSink, TransportStream> {
trace_id: String,
local_cfg: Cfg,
remote_cfg: ExchangedCfg,
remote_protocol_version: u8,
connect_rx: Option<mpsc::UnboundedReceiver<ConnectRequest>>,
listen_tx: Option<(mpsc::Sender<RemoteConnectMsg>, mpsc::Sender<RemoteConnectMsg>)>,
port_allocator: PortAllocator,
ports: HashMap<PortNumber, PortState>,
channel_tx: mpsc::Sender<PortEvt>,
channel_rx: Option<mpsc::Receiver<PortEvt>>,
terminate_rx: Option<mpsc::UnboundedReceiver<()>>,
all_clients_dropped: bool,
remote_client_dropped: bool,
remote_listener_dropped: Arc<AtomicBool>,
goodbye_sent: bool,
goodbye_received: bool,
transport_sink: Option<TransportSink>,
transport_stream: Option<TransportStream>,
}
impl<TransportSink, TransportStream> fmt::Debug for Multiplexer<TransportSink, TransportStream> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("Multiplexer")
.field("trace_id", &self.trace_id)
.field("local_cfg", &self.local_cfg)
.field("remote_cfg", &self.remote_cfg)
.field("local_protocol_version", &PROTOCOL_VERSION)
.field("remote_protocol_version", &self.remote_protocol_version)
.finish_non_exhaustive()
}
}
impl<TransportSink, TransportSinkError, TransportStream, TransportStreamError>
Multiplexer<TransportSink, TransportStream>
where
TransportSink: Sink<Bytes, Error = TransportSinkError> + Send + Unpin,
TransportSinkError: Error + Send + Sync + 'static,
TransportStream: Stream<Item = Result<Bytes, TransportStreamError>> + Send + Unpin,
TransportStreamError: Error + Send + Sync + 'static,
{
pub async fn new(
cfg: &Cfg, mut transport_sink: TransportSink, mut transport_stream: TransportStream,
) -> Result<(Self, Client, Listener), MultiplexError<TransportSinkError, TransportStreamError>> {
cfg.check();
let trace_id = match cfg.trace_id.clone() {
Some(trace_id) => trace_id,
None => generate_trace_id(),
};
log::trace!("{}: exchanging hello", &trace_id);
let fut = Self::exchange_hello(&trace_id, cfg, &mut transport_sink, &mut transport_stream);
let (remote_protocol_version, remote_cfg) = match cfg.connection_timeout {
Some(dur) => timeout(dur, fut).await.map_err(|_| MultiplexError::Timeout)??,
None => fut.await?,
};
let (channel_tx, channel_rx) = mpsc::channel(cfg.shared_send_queue);
let (listen_wait_tx, listen_wait_rx) = mpsc::channel(usize::from(cfg.connect_queue) + 1);
let (listen_no_wait_tx, listen_no_wait_rx) = mpsc::channel(usize::from(cfg.connect_queue) + 1);
let (connect_tx, connect_rx) = mpsc::unbounded_channel();
let (terminate_tx, terminate_rx) = mpsc::unbounded_channel();
let port_allocator = PortAllocator::new(cfg.max_ports);
let remote_listener_dropped = Arc::new(AtomicBool::new(false));
let multiplexer = Multiplexer {
trace_id,
remote_protocol_version,
local_cfg: cfg.clone(),
remote_cfg: remote_cfg.clone(),
connect_rx: Some(connect_rx),
listen_tx: Some((listen_wait_tx, listen_no_wait_tx)),
port_allocator: port_allocator.clone(),
ports: HashMap::new(),
channel_tx,
channel_rx: Some(channel_rx),
terminate_rx: Some(terminate_rx),
remote_client_dropped: false,
remote_listener_dropped: remote_listener_dropped.clone(),
all_clients_dropped: false,
goodbye_sent: false,
goodbye_received: false,
transport_sink: Some(transport_sink),
transport_stream: Some(transport_stream),
};
let client = Client::new(
connect_tx,
remote_cfg.connect_queue,
port_allocator.clone(),
remote_listener_dropped,
terminate_tx.clone(),
);
let listener = Listener::new(listen_wait_rx, listen_no_wait_rx, port_allocator, terminate_tx);
log::trace!("{}: multiplexer created", &multiplexer.trace_id);
Ok((multiplexer, client, listener))
}
async fn feed_msg(
trace_id: &str, msg: TransportMsg, sink: &mut TransportSink,
) -> Result<(), MultiplexError<TransportSinkError, TransportStreamError>> {
log::trace!("{} ==> {:?}", trace_id, &msg.msg);
sink.feed(msg.msg.to_vec().into()).await.map_err(MultiplexError::SinkError)?;
if let Some(data) = msg.data {
log::trace!("{} ==> [{} bytes data]", trace_id, data.len());
sink.feed(data).await.map_err(MultiplexError::SinkError)?;
}
Ok(())
}
async fn flush(
trace_id: &str, sink: &mut TransportSink,
) -> Result<(), MultiplexError<TransportSinkError, TransportStreamError>> {
log::trace!("{} ==> [flush]", trace_id);
sink.flush().await.map_err(MultiplexError::SinkError)
}
async fn recv_msg(
trace_id: &str, stream: &mut TransportStream,
) -> Result<TransportMsg, MultiplexError<TransportSinkError, TransportStreamError>> {
let msg_data = match stream.next().await {
Some(Ok(msg_data)) => msg_data,
Some(Err(err)) => return Err(MultiplexError::StreamError(err)),
None => return Err(MultiplexError::StreamClosed),
};
let msg = MultiplexMsg::from_slice(&msg_data)?;
log::trace!("{} <== {:?}", trace_id, &msg);
let data = if let MultiplexMsg::Data { .. } = &msg {
match stream.next().await {
Some(Ok(data)) => {
log::trace!("{} <== [{} bytes data]", trace_id, data.len());
Some(data)
}
Some(Err(err)) => return Err(MultiplexError::StreamError(err)),
None => return Err(MultiplexError::StreamClosed),
}
} else {
None
};
Ok(TransportMsg { msg, data })
}
async fn exchange_hello(
trace_id: &str, cfg: &Cfg, sink: &mut TransportSink, stream: &mut TransportStream,
) -> Result<(u8, ExchangedCfg), MultiplexError<TransportSinkError, TransportStreamError>> {
let send_task = async {
Self::feed_msg(trace_id, TransportMsg::new(MultiplexMsg::Reset), sink).await?;
Self::flush(trace_id, sink).await?;
Self::feed_msg(
trace_id,
TransportMsg::new(MultiplexMsg::Hello { version: PROTOCOL_VERSION, cfg: cfg.into() }),
sink,
)
.await?;
Self::flush(trace_id, sink).await?;
Ok(())
};
let recv_task = async {
loop {
match Self::recv_msg(trace_id, stream).await {
Ok(TransportMsg { msg: MultiplexMsg::Hello { version, cfg }, .. }) => {
break Ok((version, cfg))
}
Ok(_) => (),
Err(MultiplexError::Protocol(_)) => (),
Err(err) => return Err(err),
}
}
};
Ok(try_join!(send_task, recv_task)?.1)
}
fn should_terminate(&self) -> bool {
let mut terminate = true;
terminate &= self.ports.is_empty();
terminate &= self.all_clients_dropped || self.remote_listener_dropped.load(Ordering::SeqCst);
terminate &= self.listen_tx.is_none() || self.remote_client_dropped;
terminate |= self.goodbye_sent;
terminate |= self.goodbye_received;
terminate
}
fn create_port(&mut self, local_port: PortNumber, remote_port: u32) -> (Sender, Receiver) {
let local_port_num = *local_port;
log::trace!(
"{}: created port {} connected to remote port {}",
&self.trace_id,
&local_port_num,
remote_port
);
let sender_tx = self.channel_tx.clone();
let (sender_credit_provider, sender_credit_user) = credit_send_pair(self.remote_cfg.port_receive_buffer);
let receiver_tx = self.channel_tx.clone();
let (receiver_tx_data, receiver_rx_data) = mpsc::unbounded_channel();
let (receiver_credit_monitor, receiver_credit_returner) =
credit_monitor_pair(self.local_cfg.receive_buffer);
let hangup_notify = Arc::new(Mutex::new(Some(Vec::new())));
let hangup_recved = Arc::new(AtomicBool::new(false));
if let Some(PortState::Connected { remote_port, .. }) = self.ports.insert(
local_port,
PortState::Connected {
remote_port,
sender_credit_provider,
receiver_tx_data: Some(receiver_tx_data),
receiver_credit_monitor,
remote_receiver_closed_notify: hangup_notify.clone(),
remote_receiver_closed: hangup_recved.clone(),
receiver_closed: false,
receiver_dropped: false,
sender_dropped: false,
remote_receiver_dropped: false,
},
) {
panic!(
"create_port called for local port {} already connected to remote port {}",
local_port_num, remote_port
);
}
let sender = Sender::new(
local_port_num,
remote_port,
self.remote_cfg.chunk_size as usize,
self.local_cfg.max_data_size,
sender_tx,
sender_credit_user,
Arc::downgrade(&hangup_recved),
Arc::downgrade(&hangup_notify),
);
let receiver = Receiver::new(
local_port_num,
remote_port,
self.local_cfg.max_data_size,
self.local_cfg.max_received_ports,
receiver_tx,
receiver_rx_data,
receiver_credit_returner,
);
(sender, receiver)
}
fn maybe_free_port(&mut self, local_port: u32) {
let mut free = true;
if let Some(PortState::Connected {
receiver_tx_data,
receiver_dropped,
sender_dropped,
remote_receiver_dropped,
..
}) = self.ports.get(&local_port)
{
free &= *sender_dropped;
free &= *receiver_dropped;
free &= receiver_tx_data.is_none();
free &= *remote_receiver_dropped;
} else {
panic!("maybe_free_port called for port {} not in connected state.", &local_port);
}
if free {
log::trace!("{}: freed port {}", &self.trace_id, &local_port);
self.ports.remove(&local_port);
}
}
async fn send_task(
trace_id: &str, mut sink: &mut TransportSink, ping_interval: Option<Duration>,
mut rx: mpsc::Receiver<SendCmd>,
) -> Result<(), MultiplexError<TransportSinkError, TransportStreamError>> {
async fn get_next_ping(ping_interval: Option<Duration>) {
match ping_interval {
Some(interval) => sleep(interval).await,
None => future::pending().await,
}
}
let mut next_ping = get_next_ping(ping_interval).fuse().boxed();
loop {
SinkReady::new(&mut sink).await.map_err(MultiplexError::SinkError)?;
tokio::select! {
biased;
cmd_opt = rx.recv() => {
match cmd_opt {
Some(SendCmd::Send (msg)) => {
let is_goodbye = matches!(&msg, TransportMsg {msg: MultiplexMsg::Goodbye, ..});
Self::feed_msg(trace_id, msg, sink).await?;
if is_goodbye {
break;
}
next_ping = get_next_ping(ping_interval).fuse().boxed();
}
Some(SendCmd::Flush) => Self::flush(trace_id, sink).await?,
None => break,
}
}
() = &mut next_ping => {
Self::feed_msg(trace_id, TransportMsg::new(MultiplexMsg::Ping), sink).await?;
Self::flush(trace_id, sink).await?;
next_ping = get_next_ping(ping_interval).fuse().boxed();
}
}
}
let _ = Self::flush(trace_id, sink).await;
Ok(())
}
async fn recv_task(
trace_id: &str, stream: &mut TransportStream, connection_timeout: Option<Duration>,
tx: mpsc::Sender<TransportMsg>,
) -> Result<(), MultiplexError<TransportSinkError, TransportStreamError>> {
async fn get_connection_timeout(connection_timeout: Option<Duration>) {
match connection_timeout {
Some(timeout) => sleep(timeout).await,
None => future::pending().await,
}
}
let mut next_timeout = get_connection_timeout(connection_timeout).fuse().boxed();
while let Ok(tx_permit) = tx.reserve().await {
tokio::select! {
biased;
msg = Self::recv_msg(trace_id, stream) => {
let msg = msg?;
let is_goodbye = matches!(&msg, TransportMsg {msg: MultiplexMsg::Goodbye, ..});
tx_permit.send(msg);
if is_goodbye {
break;
}
next_timeout = get_connection_timeout(connection_timeout).fuse().boxed();
},
() = &mut next_timeout => return Err(MultiplexError::Timeout),
}
}
Ok(())
}
pub async fn run(mut self) -> Result<(), MultiplexError<TransportSinkError, TransportStreamError>> {
let trace_id = self.trace_id.clone();
let mut transport_sink = self.transport_sink.take().unwrap();
let mut transport_stream = self.transport_stream.take().unwrap();
log::trace!("{}: multiplexer run", &trace_id);
let (send_tx, send_rx) = mpsc::channel(1);
let send_task = Self::send_task(
&trace_id,
&mut transport_sink,
self.remote_cfg.connection_timeout.map(|d| d / 2),
send_rx,
)
.fuse();
pin_mut!(send_task);
let (recv_tx, mut recv_rx) = mpsc::channel(1);
let recv_task =
Self::recv_task(&trace_id, &mut transport_stream, self.local_cfg.connection_timeout, recv_tx).fuse();
pin_mut!(recv_task);
let mut channel_rx = self.channel_rx.take().unwrap();
let mut connect_rx = self.connect_rx.take().unwrap();
let mut terminate_rx = self.terminate_rx.take().unwrap();
let mut flushed = false;
let mut send_task_ended = false;
while !(self.goodbye_sent && self.goodbye_received && send_task_ended) {
let send_prep_task = async {
let permit = match send_tx.reserve().await {
Ok(permit) => permit,
Err(_) => return None,
};
let event = tokio::select! {
biased;
() = async { match &self.listen_tx {
Some((listen_wait_tx, _)) => listen_wait_tx.closed().await,
None => future::pending().await
}} => {
flushed = false;
GlobalEvt::ListenerDropped
},
connect_req_opt = connect_rx.recv(), if !self.all_clients_dropped => {
flushed = false;
match connect_req_opt {
Some(connect_req) => GlobalEvt::ConnectReq(connect_req),
None => GlobalEvt::AllClientsDropped,
}
},
Some(msg) = channel_rx.recv() => {
flushed = false;
GlobalEvt::Port(msg)
},
Some(()) = terminate_rx.recv(), if !self.goodbye_sent => {
GlobalEvt::SendGoodbye
}
() = future::ready(()), if self.should_terminate() && !self.goodbye_sent => {
GlobalEvt::SendGoodbye
}
() = future::ready(()), if !flushed => {
flushed = true;
GlobalEvt::Flush
},
};
Some((permit, event))
};
tokio::select! {
Some((permit, event)) = send_prep_task => self.handle_event(&trace_id, permit, event).await?,
Some(msg) = recv_rx.recv() => self.handle_received_msg(&trace_id, msg).await?,
res = &mut send_task => {
match res {
Ok(()) => send_task_ended = true,
Err(err) => return Err(err),
}
}
Err(err) = &mut recv_task => return Err(err),
}
}
log::trace!("{}: multiplexer normal exit", &trace_id);
Ok(())
}
async fn handle_event(
&mut self, trace_id: &str, permit: Permit<'_, SendCmd>, event: GlobalEvt,
) -> Result<(), MultiplexError<TransportSinkError, TransportStreamError>> {
log::trace!("{}: processing event {:?}", trace_id, &event);
fn send_msg(permit: Permit<'_, SendCmd>, msg: MultiplexMsg) {
permit.send(SendCmd::Send(TransportMsg::new(msg)))
}
match event {
GlobalEvt::ConnectReq(ConnectRequest { local_port, sent_tx: _sent_tx, response_tx, wait }) => {
if !self.remote_listener_dropped.load(Ordering::SeqCst) {
let local_port_num = *local_port;
if self.ports.insert(local_port, PortState::Connecting { response_tx }).is_some() {
panic!("ConnectRequest for already used local port {}", local_port_num);
}
send_msg(permit, MultiplexMsg::OpenPort { client_port: local_port_num, wait });
} else {
let _ = response_tx.send(ConnectResponse::Rejected { no_ports: false });
}
}
GlobalEvt::Port(PortEvt::Accepted { local_port, remote_port, port_tx }) => {
let local_port_num = *local_port;
send_msg(
permit,
MultiplexMsg::PortOpened { client_port: remote_port, server_port: local_port_num },
);
let (sender, receiver) = self.create_port(local_port, remote_port);
let _ = port_tx.send((sender, receiver));
}
GlobalEvt::Port(PortEvt::Rejected { remote_port, no_ports }) => {
send_msg(permit, MultiplexMsg::Rejected { client_port: remote_port, no_ports });
}
GlobalEvt::Port(PortEvt::SendData { remote_port, data, first, last }) => {
permit.send(SendCmd::Send(TransportMsg::with_data(
MultiplexMsg::Data { port: remote_port, first, last },
data,
)));
}
GlobalEvt::Port(PortEvt::SendPorts { remote_port, ports, first, last, wait }) => {
let mut port_nums = Vec::new();
for (port, response_tx) in ports {
let port_num = *port;
if self.ports.insert(port, PortState::Connecting { response_tx }).is_some() {
panic!("SendPorts with already used local port {}", port_num);
}
port_nums.push(port_num);
}
send_msg(
permit,
MultiplexMsg::PortData { port: remote_port, first, last, wait, ports: port_nums },
);
}
GlobalEvt::Port(PortEvt::ReturnCredits { remote_port, credits }) => {
send_msg(permit, MultiplexMsg::PortCredits { port: remote_port, credits });
}
GlobalEvt::Port(PortEvt::SenderDropped { local_port }) => {
if let Some(PortState::Connected { remote_port, sender_dropped, .. }) =
self.ports.get_mut(&local_port)
{
if *sender_dropped {
panic!("PortEvt SenderDropped more than once for port {}", &local_port);
}
*sender_dropped = true;
send_msg(permit, MultiplexMsg::SendFinish { port: *remote_port });
self.maybe_free_port(local_port);
} else {
panic!("PortEvt SenderDropped for port {} in invalid state", &local_port);
}
}
GlobalEvt::Port(PortEvt::ReceiverClosed { local_port }) => {
if let Some(PortState::Connected { remote_port, receiver_closed, receiver_dropped, .. }) =
self.ports.get_mut(&local_port)
{
if *receiver_closed || *receiver_dropped {
panic!(
"PortEvt ReceiverClosed or ReceiverDropped more than once for port {}",
&local_port
);
}
*receiver_closed = true;
send_msg(permit, MultiplexMsg::ReceiveClose { port: *remote_port });
} else {
panic!("PortEvt ReceiverClosed for non-connected port {}", &local_port);
}
}
GlobalEvt::Port(PortEvt::ReceiverDropped { local_port }) => match self.ports.get_mut(&local_port) {
Some(PortState::Connected { remote_port, receiver_dropped, .. }) => {
if *receiver_dropped {
panic!("PortEvt ReceiverDropped more than once for port {}.", &local_port);
}
*receiver_dropped = true;
send_msg(permit, MultiplexMsg::ReceiveFinish { port: *remote_port });
self.maybe_free_port(local_port);
}
_ => {
panic!("PortEvt ReceiverDropped for port {} in invalid state.", &local_port);
}
},
GlobalEvt::AllClientsDropped => {
self.all_clients_dropped = true;
send_msg(permit, MultiplexMsg::ClientFinish);
}
GlobalEvt::ListenerDropped => {
self.listen_tx = None;
send_msg(permit, MultiplexMsg::ListenerFinish);
}
GlobalEvt::SendGoodbye => {
self.goodbye_sent = true;
send_msg(permit, MultiplexMsg::Goodbye);
}
GlobalEvt::Flush => {
permit.send(SendCmd::Flush);
}
}
Ok(())
}
async fn handle_received_msg(
&mut self, _trace_id: &str, received_msg: TransportMsg,
) -> Result<(), MultiplexError<TransportSinkError, TransportStreamError>> {
let TransportMsg { msg, data } = received_msg;
match msg {
MultiplexMsg::Reset => {
return Err(MultiplexError::Reset);
}
MultiplexMsg::Hello { .. } => {
return Err(protocol_err!(
"received Hello message for already established multiplexer connection"
));
}
MultiplexMsg::Ping => (),
MultiplexMsg::OpenPort { client_port, wait } => {
let req = RemoteConnectMsg::Request(Request::new(
client_port,
wait,
self.port_allocator.clone(),
self.channel_tx.clone(),
));
if let Some((listen_wait_tx, listen_no_wait_tx)) = &self.listen_tx {
let res = if wait { listen_wait_tx.try_send(req) } else { listen_no_wait_tx.try_send(req) };
if let Err(mpsc::error::TrySendError::Full(_)) = res {
return Err(protocol_err!("remote endpoint sent too many OpenPort requests"));
}
}
}
MultiplexMsg::PortOpened { client_port, server_port } => {
if let Some((local_port, PortState::Connecting { response_tx })) =
self.ports.remove_entry(&client_port)
{
let (sender, receiver) = self.create_port(local_port, server_port);
let _ = response_tx.send(ConnectResponse::Accepted(sender, receiver));
} else {
return Err(protocol_err!(format!(
"received PortOpened message for port {} not in connecting state",
client_port
)));
}
}
MultiplexMsg::Rejected { client_port, no_ports } => {
if let Some(PortState::Connecting { response_tx }) = self.ports.remove(&client_port) {
let _ = response_tx.send(ConnectResponse::Rejected { no_ports });
} else {
return Err(protocol_err!(format!(
"received Rejected message for port {} not in connecting state",
client_port
)));
}
}
MultiplexMsg::Data { port, first, last } => {
if let Some(PortState::Connected {
receiver_tx_data: Some(receiver_tx_data),
receiver_credit_monitor,
..
}) = self.ports.get_mut(&port)
{
let data = data.unwrap();
let used_credit = match u32::try_from(data.len()) {
Ok(size) if size <= self.local_cfg.chunk_size => {
receiver_credit_monitor.use_credits(size.max(1))?
}
_ => {
return Err(protocol_err!(format!(
"received data exceeds maximum chunk size on port {}",
&port
)))
}
};
let _ = receiver_tx_data.send(PortReceiveMsg::Data(ReceivedData {
buf: data,
first,
last,
credit: used_credit,
}));
} else {
return Err(protocol_err!(format!(
"received data for non-connected or finished local port {}",
&port
)));
}
}
MultiplexMsg::PortData { port, first, last, wait, ports } => {
if let Some(PortState::Connected {
receiver_tx_data: Some(receiver_tx_data),
receiver_credit_monitor,
..
}) = self.ports.get_mut(&port)
{
let used_credit =
match ports.len().checked_mul(size_of::<u32>()).and_then(|v| u32::try_from(v).ok()) {
Some(size) if size <= self.local_cfg.chunk_size => {
receiver_credit_monitor.use_credits(size)?
}
_ => {
return Err(protocol_err!(format!(
"received ports exceeds maximum chunk size on port {}",
&port
)))
}
};
let port_allocator = self.port_allocator.clone();
let channel_tx = self.channel_tx.clone();
let requests = ports
.into_iter()
.map(|remote_port| {
Request::new(remote_port, wait, port_allocator.clone(), channel_tx.clone())
})
.collect();
let _ = receiver_tx_data.send(PortReceiveMsg::PortRequests(ReceivedPortRequests {
requests,
first,
last,
credit: used_credit,
}));
} else {
return Err(protocol_err!(format!(
"received port data for non-connected or finished local port {}",
&port
)));
}
}
MultiplexMsg::PortCredits { port, credits } => {
if let Some(PortState::Connected { sender_credit_provider, .. }) = self.ports.get_mut(&port) {
sender_credit_provider.provide(credits)?;
} else {
return Err(protocol_err!(format!(
"received port credits message for port {} not in connected state",
&port
)));
}
}
MultiplexMsg::SendFinish { port } => {
if let Some(PortState::Connected { receiver_tx_data, .. }) = self.ports.get_mut(&port) {
if let Some(receiver_tx_data) = receiver_tx_data.take() {
let _ = receiver_tx_data.send(PortReceiveMsg::Finished);
self.maybe_free_port(port);
} else {
return Err(protocol_err!(format!(
"received SendFinish message for local port {} more than once",
&port
)));
}
} else {
return Err(protocol_err!(format!(
"received SendFinish message for local port {} not in connected state",
&port
)));
}
}
MultiplexMsg::ReceiveClose { port } => {
if let Some(PortState::Connected {
sender_credit_provider,
remote_receiver_closed_notify,
remote_receiver_closed,
..
}) = self.ports.get_mut(&port)
{
if !remote_receiver_closed.load(Ordering::SeqCst) {
sender_credit_provider.close(true);
remote_receiver_closed.store(true, Ordering::SeqCst);
let notifies = remote_receiver_closed_notify.lock().await.take().unwrap();
for tx in notifies {
let _ = tx.send(());
}
self.maybe_free_port(port);
} else {
return Err(protocol_err!(format!(
"received more than one ReceiveClose message for port {}",
&port
)));
}
} else {
return Err(protocol_err!(format!(
"received ReceiveClose message for port {} not in connected state",
&port
)));
}
}
MultiplexMsg::ReceiveFinish { port } => {
if let Some(PortState::Connected {
sender_credit_provider,
remote_receiver_closed_notify,
remote_receiver_closed,
remote_receiver_dropped,
..
}) = self.ports.get_mut(&port)
{
if !remote_receiver_closed.load(Ordering::SeqCst) {
sender_credit_provider.close(false);
remote_receiver_closed.store(true, Ordering::SeqCst);
let notifies = remote_receiver_closed_notify.lock().await.take().unwrap();
for tx in notifies {
let _ = tx.send(());
}
}
*remote_receiver_dropped = true;
self.maybe_free_port(port);
} else {
return Err(protocol_err!(format!(
"received ReceiveFinish message for port {} not in connected state",
&port
)));
}
}
MultiplexMsg::ClientFinish => {
if let Some((listen_wait_tx, listen_no_wait_tx)) = &self.listen_tx {
let mut failed = false;
if let Err(mpsc::error::TrySendError::Full(_)) =
listen_wait_tx.try_send(RemoteConnectMsg::ClientDropped)
{
failed = true;
}
if let Err(mpsc::error::TrySendError::Full(_)) =
listen_no_wait_tx.try_send(RemoteConnectMsg::ClientDropped)
{
failed = true;
}
if failed {
return Err(protocol_err!(
"remote endpoint sent too many OpenPort or ClientFinish requests"
));
}
}
self.remote_client_dropped = true;
}
MultiplexMsg::ListenerFinish => {
self.remote_listener_dropped.store(true, Ordering::SeqCst);
}
MultiplexMsg::Goodbye => {
self.goodbye_received = true;
}
}
Ok(())
}
}
impl<TransportSink, TransportStream> Drop for Multiplexer<TransportSink, TransportStream> {
fn drop(&mut self) {
}
}
struct SinkReady<S, Item> {
sink: S,
_item: PhantomData<Item>,
}
impl<S, Item> SinkReady<S, Item> {
fn new(sink: S) -> Self {
Self { sink, _item: PhantomData }
}
}
impl<S, Item> Future for SinkReady<S, Item>
where
S: Sink<Item> + Unpin,
Item: Unpin,
{
type Output = Result<(), S::Error>;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
Pin::into_inner(self).sink.poll_ready_unpin(cx)
}
}
fn generate_trace_id() -> String {
lazy_static! {
static ref ID: AtomicU16 = AtomicU16::new(0);
}
let id = ID.fetch_add(1, Ordering::SeqCst);
format!("{:04x}", id)
}