use bytes::Bytes;
use futures::{future, future::BoxFuture, FutureExt, Sink, Stream};
use std::{
collections::{hash_map::Entry, HashMap},
fmt,
future::IntoFuture,
io,
sync::{Arc, Mutex},
};
use tokio::{
io::{AsyncRead, AsyncWrite},
sync::{mpsc, oneshot},
};
use x25519_dalek::{PublicKey, StaticSecret};
use crate::{
agg::{link_int::LinkInt, task::Task, AggParts},
alc::Channel,
cfg::{Cfg, ExchangedCfg},
control::{Control, Direction, Link},
exec::time::{error::Elapsed, timeout, Instant},
id::{ConnId, OwnedConnId, ServerId},
io::{IoRx, IoTx},
msg::{LinkMsg, RefusedReason},
protocol_err,
};
#[derive(Debug)]
pub enum ListenError {
AlreadyListening,
}
impl fmt::Display for ListenError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::AlreadyListening => write!(f, "already listening"),
}
}
}
impl std::error::Error for ListenError {}
impl From<ListenError> for io::Error {
fn from(err: ListenError) -> Self {
io::Error::new(io::ErrorKind::AddrInUse, err)
}
}
#[derive(Debug)]
pub enum IncomingError {
Io(io::Error),
Refused,
NotListening,
Closed,
ServerDropped,
}
impl fmt::Display for IncomingError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::Io(err) => write!(f, "IO error: {err}"),
Self::Refused => write!(f, "connection refused"),
Self::NotListening => write!(f, "not listening"),
Self::Closed => write!(f, "connection was closed"),
Self::ServerDropped => write!(f, "server dropped"),
}
}
}
impl From<io::Error> for IncomingError {
fn from(err: io::Error) -> Self {
Self::Io(err)
}
}
impl From<Elapsed> for IncomingError {
fn from(err: Elapsed) -> Self {
Self::Io(err.into())
}
}
impl std::error::Error for IncomingError {}
impl From<IncomingError> for io::Error {
fn from(err: IncomingError) -> Self {
match err {
IncomingError::Io(err) => err,
IncomingError::Refused => io::Error::new(io::ErrorKind::ConnectionRefused, err),
IncomingError::NotListening => io::Error::new(io::ErrorKind::ConnectionRefused, err),
IncomingError::Closed => io::Error::new(io::ErrorKind::ConnectionAborted, err),
IncomingError::ServerDropped => io::Error::new(io::ErrorKind::ConnectionRefused, err),
}
}
}
#[derive(Debug)]
pub enum ConnectError {
Timeout,
}
impl fmt::Display for ConnectError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
ConnectError::Timeout => write!(f, "connect timeout"),
}
}
}
impl std::error::Error for ConnectError {}
impl From<ConnectError> for io::Error {
fn from(err: ConnectError) -> Self {
io::Error::new(io::ErrorKind::TimedOut, err)
}
}
pub struct Incoming<TX, RX, TAG> {
cfg: Arc<Cfg>,
conn_id: OwnedConnId,
server_id: ServerId,
remote_server_id: Option<ServerId>,
link_tx: mpsc::Sender<LinkInt<TX, RX, TAG>>,
link_rx: mpsc::Receiver<LinkInt<TX, RX, TAG>>,
links: Vec<LinkInt<TX, RX, TAG>>,
}
impl<TX, RX, TAG> fmt::Debug for Incoming<TX, RX, TAG>
where
TAG: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let link_tags: Vec<_> = self.links.iter().map(|link| link.tag()).collect();
f.debug_struct("Incoming")
.field("cfg", &self.cfg)
.field("id", &self.id())
.field("server_id", &self.server_id)
.field("remote_server_id", &self.remote_server_id)
.field("link_tags", &link_tags)
.finish()
}
}
impl<TX, RX, TAG> Incoming<TX, RX, TAG> {
pub fn id(&self) -> ConnId {
self.conn_id.get()
}
pub fn server_id(&self) -> ServerId {
self.server_id
}
pub fn remote_server_id(&self) -> Option<ServerId> {
self.remote_server_id
}
fn update_links(&mut self) {
while let Ok(link_int) = self.link_rx.try_recv() {
self.links.push(link_int);
}
}
pub fn link_tags(&mut self) -> Vec<&TAG> {
self.update_links();
self.links.iter().map(|link| link.tag()).collect()
}
pub fn link_remote_user_datas(&mut self) -> Vec<&[u8]> {
self.update_links();
self.links.iter().map(|link| link.remote_user_data()).collect()
}
pub async fn link_added(&mut self) -> Result<(), IncomingError> {
let link_int = self.link_rx.recv().await.ok_or(IncomingError::ServerDropped)?;
self.links.push(link_int);
Ok(())
}
}
impl<TX, RX, TAG> Incoming<TX, RX, TAG>
where
RX: Stream<Item = Result<Bytes, io::Error>> + Unpin + Send + 'static,
TX: Sink<Bytes, Error = io::Error> + Unpin + Send + 'static,
TAG: Send + Sync + 'static,
{
pub fn accept(mut self) -> (Task<TX, RX, TAG>, Channel, Control<TX, RX, TAG>) {
self.update_links();
let Self { cfg, conn_id, server_id, remote_server_id, link_tx, link_rx, links } = self;
let AggParts { task, channel, control, connected_rx: _ } = AggParts::new(
cfg,
conn_id,
Direction::Incoming,
Some(server_id),
remote_server_id,
links,
Some((link_tx, link_rx)),
);
(task, channel, control)
}
pub async fn refuse(mut self) {
self.link_rx.close();
self.update_links();
let send_refused = future::join_all(self.links.iter_mut().map(|link| async move {
let _ = link.send_msg_and_flush(LinkMsg::Refused { reason: RefusedReason::ConnectionRefused }).await;
}));
let _ = timeout(self.cfg.link_non_working_timeout, send_refused).await;
}
}
struct ServerInner<TX, RX, TAG> {
cfg: Arc<Cfg>,
server_id: ServerId,
conns: HashMap<ConnId, mpsc::Sender<LinkInt<TX, RX, TAG>>>,
closed_conns_tx: mpsc::UnboundedSender<ConnId>,
closed_conns_rx: mpsc::UnboundedReceiver<ConnId>,
listen_tx: mpsc::Sender<Incoming<TX, RX, TAG>>,
}
impl<TX, RX, TAG> ServerInner<TX, RX, TAG> {
fn new(cfg: Arc<Cfg>, server_id: ServerId) -> Self {
let (closed_conns_tx, closed_conns_rx) = mpsc::unbounded_channel();
let listen_tx = mpsc::channel(cfg.connect_queue.get()).0;
Self { cfg, server_id, conns: HashMap::new(), closed_conns_tx, closed_conns_rx, listen_tx }
}
fn cleanup_links(&mut self) {
while let Ok(id) = self.closed_conns_rx.try_recv() {
self.conns.remove(&id);
}
}
}
pub struct Server<TX, RX, TAG> {
server_id: ServerId,
inner: Arc<Mutex<ServerInner<TX, RX, TAG>>>,
}
impl<TX, RX, TAG> fmt::Debug for Server<TX, RX, TAG> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("Server").field("id", &self.server_id).finish()
}
}
impl<TX, RX, TAG> Clone for Server<TX, RX, TAG> {
fn clone(&self) -> Self {
Self { server_id: self.server_id, inner: self.inner.clone() }
}
}
impl<TX, RX, TAG> Server<TX, RX, TAG>
where
TAG: Send + Sync + 'static,
RX: Stream<Item = Result<Bytes, io::Error>> + Unpin + Send + 'static,
TX: Sink<Bytes, Error = io::Error> + Unpin + Send + 'static,
{
pub fn new(cfg: Cfg) -> Self {
let server_id = ServerId::generate();
Self { server_id, inner: Arc::new(Mutex::new(ServerInner::new(Arc::new(cfg), server_id))) }
}
pub fn id(&self) -> ServerId {
self.server_id
}
pub fn connect(&self) -> (Task<TX, RX, TAG>, Outgoing, Control<TX, RX, TAG>) {
let mut inner = self.inner.lock().unwrap();
let conn_id = ConnId::generate();
let (link_tx, link_rx) = mpsc::channel(inner.cfg.connect_queue.get());
let AggParts { task, channel, control, connected_rx } = AggParts::new(
inner.cfg.clone(),
OwnedConnId::new(conn_id, inner.closed_conns_tx.clone()),
Direction::Outgoing,
Some(self.server_id),
None,
Vec::new(),
Some((link_tx.clone(), link_rx)),
);
inner.conns.insert(conn_id, link_tx);
(task, Outgoing { channel, connected_rx }, control)
}
pub fn listen(&self) -> Result<Listener<TX, RX, TAG>, ListenError> {
let mut inner = self.inner.lock().unwrap();
if !inner.listen_tx.is_closed() {
return Err(ListenError::AlreadyListening);
}
let (listen_tx, listen_rx) = mpsc::channel(inner.cfg.connect_queue.get());
inner.listen_tx = listen_tx;
Ok(Listener { server_id: inner.server_id, listen_rx })
}
pub async fn add_incoming(
&self, mut tx: TX, mut rx: RX, tag: TAG, user_data: &[u8],
) -> Result<Link<TAG>, IncomingError> {
assert!(user_data.len() <= u16::MAX as usize, "user_data is too big");
let server_id;
let cfg;
let closed_conns_tx;
{
let mut inner = self.inner.lock().unwrap();
inner.cleanup_links();
server_id = inner.server_id;
cfg = inner.cfg.clone();
closed_conns_tx = inner.closed_conns_tx.clone();
}
let (remote_server_id, conn_id, existing, remote_cfg, roundtrip, remote_user_data) =
timeout(cfg.link_ping_timeout, async {
let random: [u8; 32] = rand::random();
let server_secret = StaticSecret::from(random);
let server_public_key = PublicKey::from(&server_secret);
let start = Instant::now();
LinkMsg::Welcome {
extensions: 0,
public_key: server_public_key,
server_id,
user_data: user_data.to_vec(),
cfg: (&*cfg).into(),
}
.send(&mut tx)
.await?;
let LinkMsg::Connect {
extensions: _,
public_key: client_public_key,
server_id,
connection_id: encrypted_conn_id,
existing_connection,
user_data: remote_user_data,
cfg,
} = LinkMsg::recv(&mut rx).await?
else {
return Err::<_, IncomingError>(protocol_err!("expected Connect message").into());
};
let shared_secret = server_secret.diffie_hellman(&client_public_key);
let conn_id = encrypted_conn_id.decrypt(&shared_secret);
Ok((server_id, conn_id, existing_connection, cfg, start.elapsed(), remote_user_data))
})
.await??;
tracing::debug!(?server_id, ?conn_id, ?existing, "handling incoming link");
enum Connection<TX, RX, TAG> {
Existing {
link_tx: mpsc::Sender<LinkInt<TX, RX, TAG>>,
},
New {
link_tx: mpsc::Sender<LinkInt<TX, RX, TAG>>,
link_rx: mpsc::Receiver<LinkInt<TX, RX, TAG>>,
listen_tx_permit: mpsc::OwnedPermit<Incoming<TX, RX, TAG>>,
},
Refuse {
reason: RefusedReason,
err: IncomingError,
},
}
let mut need_listen_tx_permit = false;
let connection = loop {
let listen_tx_permit = if need_listen_tx_permit {
let listen_tx = self.inner.lock().unwrap().listen_tx.clone();
Some(listen_tx.reserve_owned().await)
} else {
None
};
let mut inner = self.inner.lock().unwrap();
match inner.conns.entry(conn_id) {
Entry::Occupied(ocu) => break Connection::Existing { link_tx: ocu.get().clone() },
Entry::Vacant(vac) if !existing => match listen_tx_permit {
Some(Ok(listen_tx_permit)) => {
let (link_tx, link_rx) = mpsc::channel(cfg.connect_queue.get());
vac.insert(link_tx.clone());
break Connection::New { link_tx, link_rx, listen_tx_permit };
}
Some(Err(_)) => {
break Connection::Refuse {
reason: RefusedReason::NotListening,
err: IncomingError::NotListening,
}
}
None => need_listen_tx_permit = true,
},
Entry::Vacant(_) => {
break Connection::Refuse { reason: RefusedReason::Closed, err: IncomingError::Closed }
}
}
};
match connection {
Connection::Existing { link_tx } => match link_tx.reserve_owned().await {
Ok(link_tx_permit) => {
let link_int = LinkInt::new(
tag,
conn_id,
tx,
rx,
cfg,
remote_cfg,
Direction::Incoming,
roundtrip,
remote_user_data,
);
let link = Link::from(&link_int);
link_tx_permit.send(link_int);
tracing::debug!(?conn_id, "link joins existing connection");
Ok(link)
}
Err(_) => {
tracing::debug!("refusing link that belongs to closed connection");
timeout(
cfg.link_ping_timeout,
LinkMsg::Refused { reason: RefusedReason::Closed }.send(&mut tx),
)
.await??;
Err(IncomingError::Closed)
}
},
Connection::New { link_tx, link_rx, listen_tx_permit } => {
let link_int = LinkInt::new(
tag,
conn_id,
tx,
rx,
cfg.clone(),
remote_cfg,
Direction::Incoming,
roundtrip,
remote_user_data,
);
let link = Link::from(&link_int);
link_tx.try_send(link_int).unwrap();
listen_tx_permit.send(Incoming {
cfg,
conn_id: OwnedConnId::new(conn_id, closed_conns_tx),
server_id: self.server_id,
remote_server_id,
link_tx,
link_rx,
links: Vec::new(),
});
tracing::debug!(?conn_id, "link starts new connection");
Ok(link)
}
Connection::Refuse { reason, err } => {
tracing::debug!(?reason, %err, "refusing link");
timeout(cfg.link_ping_timeout, LinkMsg::Refused { reason }.send(&mut tx)).await??;
Err(err)
}
}
}
}
impl<R, W, TAG> Server<IoTx<W>, IoRx<R>, TAG>
where
R: AsyncRead + Send + Unpin + 'static,
W: AsyncWrite + Send + Unpin + 'static,
TAG: Send + Sync + 'static,
{
pub async fn add_incoming_io(
&self, read: R, write: W, tag: TAG, user_data: &[u8],
) -> Result<Link<TAG>, IncomingError> {
self.add_incoming(IoTx::new(write), IoRx::new(read), tag, user_data).await
}
}
pub struct Listener<TX, RX, TAG> {
server_id: ServerId,
listen_rx: mpsc::Receiver<Incoming<TX, RX, TAG>>,
}
impl<N, R, W> fmt::Debug for Listener<N, R, W> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("Listener").field("server_id", &self.server_id).finish()
}
}
impl<TX, RX, TAG> Listener<TX, RX, TAG>
where
TAG: Send + Sync + 'static,
RX: Stream<Item = Result<Bytes, io::Error>> + Unpin + Send + 'static,
TX: Sink<Bytes, Error = io::Error> + Unpin + Send + 'static,
{
pub fn id(&self) -> ServerId {
self.server_id
}
pub async fn next(&mut self) -> Result<Incoming<TX, RX, TAG>, IncomingError> {
self.listen_rx.recv().await.ok_or(IncomingError::ServerDropped)
}
pub async fn accept(&mut self) -> Result<(Task<TX, RX, TAG>, Channel, Control<TX, RX, TAG>), IncomingError> {
let ic = self.next().await?;
Ok(ic.accept())
}
}
pub struct Outgoing {
channel: Channel,
connected_rx: oneshot::Receiver<Arc<ExchangedCfg>>,
}
impl fmt::Debug for Outgoing {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("Outgoing").field("id", &self.id()).finish()
}
}
impl Outgoing {
pub fn id(&self) -> ConnId {
self.channel.id()
}
pub async fn connect(self) -> Result<Channel, ConnectError> {
let Self { mut channel, connected_rx } = self;
let remote_cfg = connected_rx.await.map_err(|_| ConnectError::Timeout)?;
channel.set_remote_cfg(remote_cfg);
Ok(channel)
}
}
impl IntoFuture for Outgoing {
type Output = Result<Channel, ConnectError>;
type IntoFuture = BoxFuture<'static, Result<Channel, ConnectError>>;
fn into_future(self) -> Self::IntoFuture {
self.connect().boxed()
}
}
pub fn connect<TX, RX, TAG>(cfg: Cfg) -> (Task<TX, RX, TAG>, Outgoing, Control<TX, RX, TAG>)
where
RX: Stream<Item = Result<Bytes, io::Error>> + Unpin + Send + 'static,
TX: Sink<Bytes, Error = io::Error> + Unpin + Send + 'static,
TAG: Send + Sync + 'static,
{
let AggParts { task, channel, control, connected_rx } = AggParts::new(
Arc::new(cfg),
OwnedConnId::untracked(ConnId::generate()),
Direction::Outgoing,
None,
None,
Vec::new(),
None,
);
(task, Outgoing { channel, connected_rx }, control)
}