async-quic 0.3.1

runtime independent async quic implementation based on quinn-proto
Documentation
use async_io::Async;
use bytes::BytesMut;
use futures::{
    channel::mpsc::{channel, Receiver, Sender},
    prelude::*,
    ready,
};
use std::{
    collections::{BTreeMap, VecDeque},
    io::{self, IoSliceMut},
    mem::MaybeUninit,
    net::{IpAddr, SocketAddr, UdpSocket},
    pin::Pin,
    sync::{Arc, Mutex},
    task::{Context, Poll},
    time::Instant,
};

use crate::{ConnectionInner, QuicConnection};

pub struct QuicEndpoint {
    inner: Arc<EndpointInner>,
}

impl QuicEndpoint {
    pub fn new(
        udp: UdpSocket,
        server_config: Option<Arc<rustls::ServerConfig>>,
    ) -> io::Result<Self> {
        quinn_udp::UdpSocketState::configure((&udp).into())?;
        let config = server_config.map(|c| Arc::new(quinn_proto::ServerConfig::with_crypto(c)));
        let endpoint = quinn_proto::Endpoint::new(Arc::new(Default::default()), config);
        let udp_state = Arc::new(quinn_udp::UdpState::new());
        let recv_buf = vec![
            0u8;
            endpoint.config().get_max_udp_payload_size().min(64 * 1024) as usize
                * udp_state.gro_segments()
                * quinn_udp::BATCH_SIZE
        ]
        .into_boxed_slice();
        let (transmit_sender, transmit_receiver) = channel(quinn_udp::BATCH_SIZE);
        let state = Mutex::new(EndpointState {
            connections: BTreeMap::new(),
            endpoint,
            udp: (Async::new(udp)?, quinn_udp::UdpSocketState::new(), recv_buf),
            recv_buffer: VecDeque::new(),
            transmit_buffer: VecDeque::with_capacity(quinn_udp::BATCH_SIZE),
            transmit_receiver,
        });
        let inner = Arc::new(EndpointInner {
            state,
            transmit_sender,
            udp_state,
        });
        Ok(Self { inner })
    }
}

impl Stream for QuicEndpoint {
    type Item = QuicConnection;

    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
        let mut state = self.inner.state.lock().unwrap();
        loop {
            state.poll_transmit(&self.inner.udp_state, cx);
            while let Some(msg) = state.recv_buffer.pop_front() {
                match state
                    .endpoint
                    .handle(Instant::now(), msg.0, msg.1, msg.2, msg.3)
                {
                    Some((handle, quinn_proto::DatagramEvent::ConnectionEvent(event))) => {
                        match state.connections.get(&handle) {
                            Some(conn) => conn.handle_event(event),
                            None => log::error!("connection not found"),
                        }
                    }
                    Some((handle, quinn_proto::DatagramEvent::NewConnection(conn))) => {
                        let conn = QuicConnection::new(
                            handle,
                            conn,
                            self.inner.clone(),
                            self.inner.transmit_sender.clone(),
                        );
                        state.connections.insert(handle, conn.inner());
                        return Poll::Ready(Some(conn));
                    }
                    None => {}
                }
            }
            match state.fill_recv_buffer(cx) {
                Poll::Ready(Ok(())) => {}
                Poll::Ready(Err(err)) => log::error!("endpoint receive error: {:?}", err),
                Poll::Pending => return Poll::Pending,
            }
        }
    }
}

pub(crate) struct EndpointInner {
    state: Mutex<EndpointState>,
    transmit_sender: Sender<quinn_proto::Transmit>,
    udp_state: Arc<quinn_udp::UdpState>,
}

impl EndpointInner {
    pub(crate) fn udp_state(&self) -> &quinn_udp::UdpState {
        &self.udp_state
    }
    pub(crate) fn handle_enpoint_event(
        &self,
        handle: quinn_proto::ConnectionHandle,
        event: quinn_proto::EndpointEvent,
    ) -> Option<quinn_proto::ConnectionEvent> {
        self.state
            .lock()
            .unwrap()
            .endpoint
            .handle_event(handle, event)
    }
}

