use futures::{ready, Future, FutureExt};
use std::{
clone::Clone,
error::Error,
fmt, mem,
pin::Pin,
sync::{
atomic::{AtomicBool, Ordering},
Arc, Mutex,
},
task::{Context, Poll},
};
use tokio::{
sync::{mpsc, oneshot},
task::JoinHandle,
};
use super::{
port_allocator::{PortAllocator, PortNumber},
receiver::Receiver,
sender::Sender,
};
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum ConnectError {
LocalPortsExhausted,
RemotePortsExhausted,
TooManyPendingConnectionRequests,
Rejected,
Multiplexer,
}
impl fmt::Display for ConnectError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::LocalPortsExhausted => write!(f, "all local ports are in use"),
Self::RemotePortsExhausted => write!(f, "all remote ports are in use"),
Self::TooManyPendingConnectionRequests => write!(f, "too many connection requests are pending"),
Self::Rejected => write!(f, "connection has been rejected by server"),
Self::Multiplexer => write!(f, "multiplexer error"),
}
}
}
impl Error for ConnectError {}
#[derive(Clone)]
struct ConntectRequestCrediter(Arc<Mutex<ConntectRequestCrediterInner>>);
struct ConntectRequestCrediterInner {
limit: u16,
used: u16,
notify_tx: Vec<oneshot::Sender<()>>,
}
impl ConntectRequestCrediter {
pub fn new(limit: u16) -> Self {
let inner = ConntectRequestCrediterInner { limit, used: 0, notify_tx: Vec::new() };
Self(Arc::new(Mutex::new(inner)))
}
pub async fn request(&self) -> ConnectRequestCredit {
loop {
let rx = {
let mut inner = self.0.lock().unwrap();
if inner.used < inner.limit {
inner.used += 1;
return ConnectRequestCredit(self.0.clone());
} else {
let (tx, rx) = oneshot::channel();
inner.notify_tx.push(tx);
rx
}
};
let _ = rx.await;
}
}
pub fn try_request(&self) -> Option<ConnectRequestCredit> {
let mut inner = self.0.lock().unwrap();
if inner.used < inner.limit {
inner.used += 1;
Some(ConnectRequestCredit(self.0.clone()))
} else {
None
}
}
}
pub(crate) struct ConnectRequestCredit(Arc<Mutex<ConntectRequestCrediterInner>>);
impl Drop for ConnectRequestCredit {
fn drop(&mut self) {
let notify_tx = {
let mut inner = self.0.lock().unwrap();
inner.used -= 1;
mem::take(&mut inner.notify_tx)
};
for tx in notify_tx {
let _ = tx.send(());
}
}
}
#[derive(Debug)]
pub(crate) struct ConnectRequest {
pub local_port: PortNumber,
pub sent_tx: mpsc::Sender<()>,
pub response_tx: oneshot::Sender<ConnectResponse>,
pub wait: bool,
}
#[derive(Debug)]
pub(crate) enum ConnectResponse {
Accepted(Sender, Receiver),
Rejected {
no_ports: bool,
},
}
pub struct Connect {
pub(crate) sent_rx: mpsc::Receiver<()>,
pub(crate) response: JoinHandle<Result<(Sender, Receiver), ConnectError>>,
}
impl Connect {
pub async fn sent(&mut self) {
let _ = self.sent_rx.recv().await;
}
}
impl Future for Connect {
type Output = Result<(Sender, Receiver), ConnectError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let result = ready!(Pin::into_inner(self).response.poll_unpin(cx));
Poll::Ready(result.map_err(|_| ConnectError::Multiplexer)?)
}
}
#[derive(Clone)]
pub struct Client {
tx: mpsc::UnboundedSender<ConnectRequest>,
crediter: ConntectRequestCrediter,
port_allocator: PortAllocator,
listener_dropped: Arc<AtomicBool>,
terminate_tx: mpsc::UnboundedSender<()>,
}
impl fmt::Debug for Client {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("Client").field("port_allocator", &self.port_allocator).finish()
}
}
impl Client {
pub(crate) fn new(
tx: mpsc::UnboundedSender<ConnectRequest>, limit: u16, port_allocator: PortAllocator,
listener_dropped: Arc<AtomicBool>, terminate_tx: mpsc::UnboundedSender<()>,
) -> Client {
Client {
tx,
crediter: ConntectRequestCrediter::new(limit),
port_allocator,
listener_dropped,
terminate_tx,
}
}
pub fn port_allocator(&self) -> PortAllocator {
self.port_allocator.clone()
}
pub async fn connect(&self) -> Result<(Sender, Receiver), ConnectError> {
self.connect_ext(None, true).await?.await
}
pub async fn connect_ext(&self, local_port: Option<PortNumber>, wait: bool) -> Result<Connect, ConnectError> {
let local_port = match local_port {
Some(local_port) => local_port,
None => {
if wait {
self.port_allocator.allocate().await
} else {
self.port_allocator.try_allocate().ok_or(ConnectError::LocalPortsExhausted)?
}
}
};
let credit = if wait {
self.crediter.request().await
} else {
match self.crediter.try_request() {
Some(credit) => credit,
None => return Err(ConnectError::TooManyPendingConnectionRequests),
}
};
let (sent_tx, sent_rx) = mpsc::channel(1);
let (response_tx, response_rx) = oneshot::channel();
let req = ConnectRequest { local_port, sent_tx, response_tx, wait };
let _ = self.tx.send(req);
let listener_dropped = self.listener_dropped.clone();
let response = tokio::spawn(async move {
let _credit = credit;
match response_rx.await {
Ok(ConnectResponse::Accepted(sender, receiver)) => Ok((sender, receiver)),
Ok(ConnectResponse::Rejected { no_ports }) => {
if no_ports {
Err(ConnectError::RemotePortsExhausted)
} else {
Err(ConnectError::Rejected)
}
}
Err(_) => {
if listener_dropped.load(Ordering::SeqCst) {
Err(ConnectError::Rejected)
} else {
Err(ConnectError::Multiplexer)
}
}
}
});
Ok(Connect { sent_rx, response })
}
pub fn terminate(&self) {
let _ = self.terminate_tx.send(());
}
}