async_quic/
endpoint.rs

1use async_io::Async;
2use bytes::BytesMut;
3use futures::{
4    channel::mpsc::{channel, Receiver, Sender},
5    prelude::*,
6    ready,
7};
8use std::{
9    collections::{BTreeMap, VecDeque},
10    io::{self, IoSliceMut},
11    mem::MaybeUninit,
12    net::{IpAddr, SocketAddr, UdpSocket},
13    pin::Pin,
14    sync::{Arc, Mutex},
15    task::{Context, Poll},
16    time::Instant,
17};
18
19use crate::{ConnectionInner, QuicConnection};
20
21pub struct QuicEndpoint {
22    inner: Arc<EndpointInner>,
23}
24
25impl QuicEndpoint {
26    pub fn new(
27        udp: UdpSocket,
28        server_config: Option<Arc<rustls::ServerConfig>>,
29    ) -> io::Result<Self> {
30        quinn_udp::UdpSocketState::configure((&udp).into())?;
31        let config = server_config.map(|c| Arc::new(quinn_proto::ServerConfig::with_crypto(c)));
32        let endpoint = quinn_proto::Endpoint::new(Arc::new(Default::default()), config);
33        let udp_state = Arc::new(quinn_udp::UdpState::new());
34        let recv_buf = vec![
35            0u8;
36            endpoint.config().get_max_udp_payload_size().min(64 * 1024) as usize
37                * udp_state.gro_segments()
38                * quinn_udp::BATCH_SIZE
39        ]
40        .into_boxed_slice();
41        let (transmit_sender, transmit_receiver) = channel(quinn_udp::BATCH_SIZE);
42        let state = Mutex::new(EndpointState {
43            connections: BTreeMap::new(),
44            endpoint,
45            udp: (Async::new(udp)?, quinn_udp::UdpSocketState::new(), recv_buf),
46            recv_buffer: VecDeque::new(),
47            transmit_buffer: VecDeque::with_capacity(quinn_udp::BATCH_SIZE),
48            transmit_receiver,
49        });
50        let inner = Arc::new(EndpointInner {
51            state,
52            transmit_sender,
53            udp_state,
54        });
55        Ok(Self { inner })
56    }
57}
58
59impl Stream for QuicEndpoint {
60    type Item = QuicConnection;
61
62    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
63        let mut state = self.inner.state.lock().unwrap();
64        loop {
65            state.poll_transmit(&self.inner.udp_state, cx);
66            while let Some(msg) = state.recv_buffer.pop_front() {
67                match state
68                    .endpoint
69                    .handle(Instant::now(), msg.0, msg.1, msg.2, msg.3)
70                {
71                    Some((handle, quinn_proto::DatagramEvent::ConnectionEvent(event))) => {
72                        match state.connections.get(&handle) {
73                            Some(conn) => conn.handle_event(event),
74                            None => log::error!("connection not found"),
75                        }
76                    }
77                    Some((handle, quinn_proto::DatagramEvent::NewConnection(conn))) => {
78                        let conn = QuicConnection::new(
79                            handle,
80                            conn,
81                            self.inner.clone(),
82                            self.inner.transmit_sender.clone(),
83                        );
84                        state.connections.insert(handle, conn.inner());
85                        return Poll::Ready(Some(conn));
86                    }
87                    None => {}
88                }
89            }
90            match state.fill_recv_buffer(cx) {
91                Poll::Ready(Ok(())) => {}
92                Poll::Ready(Err(err)) => log::error!("endpoint receive error: {:?}", err),
93                Poll::Pending => return Poll::Pending,
94            }
95        }
96    }
97}
98
99pub(crate) struct EndpointInner {
100    state: Mutex<EndpointState>,
101    transmit_sender: Sender<quinn_proto::Transmit>,
102    udp_state: Arc<quinn_udp::UdpState>,
103}
104
105impl EndpointInner {
106    pub(crate) fn udp_state(&self) -> &quinn_udp::UdpState {
107        &self.udp_state
108    }
109    pub(crate) fn handle_enpoint_event(
110        &self,
111        handle: quinn_proto::ConnectionHandle,
112        event: quinn_proto::EndpointEvent,
113    ) -> Option<quinn_proto::ConnectionEvent> {
114        self.state
115            .lock()
116            .unwrap()
117            .endpoint
118            .handle_event(handle, event)
119    }
120}
121
122struct EndpointState {
123    connections: BTreeMap<quinn_proto::ConnectionHandle, Arc<ConnectionInner>>,
124    udp: (Async<UdpSocket>, quinn_udp::UdpSocketState, Box<[u8]>),
125    transmit_receiver: Receiver<quinn_proto::Transmit>,
126    transmit_buffer: VecDeque<quinn_proto::Transmit>,
127    recv_buffer: VecDeque<(
128        SocketAddr,
129        Option<IpAddr>,
130        Option<quinn_proto::EcnCodepoint>,
131        BytesMut,
132    )>,
133    endpoint: quinn_proto::Endpoint,
134}
135
136impl EndpointState {
137    fn poll_transmit(&mut self, udp_state: &quinn_udp::UdpState, cx: &mut Context) {
138        for _ in 0..3 {
139            while self.transmit_buffer.len() < self.transmit_buffer.capacity() {
140                match self.endpoint.poll_transmit() {
141                    Some(t) => self.transmit_buffer.push_back(t),
142                    None => break,
143                }
144            }
145            while self.transmit_buffer.len() < self.transmit_buffer.capacity() {
146                match self.transmit_receiver.poll_next_unpin(cx) {
147                    Poll::Ready(Some(t)) => self.transmit_buffer.push_back(t),
148                    Poll::Ready(None) => unreachable!(),
149                    Poll::Pending => break,
150                }
151            }
152            match poll_send(
153                &mut self.udp.1,
154                udp_state,
155                &self.udp.0,
156                cx,
157                self.transmit_buffer.make_contiguous(),
158            ) {
159                Poll::Ready(Ok(n)) => drop(self.transmit_buffer.drain(0..n)),
160                Poll::Ready(Err(err)) => log::error!("endpoint send error: {:?}", err),
161                Poll::Pending => break,
162            }
163        }
164    }
165
166    fn fill_recv_buffer<'a>(&'a mut self, cx: &mut Context) -> Poll<io::Result<()>> {
167        loop {
168            let mut metas = [quinn_udp::RecvMeta::default(); quinn_udp::BATCH_SIZE];
169            let mut iovs = MaybeUninit::<[IoSliceMut<'a>; quinn_udp::BATCH_SIZE]>::uninit();
170            self.udp
171                .2
172                .chunks_mut(self.udp.2.len() / quinn_udp::BATCH_SIZE)
173                .enumerate()
174                .for_each(|(i, buf)| unsafe {
175                    iovs.as_mut_ptr()
176                        .cast::<IoSliceMut>()
177                        .add(i)
178                        .write(IoSliceMut::<'a>::new(buf));
179                });
180            let mut iovs = unsafe { iovs.assume_init() };
181            match poll_recv(&self.udp.1, &self.udp.0, cx, &mut iovs, &mut metas) {
182                Poll::Ready(Ok(n)) => {
183                    for (meta, buf) in metas.iter().zip(iovs.iter()).take(n) {
184                        let mut b: BytesMut = buf[0..meta.len].into();
185                        while !b.is_empty() {
186                            let b = b.split_to(meta.stride.min(b.len()));
187                            self.recv_buffer
188                                .push_back((meta.addr, meta.dst_ip, meta.ecn, b));
189                        }
190                    }
191                    return Poll::Ready(Ok(()));
192                }
193                Poll::Pending => return Poll::Pending,
194                Poll::Ready(Err(err)) => {
195                    if err.kind() != io::ErrorKind::ConnectionReset {
196                        return Poll::Ready(Err(err));
197                    }
198                    todo!()
199                }
200            }
201        }
202    }
203}
204
205fn poll_send(
206    uss: &mut quinn_udp::UdpSocketState,
207    us: &quinn_udp::UdpState,
208    io: &Async<UdpSocket>,
209    cx: &mut Context,
210    t: &[quinn_proto::Transmit],
211) -> Poll<io::Result<usize>> {
212    loop {
213        ready!(io.poll_writable(cx))?;
214        if let Ok(n) = uss.send(io.into(), us, t) {
215            return Poll::Ready(Ok(n));
216        }
217    }
218}
219fn poll_recv(
220    uss: &quinn_udp::UdpSocketState,
221    io: &Async<UdpSocket>,
222    cx: &mut Context,
223    b: &mut [IoSliceMut<'_>],
224    m: &mut [quinn_udp::RecvMeta],
225) -> Poll<io::Result<usize>> {
226    loop {
227        ready!(io.poll_readable(cx))?;
228        if let Ok(res) = uss.recv(io.into(), b, m) {
229            return Poll::Ready(Ok(res));
230        }
231    }
232}