use async_io::Async;
use bytes::BytesMut;
use futures::{
channel::mpsc::{channel, Receiver, Sender},
prelude::*,
ready,
};
use std::{
collections::{BTreeMap, VecDeque},
io::{self, IoSliceMut},
mem::MaybeUninit,
net::{IpAddr, SocketAddr, UdpSocket},
pin::Pin,
sync::{Arc, Mutex},
task::{Context, Poll},
time::Instant,
};
use crate::{ConnectionInner, QuicConnection};
pub struct QuicEndpoint {
inner: Arc<EndpointInner>,
}
impl QuicEndpoint {
pub fn new(
udp: UdpSocket,
server_config: Option<Arc<rustls::ServerConfig>>,
) -> io::Result<Self> {
quinn_udp::UdpSocketState::configure((&udp).into())?;
let config = server_config.map(|c| Arc::new(quinn_proto::ServerConfig::with_crypto(c)));
let endpoint = quinn_proto::Endpoint::new(Arc::new(Default::default()), config);
let udp_state = Arc::new(quinn_udp::UdpState::new());
let recv_buf = vec![
0u8;
endpoint.config().get_max_udp_payload_size().min(64 * 1024) as usize
* udp_state.gro_segments()
* quinn_udp::BATCH_SIZE
]
.into_boxed_slice();
let (transmit_sender, transmit_receiver) = channel(quinn_udp::BATCH_SIZE);
let state = Mutex::new(EndpointState {
connections: BTreeMap::new(),
endpoint,
udp: (Async::new(udp)?, quinn_udp::UdpSocketState::new(), recv_buf),
recv_buffer: VecDeque::new(),
transmit_buffer: VecDeque::with_capacity(quinn_udp::BATCH_SIZE),
transmit_receiver,
});
let inner = Arc::new(EndpointInner {
state,
transmit_sender,
udp_state,
});
Ok(Self { inner })
}
}
impl Stream for QuicEndpoint {
type Item = QuicConnection;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut state = self.inner.state.lock().unwrap();
loop {
state.poll_transmit(&self.inner.udp_state, cx);
while let Some(msg) = state.recv_buffer.pop_front() {
match state
.endpoint
.handle(Instant::now(), msg.0, msg.1, msg.2, msg.3)
{
Some((handle, quinn_proto::DatagramEvent::ConnectionEvent(event))) => {
match state.connections.get(&handle) {
Some(conn) => conn.handle_event(event),
None => log::error!("connection not found"),
}
}
Some((handle, quinn_proto::DatagramEvent::NewConnection(conn))) => {
let conn = QuicConnection::new(
handle,
conn,
self.inner.clone(),
self.inner.transmit_sender.clone(),
);
state.connections.insert(handle, conn.inner());
return Poll::Ready(Some(conn));
}
None => {}
}
}
match state.fill_recv_buffer(cx) {
Poll::Ready(Ok(())) => {}
Poll::Ready(Err(err)) => log::error!("endpoint receive error: {:?}", err),
Poll::Pending => return Poll::Pending,
}
}
}
}
pub(crate) struct EndpointInner {
state: Mutex<EndpointState>,
transmit_sender: Sender<quinn_proto::Transmit>,
udp_state: Arc<quinn_udp::UdpState>,
}
impl EndpointInner {
pub(crate) fn udp_state(&self) -> &quinn_udp::UdpState {
&self.udp_state
}
pub(crate) fn handle_enpoint_event(
&self,
handle: quinn_proto::ConnectionHandle,
event: quinn_proto::EndpointEvent,
) -> Option<quinn_proto::ConnectionEvent> {
self.state
.lock()
.unwrap()
.endpoint
.handle_event(handle, event)
}
}
struct EndpointState {
connections: BTreeMap<quinn_proto::ConnectionHandle, Arc<ConnectionInner>>,
udp: (Async<UdpSocket>, quinn_udp::UdpSocketState, Box<[u8]>),
transmit_receiver: Receiver<quinn_proto::Transmit>,
transmit_buffer: VecDeque<quinn_proto::Transmit>,
recv_buffer: VecDeque<(
SocketAddr,
Option<IpAddr>,
Option<quinn_proto::EcnCodepoint>,
BytesMut,
)>,
endpoint: quinn_proto::Endpoint,
}
impl EndpointState {
fn poll_transmit(&mut self, udp_state: &quinn_udp::UdpState, cx: &mut Context) {
for _ in 0..3 {
while self.transmit_buffer.len() < self.transmit_buffer.capacity() {
match self.endpoint.poll_transmit() {
Some(t) => self.transmit_buffer.push_back(t),
None => break,
}
}
while self.transmit_buffer.len() < self.transmit_buffer.capacity() {
match self.transmit_receiver.poll_next_unpin(cx) {
Poll::Ready(Some(t)) => self.transmit_buffer.push_back(t),
Poll::Ready(None) => unreachable!(),
Poll::Pending => break,
}
}
match poll_send(
&mut self.udp.1,
udp_state,
&self.udp.0,
cx,
self.transmit_buffer.make_contiguous(),
) {
Poll::Ready(Ok(n)) => drop(self.transmit_buffer.drain(0..n)),
Poll::Ready(Err(err)) => log::error!("endpoint send error: {:?}", err),
Poll::Pending => break,
}
}
}
fn fill_recv_buffer<'a>(&'a mut self, cx: &mut Context) -> Poll<io::Result<()>> {
loop {
let mut metas = [quinn_udp::RecvMeta::default(); quinn_udp::BATCH_SIZE];
let mut iovs = MaybeUninit::<[IoSliceMut<'a>; quinn_udp::BATCH_SIZE]>::uninit();
self.udp
.2
.chunks_mut(self.udp.2.len() / quinn_udp::BATCH_SIZE)
.enumerate()
.for_each(|(i, buf)| unsafe {
iovs.as_mut_ptr()
.cast::<IoSliceMut>()
.add(i)
.write(IoSliceMut::<'a>::new(buf));
});
let mut iovs = unsafe { iovs.assume_init() };
match poll_recv(&self.udp.1, &self.udp.0, cx, &mut iovs, &mut metas) {
Poll::Ready(Ok(n)) => {
for (meta, buf) in metas.iter().zip(iovs.iter()).take(n) {
let mut b: BytesMut = buf[0..meta.len].into();
while !b.is_empty() {
let b = b.split_to(meta.stride.min(b.len()));
self.recv_buffer
.push_back((meta.addr, meta.dst_ip, meta.ecn, b));
}
}
return Poll::Ready(Ok(()));
}
Poll::Pending => return Poll::Pending,
Poll::Ready(Err(err)) => {
if err.kind() != io::ErrorKind::ConnectionReset {
return Poll::Ready(Err(err));
}
todo!()
}
}
}
}
}
fn poll_send(
uss: &mut quinn_udp::UdpSocketState,
us: &quinn_udp::UdpState,
io: &Async<UdpSocket>,
cx: &mut Context,
t: &[quinn_proto::Transmit],
) -> Poll<io::Result<usize>> {
loop {
ready!(io.poll_writable(cx))?;
if let Ok(n) = uss.send(io.into(), us, t) {
return Poll::Ready(Ok(n));
}
}
}
fn poll_recv(
uss: &quinn_udp::UdpSocketState,
io: &Async<UdpSocket>,
cx: &mut Context,
b: &mut [IoSliceMut<'_>],
m: &mut [quinn_udp::RecvMeta],
) -> Poll<io::Result<usize>> {
loop {
ready!(io.poll_readable(cx))?;
if let Ok(res) = uss.recv(io.into(), b, m) {
return Poll::Ready(Ok(res));
}
}
}