1use async_io::Async;
2use bytes::BytesMut;
3use futures::{
4 channel::mpsc::{channel, Receiver, Sender},
5 prelude::*,
6 ready,
7};
8use std::{
9 collections::{BTreeMap, VecDeque},
10 io::{self, IoSliceMut},
11 mem::MaybeUninit,
12 net::{IpAddr, SocketAddr, UdpSocket},
13 pin::Pin,
14 sync::{Arc, Mutex},
15 task::{Context, Poll},
16 time::Instant,
17};
18
19use crate::{ConnectionInner, QuicConnection};
20
21pub struct QuicEndpoint {
22 inner: Arc<EndpointInner>,
23}
24
25impl QuicEndpoint {
26 pub fn new(
27 udp: UdpSocket,
28 server_config: Option<Arc<rustls::ServerConfig>>,
29 ) -> io::Result<Self> {
30 quinn_udp::UdpSocketState::configure((&udp).into())?;
31 let config = server_config.map(|c| Arc::new(quinn_proto::ServerConfig::with_crypto(c)));
32 let endpoint = quinn_proto::Endpoint::new(Arc::new(Default::default()), config);
33 let udp_state = Arc::new(quinn_udp::UdpState::new());
34 let recv_buf = vec![
35 0u8;
36 endpoint.config().get_max_udp_payload_size().min(64 * 1024) as usize
37 * udp_state.gro_segments()
38 * quinn_udp::BATCH_SIZE
39 ]
40 .into_boxed_slice();
41 let (transmit_sender, transmit_receiver) = channel(quinn_udp::BATCH_SIZE);
42 let state = Mutex::new(EndpointState {
43 connections: BTreeMap::new(),
44 endpoint,
45 udp: (Async::new(udp)?, quinn_udp::UdpSocketState::new(), recv_buf),
46 recv_buffer: VecDeque::new(),
47 transmit_buffer: VecDeque::with_capacity(quinn_udp::BATCH_SIZE),
48 transmit_receiver,
49 });
50 let inner = Arc::new(EndpointInner {
51 state,
52 transmit_sender,
53 udp_state,
54 });
55 Ok(Self { inner })
56 }
57}
58
59impl Stream for QuicEndpoint {
60 type Item = QuicConnection;
61
62 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
63 let mut state = self.inner.state.lock().unwrap();
64 loop {
65 state.poll_transmit(&self.inner.udp_state, cx);
66 while let Some(msg) = state.recv_buffer.pop_front() {
67 match state
68 .endpoint
69 .handle(Instant::now(), msg.0, msg.1, msg.2, msg.3)
70 {
71 Some((handle, quinn_proto::DatagramEvent::ConnectionEvent(event))) => {
72 match state.connections.get(&handle) {
73 Some(conn) => conn.handle_event(event),
74 None => log::error!("connection not found"),
75 }
76 }
77 Some((handle, quinn_proto::DatagramEvent::NewConnection(conn))) => {
78 let conn = QuicConnection::new(
79 handle,
80 conn,
81 self.inner.clone(),
82 self.inner.transmit_sender.clone(),
83 );
84 state.connections.insert(handle, conn.inner());
85 return Poll::Ready(Some(conn));
86 }
87 None => {}
88 }
89 }
90 match state.fill_recv_buffer(cx) {
91 Poll::Ready(Ok(())) => {}
92 Poll::Ready(Err(err)) => log::error!("endpoint receive error: {:?}", err),
93 Poll::Pending => return Poll::Pending,
94 }
95 }
96 }
97}
98
99pub(crate) struct EndpointInner {
100 state: Mutex<EndpointState>,
101 transmit_sender: Sender<quinn_proto::Transmit>,
102 udp_state: Arc<quinn_udp::UdpState>,
103}
104
105impl EndpointInner {
106 pub(crate) fn udp_state(&self) -> &quinn_udp::UdpState {
107 &self.udp_state
108 }
109 pub(crate) fn handle_enpoint_event(
110 &self,
111 handle: quinn_proto::ConnectionHandle,
112 event: quinn_proto::EndpointEvent,
113 ) -> Option<quinn_proto::ConnectionEvent> {
114 self.state
115 .lock()
116 .unwrap()
117 .endpoint
118 .handle_event(handle, event)
119 }
120}
121
122struct EndpointState {
123 connections: BTreeMap<quinn_proto::ConnectionHandle, Arc<ConnectionInner>>,
124 udp: (Async<UdpSocket>, quinn_udp::UdpSocketState, Box<[u8]>),
125 transmit_receiver: Receiver<quinn_proto::Transmit>,
126 transmit_buffer: VecDeque<quinn_proto::Transmit>,
127 recv_buffer: VecDeque<(
128 SocketAddr,
129 Option<IpAddr>,
130 Option<quinn_proto::EcnCodepoint>,
131 BytesMut,
132 )>,
133 endpoint: quinn_proto::Endpoint,
134}
135
136impl EndpointState {
137 fn poll_transmit(&mut self, udp_state: &quinn_udp::UdpState, cx: &mut Context) {
138 for _ in 0..3 {
139 while self.transmit_buffer.len() < self.transmit_buffer.capacity() {
140 match self.endpoint.poll_transmit() {
141 Some(t) => self.transmit_buffer.push_back(t),
142 None => break,
143 }
144 }
145 while self.transmit_buffer.len() < self.transmit_buffer.capacity() {
146 match self.transmit_receiver.poll_next_unpin(cx) {
147 Poll::Ready(Some(t)) => self.transmit_buffer.push_back(t),
148 Poll::Ready(None) => unreachable!(),
149 Poll::Pending => break,
150 }
151 }
152 match poll_send(
153 &mut self.udp.1,
154 udp_state,
155 &self.udp.0,
156 cx,
157 self.transmit_buffer.make_contiguous(),
158 ) {
159 Poll::Ready(Ok(n)) => drop(self.transmit_buffer.drain(0..n)),
160 Poll::Ready(Err(err)) => log::error!("endpoint send error: {:?}", err),
161 Poll::Pending => break,
162 }
163 }
164 }
165
166 fn fill_recv_buffer<'a>(&'a mut self, cx: &mut Context) -> Poll<io::Result<()>> {
167 loop {
168 let mut metas = [quinn_udp::RecvMeta::default(); quinn_udp::BATCH_SIZE];
169 let mut iovs = MaybeUninit::<[IoSliceMut<'a>; quinn_udp::BATCH_SIZE]>::uninit();
170 self.udp
171 .2
172 .chunks_mut(self.udp.2.len() / quinn_udp::BATCH_SIZE)
173 .enumerate()
174 .for_each(|(i, buf)| unsafe {
175 iovs.as_mut_ptr()
176 .cast::<IoSliceMut>()
177 .add(i)
178 .write(IoSliceMut::<'a>::new(buf));
179 });
180 let mut iovs = unsafe { iovs.assume_init() };
181 match poll_recv(&self.udp.1, &self.udp.0, cx, &mut iovs, &mut metas) {
182 Poll::Ready(Ok(n)) => {
183 for (meta, buf) in metas.iter().zip(iovs.iter()).take(n) {
184 let mut b: BytesMut = buf[0..meta.len].into();
185 while !b.is_empty() {
186 let b = b.split_to(meta.stride.min(b.len()));
187 self.recv_buffer
188 .push_back((meta.addr, meta.dst_ip, meta.ecn, b));
189 }
190 }
191 return Poll::Ready(Ok(()));
192 }
193 Poll::Pending => return Poll::Pending,
194 Poll::Ready(Err(err)) => {
195 if err.kind() != io::ErrorKind::ConnectionReset {
196 return Poll::Ready(Err(err));
197 }
198 todo!()
199 }
200 }
201 }
202 }
203}
204
205fn poll_send(
206 uss: &mut quinn_udp::UdpSocketState,
207 us: &quinn_udp::UdpState,
208 io: &Async<UdpSocket>,
209 cx: &mut Context,
210 t: &[quinn_proto::Transmit],
211) -> Poll<io::Result<usize>> {
212 loop {
213 ready!(io.poll_writable(cx))?;
214 if let Ok(n) = uss.send(io.into(), us, t) {
215 return Poll::Ready(Ok(n));
216 }
217 }
218}
219fn poll_recv(
220 uss: &quinn_udp::UdpSocketState,
221 io: &Async<UdpSocket>,
222 cx: &mut Context,
223 b: &mut [IoSliceMut<'_>],
224 m: &mut [quinn_udp::RecvMeta],
225) -> Poll<io::Result<usize>> {
226 loop {
227 ready!(io.poll_readable(cx))?;
228 if let Ok(res) = uss.recv(io.into(), b, m) {
229 return Poll::Ready(Ok(res));
230 }
231 }
232}