use futures::{
future::BoxFuture,
ready,
stream::Stream,
task::{Context, Poll},
FutureExt,
};
use std::{error::Error, fmt, pin::Pin, sync::Arc};
use tokio::sync::{mpsc, oneshot, Mutex};
use super::{
multiplexer::PortEvt,
port_allocator::{PortAllocator, PortNumber},
receiver::Receiver,
sender::Sender,
};
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum ListenerError {
LocalPortsExhausted,
MultiplexerError,
}
impl fmt::Display for ListenerError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::LocalPortsExhausted => write!(f, "all local ports are in use"),
Self::MultiplexerError => write!(f, "multiplexer error"),
}
}
}
impl Error for ListenerError {}
pub struct Request {
remote_port: u32,
wait: bool,
allocator: PortAllocator,
tx: mpsc::Sender<PortEvt>,
done_tx: Option<oneshot::Sender<()>>,
}
impl Request {
pub(crate) fn new(remote_port: u32, wait: bool, allocator: PortAllocator, tx: mpsc::Sender<PortEvt>) -> Self {
let (done_tx, done_rx) = oneshot::channel();
let drop_tx = tx.clone();
tokio::spawn(async move {
if done_rx.await.is_err() {
let _ = drop_tx.send(PortEvt::Rejected { remote_port, no_ports: false }).await;
}
});
Self { remote_port, wait, allocator, tx, done_tx: Some(done_tx) }
}
pub fn remote_port(&self) -> u32 {
self.remote_port
}
pub fn is_wait(&self) -> bool {
self.wait
}
pub async fn accept(self) -> Result<(Sender, Receiver), ListenerError> {
let local_port = if self.wait {
self.allocator.allocate().await
} else {
match self.allocator.try_allocate() {
Some(local_port) => local_port,
None => {
self.reject(true).await;
return Err(ListenerError::LocalPortsExhausted);
}
}
};
self.accept_from(local_port).await
}
pub async fn accept_from(mut self, local_port: PortNumber) -> Result<(Sender, Receiver), ListenerError> {
let (port_tx, port_rx) = oneshot::channel();
let _ = self.tx.send(PortEvt::Accepted { local_port, remote_port: self.remote_port, port_tx }).await;
let _ = self.done_tx.take().unwrap().send(());
port_rx.await.map_err(|_| ListenerError::MultiplexerError)
}
pub async fn reject(mut self, no_ports: bool) {
let _ = self.tx.send(PortEvt::Rejected { remote_port: self.remote_port, no_ports }).await;
let _ = self.done_tx.take().unwrap().send(());
}
}
impl Drop for Request {
fn drop(&mut self) {
}
}
pub(crate) enum RemoteConnectMsg {
Request(Request),
ClientDropped,
}
pub struct Listener {
wait_rx: mpsc::Receiver<RemoteConnectMsg>,
no_wait_rx: mpsc::Receiver<RemoteConnectMsg>,
port_allocator: PortAllocator,
terminate_tx: mpsc::UnboundedSender<()>,
closed: bool,
}
impl fmt::Debug for Listener {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("Listener").field("port_allocator", &self.port_allocator).finish()
}
}
impl Listener {
pub(crate) fn new(
wait_rx: mpsc::Receiver<RemoteConnectMsg>, no_wait_rx: mpsc::Receiver<RemoteConnectMsg>,
port_allocator: PortAllocator, terminate_tx: mpsc::UnboundedSender<()>,
) -> Self {
Self { wait_rx, no_wait_rx, port_allocator, terminate_tx, closed: false }
}
pub fn port_allocator(&self) -> PortAllocator {
self.port_allocator.clone()
}
pub async fn accept(&mut self) -> Result<Option<(Sender, Receiver)>, ListenerError> {
if self.closed {
return Ok(None);
}
loop {
tokio::select! {
local_port = self.port_allocator.allocate() => {
match self.inspect().await? {
Some(req) => break Ok(Some(req.accept_from(local_port).await?)),
None => break Ok(None),
}
},
no_wait_req_opt = self.no_wait_rx.recv() => {
match no_wait_req_opt {
Some(RemoteConnectMsg::Request(no_wait_req)) => {
match self.port_allocator.try_allocate() {
Some(local_port) => break Ok(Some(no_wait_req.accept_from(local_port).await?)),
None => no_wait_req.reject(true).await,
}
},
Some(RemoteConnectMsg::ClientDropped) => {
self.closed = true;
break Ok(None);
},
None => break Err(ListenerError::MultiplexerError),
}
},
}
}
}
pub async fn inspect(&mut self) -> Result<Option<Request>, ListenerError> {
if self.closed {
return Ok(None);
}
let req_opt = tokio::select! {
req_opt = self.wait_rx.recv() => req_opt,
req_opt = self.no_wait_rx.recv() => req_opt,
};
match req_opt {
Some(RemoteConnectMsg::Request(req)) => Ok(Some(req)),
Some(RemoteConnectMsg::ClientDropped) => {
self.closed = true;
Ok(None)
}
None => Err(ListenerError::MultiplexerError),
}
}
pub fn into_stream(self) -> ListenerStream {
ListenerStream::new(self)
}
pub fn terminate(&self) {
let _ = self.terminate_tx.send(());
}
}
impl Drop for Listener {
fn drop(&mut self) {
}
}
pub struct ListenerStream {
server: Arc<Mutex<Listener>>,
#[allow(clippy::type_complexity)]
accept_fut: Option<BoxFuture<'static, Option<Result<(Sender, Receiver), ListenerError>>>>,
}
impl ListenerStream {
fn new(server: Listener) -> Self {
Self { server: Arc::new(Mutex::new(server)), accept_fut: None }
}
async fn accept(server: Arc<Mutex<Listener>>) -> Option<Result<(Sender, Receiver), ListenerError>> {
let mut server = server.lock().await;
server.accept().await.transpose()
}
fn poll_next(&mut self, cx: &mut Context) -> Poll<Option<Result<(Sender, Receiver), ListenerError>>> {
if self.accept_fut.is_none() {
self.accept_fut = Some(Self::accept(self.server.clone()).boxed());
}
let accept_fut = self.accept_fut.as_mut().unwrap();
let res = ready!(accept_fut.as_mut().poll(cx));
self.accept_fut = None;
Poll::Ready(res)
}
}
impl Stream for ListenerStream {
type Item = Result<(Sender, Receiver), ListenerError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
Pin::into_inner(self).poll_next(cx)
}
}