struct EndpointState {
    connections: BTreeMap<quinn_proto::ConnectionHandle, Arc<ConnectionInner>>,
    udp: (Async<UdpSocket>, quinn_udp::UdpSocketState, Box<[u8]>),
    transmit_receiver: Receiver<quinn_proto::Transmit>,
    transmit_buffer: VecDeque<quinn_proto::Transmit>,
    recv_buffer: VecDeque<(
        SocketAddr,
        Option<IpAddr>,
        Option<quinn_proto::EcnCodepoint>,
        BytesMut,
    )>,
    endpoint: quinn_proto::Endpoint,
}

impl EndpointState {
    fn poll_transmit(&mut self, udp_state: &quinn_udp::UdpState, cx: &mut Context) {
        for _ in 0..3 {
            while self.transmit_buffer.len() < self.transmit_buffer.capacity() {
                match self.endpoint.poll_transmit() {
                    Some(t) => self.transmit_buffer.push_back(t),
                    None => break,
                }
            }
            while self.transmit_buffer.len() < self.transmit_buffer.capacity() {
                match self.transmit_receiver.poll_next_unpin(cx) {
                    Poll::Ready(Some(t)) => self.transmit_buffer.push_back(t),
                    Poll::Ready(None) => unreachable!(),
                    Poll::Pending => break,
                }
            }
            match poll_send(
                &mut self.udp.1,
                udp_state,
                &self.udp.0,
                cx,
                self.transmit_buffer.make_contiguous(),
            ) {
                Poll::Ready(Ok(n)) => drop(self.transmit_buffer.drain(0..n)),
                Poll::Ready(Err(err)) => log::error!("endpoint send error: {:?}", err),
                Poll::Pending => break,
            }
        }
    }

    fn fill_recv_buffer<'a>(&'a mut self, cx: &mut Context) -> Poll<io::Result<()>> {
        loop {
            let mut metas = [quinn_udp::RecvMeta::default(); quinn_udp::BATCH_SIZE];
            let mut iovs = MaybeUninit::<[IoSliceMut<'a>; quinn_udp::BATCH_SIZE]>::uninit();
            self.udp
                .2
                .chunks_mut(self.udp.2.len() / quinn_udp::BATCH_SIZE)
                .enumerate()
                .for_each(|(i, buf)| unsafe {
                    iovs.as_mut_ptr()
                        .cast::<IoSliceMut>()
                        .add(i)
                        .write(IoSliceMut::<'a>::new(buf));
                });
            let mut iovs = unsafe { iovs.assume_init() };
            match poll_recv(&self.udp.1, &self.udp.0, cx, &mut iovs, &mut metas) {
                Poll::Ready(Ok(n)) => {
                    for (meta, buf) in metas.iter().zip(iovs.iter()).take(n) {
                        let mut b: BytesMut = buf[0..meta.len].into();
                        while !b.is_empty() {
                            let b = b.split_to(meta.stride.min(b.len()));
                            self.recv_buffer
                                .push_back((meta.addr, meta.dst_ip, meta.ecn, b));
                        }
                    }
                    return Poll::Ready(Ok(()));
                }
                Poll::Pending => return Poll::Pending,
                Poll::Ready(Err(err)) => {
                    if err.kind() != io::ErrorKind::ConnectionReset {
                        return Poll::Ready(Err(err));
                    }
                    todo!()
                }
            }
        }
    }
}

fn poll_send(
    uss: &mut quinn_udp::UdpSocketState,
    us: &quinn_udp::UdpState,
    io: &Async<UdpSocket>,
    cx: &mut Context,
    t: &[quinn_proto::Transmit],
) -> Poll<io::Result<usize>> {
    loop {
        ready!(io.poll_writable(cx))?;
        if let Ok(n) = uss.send(io.into(), us, t) {
            return Poll::Ready(Ok(n));
        }
    }
}
fn poll_recv(
    uss: &quinn_udp::UdpSocketState,
    io: &Async<UdpSocket>,
    cx: &mut Context,
    b: &mut [IoSliceMut<'_>],
    m: &mut [quinn_udp::RecvMeta],
) -> Poll<io::Result<usize>> {
    loop {
        ready!(io.poll_readable(cx))?;
        if let Ok(res) = uss.recv(io.into(), b, m) {
            return Poll::Ready(Ok(res));
        }
    }
}