use std::{
collections::VecDeque,
fmt::Debug,
future::poll_fn,
io,
mem::ManuallyDrop,
net::{SocketAddr, SocketAddrV6},
ops::Deref,
pin::pin,
ptr,
sync::Arc,
task::{Context, Poll, Waker},
time::Instant,
};
use compio::buf::{BufResult, bytes::Bytes};
#[cfg(rustls)]
use compio::net::ToSocketAddrsAsync;
use compio::net::UdpSocket;
use compio::runtime::JoinHandle;
use compio_log::{Instrument, error};
use flume::{Receiver, Sender, unbounded};
use futures_util::{FutureExt, StreamExt, future, select, task::AtomicWaker};
use noq_proto::{
ClientConfig, ConnectError, ConnectionError, ConnectionHandle, DatagramEvent, EndpointConfig,
EndpointEvent, FourTuple, ServerConfig, Transmit, VarInt,
};
use rustc_hash::FxHashMap as HashMap;
use crate::{
Connecting, ConnectionEvent, Incoming, RecvMeta, Socket,
sync::{mutex_blocking::Mutex, shared::Shared},
};
#[derive(Debug)]
struct EndpointState {
endpoint: noq_proto::Endpoint,
worker: Option<JoinHandle<()>>,
connections: HashMap<ConnectionHandle, Sender<ConnectionEvent>>,
close: Option<(VarInt, Bytes)>,
exit_on_idle: bool,
incoming: VecDeque<noq_proto::Incoming>,
incoming_wakers: VecDeque<Waker>,
stats: EndpointStats,
}
#[non_exhaustive]
#[derive(Debug, Default, Copy, Clone)]
pub struct EndpointStats {
pub accepted_handshakes: u64,
pub outgoing_handshakes: u64,
pub refused_handshakes: u64,
pub ignored_handshakes: u64,
}
impl EndpointState {
fn handle_data(&mut self, meta: RecvMeta, buf: &[u8], respond_fn: impl Fn(Vec<u8>, Transmit)) {
let now = Instant::now();
for data in buf[..meta.len]
.chunks(meta.stride.min(meta.len))
.map(Into::into)
{
let mut resp_buf = Vec::new();
match self.endpoint.handle(
now,
FourTuple::new(meta.remote, meta.local_ip),
meta.ecn,
data,
&mut resp_buf,
) {
Some(DatagramEvent::NewConnection(incoming)) => {
if self.close.is_none() {
self.incoming.push_back(incoming);
} else {
let transmit = self.endpoint.refuse(incoming, &mut resp_buf);
respond_fn(resp_buf, transmit);
}
}
Some(DatagramEvent::ConnectionEvent(ch, event)) => {
let _ = self
.connections
.get(&ch)
.unwrap()
.send(ConnectionEvent::Proto(event));
}
Some(DatagramEvent::Response(transmit)) => respond_fn(resp_buf, transmit),
None => {}
}
}
}
fn handle_event(&mut self, ch: ConnectionHandle, event: EndpointEvent) {
if event.is_drained() {
self.connections.remove(&ch);
}
if let Some(event) = self.endpoint.handle_event(ch, event) {
let _ = self
.connections
.get(&ch)
.unwrap()
.send(ConnectionEvent::Proto(event));
}
}
fn is_idle(&self) -> bool {
self.connections.is_empty()
}
fn poll_incoming(&mut self, cx: &mut Context) -> Poll<Option<noq_proto::Incoming>> {
if self.close.is_none() {
if let Some(incoming) = self.incoming.pop_front() {
Poll::Ready(Some(incoming))
} else {
self.incoming_wakers.push_back(cx.waker().clone());
Poll::Pending
}
} else {
Poll::Ready(None)
}
}
fn new_connection(
&mut self,
handle: ConnectionHandle,
conn: noq_proto::Connection,
socket: Socket,
events_tx: Sender<(ConnectionHandle, EndpointEvent)>,
) -> Connecting {
let (tx, rx) = unbounded();
if let Some((error_code, reason)) = &self.close {
tx.send(ConnectionEvent::Close(*error_code, reason.clone()))
.unwrap();
}
self.connections.insert(handle, tx);
Connecting::new(handle, conn, socket, events_tx, rx)
}
}
impl Drop for EndpointState {
fn drop(&mut self) {
for incoming in self.incoming.drain(..) {
self.endpoint.ignore(incoming);
}
}
}
type ChannelPair<T> = (Sender<T>, Receiver<T>);
#[derive(Debug)]
pub(crate) struct EndpointInner {
state: Mutex<EndpointState>,
socket: Socket,
ipv6: bool,
events: ChannelPair<(ConnectionHandle, EndpointEvent)>,
done: AtomicWaker,
}
impl EndpointInner {
fn new(
socket: UdpSocket,
config: EndpointConfig,
server_config: Option<ServerConfig>,
) -> io::Result<Self> {
let socket = Socket::new(socket)?;
let ipv6 = socket.local_addr()?.is_ipv6();
let allow_mtud = !socket.may_fragment();
Ok(Self {
state: Mutex::new(EndpointState {
endpoint: noq_proto::Endpoint::new(
Arc::new(config),
server_config.map(Arc::new),
allow_mtud,
),
worker: None,
connections: HashMap::default(),
close: None,
exit_on_idle: false,
incoming: VecDeque::new(),
incoming_wakers: VecDeque::new(),
stats: EndpointStats::default(),
}),
socket,
ipv6,
events: unbounded(),
done: AtomicWaker::new(),
})
}
fn connect(
&self,
remote: SocketAddr,
server_name: &str,
config: ClientConfig,
) -> Result<Connecting, ConnectError> {
let mut state = self.state.lock();
if state.worker.is_none() {
return Err(ConnectError::EndpointStopping);
}
if remote.is_ipv6() && !self.ipv6 {
return Err(ConnectError::InvalidRemoteAddress(remote));
}
let remote = if self.ipv6 {
SocketAddr::V6(match remote {
SocketAddr::V4(addr) => {
SocketAddrV6::new(addr.ip().to_ipv6_mapped(), addr.port(), 0, 0)
}
SocketAddr::V6(addr) => addr,
})
} else {
remote
};
let (handle, conn) = state
.endpoint
.connect(Instant::now(), config, remote, server_name)?;
state.stats.outgoing_handshakes += 1;
Ok(state.new_connection(handle, conn, self.socket.clone(), self.events.0.clone()))
}
fn respond(&self, buf: Vec<u8>, transmit: Transmit) {
let socket = self.socket.clone();
compio::runtime::spawn(async move {
socket.send(buf, &transmit).await;
})
.detach();
}
pub(crate) fn accept(
&self,
incoming: noq_proto::Incoming,
server_config: Option<ServerConfig>,
) -> Result<Connecting, ConnectionError> {
let mut state = self.state.lock();
let mut resp_buf = Vec::new();
let now = Instant::now();
match state
.endpoint
.accept(incoming, now, &mut resp_buf, server_config.map(Arc::new))
{
Ok((handle, conn)) => {
state.stats.accepted_handshakes += 1;
Ok(state.new_connection(handle, conn, self.socket.clone(), self.events.0.clone()))
}
Err(err) => {
if let Some(transmit) = err.response {
self.respond(resp_buf, transmit);
}
Err(err.cause)
}
}
}
pub(crate) fn refuse(&self, incoming: noq_proto::Incoming) {
let mut state = self.state.lock();
state.stats.refused_handshakes += 1;
let mut resp_buf = Vec::new();
let transmit = state.endpoint.refuse(incoming, &mut resp_buf);
self.respond(resp_buf, transmit);
}
#[allow(clippy::result_large_err)]
pub(crate) fn retry(&self, incoming: noq_proto::Incoming) -> Result<(), noq_proto::RetryError> {
let mut state = self.state.lock();
let mut resp_buf = Vec::new();
let transmit = state.endpoint.retry(incoming, &mut resp_buf)?;
self.respond(resp_buf, transmit);
Ok(())
}
pub(crate) fn ignore(&self, incoming: noq_proto::Incoming) {
let mut state = self.state.lock();
state.stats.ignored_handshakes += 1;
state.endpoint.ignore(incoming);
}
async fn run(&self) -> io::Result<()> {
let respond_fn = |buf: Vec<u8>, transmit: Transmit| self.respond(buf, transmit);
let mut recv_fut = pin!(
self.socket
.recv(Vec::with_capacity(
self.state
.lock()
.endpoint
.config()
.get_max_udp_payload_size()
.min(64 * 1024) as usize
* self.socket.max_gro_segments(),
))
.fuse()
);
let mut event_stream = self.events.1.stream().ready_chunks(100);
loop {
let mut state = select! {
BufResult(res, recv_buf) = recv_fut => {
let mut state = self.state.lock();
match res {
Ok(meta) => state.handle_data(meta, &recv_buf, respond_fn),
Err(e) if e.kind() == io::ErrorKind::ConnectionReset => {}
#[cfg(windows)]
Err(e) if e.raw_os_error() == Some(windows_sys::Win32::Foundation::ERROR_PORT_UNREACHABLE as _) => {}
Err(e) => break Err(e),
}
recv_fut.set(self.socket.recv(recv_buf).fuse());
state
},
events = event_stream.select_next_some() => {
let mut state = self.state.lock();
for (ch, event) in events {
state.handle_event(ch, event);
}
state
},
};
if state.exit_on_idle && state.is_idle() {
break Ok(());
}
if !state.incoming.is_empty() {
let n = state.incoming.len().min(state.incoming_wakers.len());
state.incoming_wakers.drain(..n).for_each(Waker::wake);
}
}
}
}
#[derive(Debug, Clone)]
pub(crate) struct EndpointRef(Shared<EndpointInner>);
impl EndpointRef {
fn into_inner(self) -> Shared<EndpointInner> {
let this = ManuallyDrop::new(self);
unsafe { ptr::read(&this.0) }
}
async fn shutdown(self) -> io::Result<()> {
let (worker, idle) = {
let mut state = self.0.state.lock();
let idle = state.is_idle();
if !idle {
state.exit_on_idle = true;
}
(state.worker.take(), idle)
};
if let Some(worker) = worker {
if idle {
worker.cancel().await;
} else {
_ = worker.await;
}
}
let mut this = Some(self.into_inner());
let inner = poll_fn(move |cx| {
let s = match Shared::try_unwrap(this.take().unwrap()) {
Ok(inner) => return Poll::Ready(inner),
Err(s) => s,
};
s.done.register(cx.waker());
match Shared::try_unwrap(s) {
Ok(inner) => Poll::Ready(inner),
Err(s) => {
this.replace(s);
Poll::Pending
}
}
})
.await;
inner.socket.close().await
}
}
impl Drop for EndpointRef {
fn drop(&mut self) {
if Shared::strong_count(&self.0) == 2 {
self.0.done.wake();
self.0.state.lock().exit_on_idle = true;
}
}
}
impl Deref for EndpointRef {
type Target = EndpointInner;
fn deref(&self) -> &Self::Target {
&self.0
}
}
#[derive(Debug, Clone)]
pub struct Endpoint {
inner: EndpointRef,
pub default_client_config: Option<ClientConfig>,
}
impl Endpoint {
pub fn new(
socket: UdpSocket,
config: EndpointConfig,
server_config: Option<ServerConfig>,
default_client_config: Option<ClientConfig>,
) -> io::Result<Self> {
let inner = EndpointRef(Shared::new(EndpointInner::new(
socket,
config,
server_config,
)?));
let worker = compio::runtime::spawn({
let inner = inner.clone();
async move {
#[allow(unused)]
if let Err(e) = inner.run().await {
error!("I/O error: {}", e);
}
}
.in_current_span()
});
inner.state.lock().worker = Some(worker);
Ok(Self {
inner,
default_client_config,
})
}
#[cfg(rustls)]
pub async fn client(addr: impl ToSocketAddrsAsync) -> io::Result<Endpoint> {
let socket = UdpSocket::bind(addr).await?;
Self::new(socket, EndpointConfig::default(), None, None)
}
#[cfg(rustls)]
pub async fn server(addr: impl ToSocketAddrsAsync, config: ServerConfig) -> io::Result<Self> {
let socket = UdpSocket::bind(addr).await?;
Self::new(socket, EndpointConfig::default(), Some(config), None)
}
pub fn stats(&self) -> EndpointStats {
self.inner.state.lock().stats
}
pub fn connect(
&self,
remote: SocketAddr,
server_name: &str,
config: Option<ClientConfig>,
) -> Result<Connecting, ConnectError> {
let config = config
.or_else(|| self.default_client_config.clone())
.ok_or(ConnectError::NoDefaultClientConfig)?;
self.inner.connect(remote, server_name, config)
}
pub async fn wait_incoming(&self) -> Option<Incoming> {
future::poll_fn(|cx| self.inner.state.lock().poll_incoming(cx))
.await
.map(|incoming| Incoming::new(incoming, self.inner.clone()))
}
pub fn set_server_config(&self, server_config: Option<ServerConfig>) {
self.inner
.state
.lock()
.endpoint
.set_server_config(server_config.map(Arc::new))
}
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.inner.socket.local_addr()
}
pub fn open_connections(&self) -> usize {
self.inner.state.lock().endpoint.open_connections()
}
pub fn close(&self, error_code: VarInt, reason: &[u8]) {
let reason = Bytes::copy_from_slice(reason);
let mut state = self.inner.state.lock();
if state.close.is_some() {
return;
}
state.close = Some((error_code, reason.clone()));
for conn in state.connections.values() {
let _ = conn.send(ConnectionEvent::Close(error_code, reason.clone()));
}
state.incoming_wakers.drain(..).for_each(Waker::wake);
}
pub async fn shutdown(self) -> io::Result<()> {
self.inner.shutdown().await
}
}