1use std::collections::{HashMap, VecDeque};
2use std::io::{Error, ErrorKind, Result};
3use std::net::{SocketAddr, ToSocketAddrs};
4use std::sync::Arc;
5
6use futures::lock::Mutex;
7use futures::stream::unfold;
8use futures::{Stream, StreamExt};
9use futures_map::{FuturesUnorderedMap, KeyWaitMap};
10use quiche::{Config, ConnectionId, RecvInfo, SendInfo};
11use ring::{hmac::Key, rand::SystemRandom};
12
13use crate::errors::map_quic_error;
14use crate::{QuicConn, QuicConnState};
15
16enum QuicListenerHandshake {
17 Connection {
18 #[allow(unused)]
19 conn: QuicConnState,
20 is_established: bool,
21 read_size: usize,
23 },
24 Response {
25 buf: Vec<u8>,
27 read_size: usize,
29 },
30}
31
32pub struct QuicListenerState {
34 config: Config,
36 seed_key: Key,
38 handshaking_pool: HashMap<ConnectionId<'static>, QuicConnState>,
40 established_conns: HashMap<ConnectionId<'static>, QuicConnState>,
42 incoming_conns: VecDeque<QuicConn>,
44}
45
46impl QuicListenerState {
47 fn new(config: Config) -> Result<Self> {
49 let rng = SystemRandom::new();
50
51 let seed_key = ring::hmac::Key::generate(ring::hmac::HMAC_SHA256, &rng)
52 .map_err(|err| Error::new(ErrorKind::Other, format!("{}", err)))?;
53
54 Ok(Self {
55 config,
56 seed_key,
57 handshaking_pool: Default::default(),
58 incoming_conns: Default::default(),
59 established_conns: Default::default(),
60 })
61 }
62
63 fn get_conn<'a>(&self, id: &ConnectionId<'a>) -> Option<(QuicConnState, bool)> {
67 if let Some(conn) = self.handshaking_pool.get(id) {
68 return Some((conn.clone(), false));
69 }
70
71 if let Some(conn) = self.established_conns.get(id) {
72 return Some((conn.clone(), true));
73 }
74
75 None
76 }
77
78 #[allow(unused)]
80 fn established<'a>(&mut self, id: &ConnectionId<'a>) {
81 let id = id.clone().into_owned();
82 if let Some(conn) = self.handshaking_pool.remove(&id) {
83 self.established_conns.insert(id, conn.clone());
84 self.incoming_conns.push_back(conn.into());
85 }
86 }
87
88 fn remove_conn<'a>(&mut self, id: &ConnectionId<'a>) -> bool {
90 let id = id.clone().into_owned();
91 if self.handshaking_pool.remove(&id).is_some() {
92 return true;
93 }
94
95 if self.established_conns.remove(&id).is_some() {
96 return true;
97 }
98
99 false
100 }
101
102 fn handshake<'a>(
104 &mut self,
105 header: &quiche::Header<'a>,
106 buf: &'a mut [u8],
107 recv_info: RecvInfo,
108 ) -> Result<QuicListenerHandshake> {
109 if header.ty != quiche::Type::Initial {
110 return Err(Error::new(
111 ErrorKind::InvalidData,
112 format!("Invalid packet: {:?}", recv_info),
113 ));
114 }
115
116 self.client_hello(header, buf, recv_info)
117 }
118
119 fn client_hello<'a>(
120 &mut self,
121 header: &quiche::Header<'a>,
122 buf: &'a mut [u8],
123 recv_info: RecvInfo,
124 ) -> Result<QuicListenerHandshake> {
125 if !quiche::version_is_supported(header.version) {
126 return self.negotiation_version(header, recv_info, buf);
127 }
128
129 let token = header.token.as_ref().unwrap();
130
131 if token.is_empty() {
133 return self.retry(header, recv_info, buf);
134 }
135
136 let odcid = Self::validate_token(token, &recv_info.from)?;
138
139 let scid: quiche::ConnectionId<'_> = header.dcid.clone();
140
141 if quiche::MAX_CONN_ID_LEN != scid.len() {
142 return Err(Error::new(
143 ErrorKind::Interrupted,
144 format!("Check dcid length error, len={}", scid.len()),
145 ));
146 }
147
148 let mut quiche_conn = quiche::accept(
149 &scid,
150 Some(&odcid),
151 recv_info.to,
152 recv_info.from,
153 &mut self.config,
154 )
155 .map_err(map_quic_error)?;
156
157 let read_size = quiche_conn.recv(buf, recv_info).map_err(map_quic_error)?;
158
159 let is_established = quiche_conn.is_established();
160
161 let scid = quiche_conn.source_id().into_owned();
162 let dcid = quiche_conn.destination_id().into_owned();
163
164 log::trace!("Create new incoming conn, scid={:?}, dcid={:?}", scid, dcid);
165
166 let conn = QuicConnState::new(quiche_conn, 1, None);
167
168 if is_established {
169 self.established_conns.insert(scid, conn.clone());
170 self.incoming_conns.push_back(conn.clone().into());
171 } else {
172 self.handshaking_pool.insert(scid, conn.clone());
173 }
174
175 Ok(QuicListenerHandshake::Connection {
176 conn,
177 is_established,
178 read_size,
179 })
180 }
181
182 fn negotiation_version<'a>(
183 &mut self,
184 header: &quiche::Header<'a>,
185 _recv_info: RecvInfo,
186 buf: &mut [u8],
187 ) -> Result<QuicListenerHandshake> {
188 let scid = header.scid.clone().into_owned();
189 let dcid = header.dcid.clone().into_owned();
190
191 let mut read_buf = vec![0; 128];
192
193 let write_size = quiche::negotiate_version(&scid, &dcid, buf).map_err(map_quic_error)?;
194
195 read_buf.resize(write_size, 0);
196
197 Ok(QuicListenerHandshake::Response {
198 buf: read_buf,
199 read_size: buf.len(),
200 })
201 }
202 fn retry<'a>(
204 &mut self,
205 header: &quiche::Header<'a>,
206 recv_info: RecvInfo,
207 buf: &mut [u8],
208 ) -> Result<QuicListenerHandshake> {
209 let token = self.mint_token(&header, &recv_info.from);
210
211 let new_scid = ring::hmac::sign(&self.seed_key, &header.dcid);
212 let new_scid = &new_scid.as_ref()[..quiche::MAX_CONN_ID_LEN];
213 let new_scid = quiche::ConnectionId::from_vec(new_scid.to_vec());
214
215 let scid = header.scid.clone().into_owned();
216 let dcid: ConnectionId<'_> = header.dcid.clone().into_owned();
217 let version = header.version;
218
219 let mut read_buf = vec![0; 1200];
220
221 let write_size = quiche::retry(&scid, &dcid, &new_scid, &token, version, &mut read_buf)
222 .map_err(map_quic_error)?;
223
224 read_buf.resize(write_size, 0);
225
226 Ok(QuicListenerHandshake::Response {
227 buf: read_buf,
228 read_size: buf.len(),
229 })
230 }
231
232 fn validate_token<'a>(token: &'a [u8], src: &SocketAddr) -> Result<quiche::ConnectionId<'a>> {
233 if token.len() < 6 {
234 return Err(Error::new(
235 ErrorKind::Interrupted,
236 format!("Invalid token, token length < 6"),
237 ));
238 }
239
240 if &token[..6] != b"quiche" {
241 return Err(Error::new(
242 ErrorKind::Interrupted,
243 format!("Invalid token, not start with 'quiche'"),
244 ));
245 }
246
247 let token = &token[6..];
248
249 let addr = match src.ip() {
250 std::net::IpAddr::V4(a) => a.octets().to_vec(),
251 std::net::IpAddr::V6(a) => a.octets().to_vec(),
252 };
253
254 if token.len() < addr.len() || &token[..addr.len()] != addr.as_slice() {
255 return Err(Error::new(
256 ErrorKind::Interrupted,
257 format!("Invalid token, address mismatch"),
258 ));
259 }
260
261 Ok(quiche::ConnectionId::from_ref(&token[addr.len()..]))
262 }
263
264 fn mint_token<'a>(&self, hdr: &quiche::Header<'a>, src: &SocketAddr) -> Vec<u8> {
265 let mut token = Vec::new();
266
267 token.extend_from_slice(b"quiche");
268
269 let addr = match src.ip() {
270 std::net::IpAddr::V4(a) => a.octets().to_vec(),
271 std::net::IpAddr::V6(a) => a.octets().to_vec(),
272 };
273
274 token.extend_from_slice(&addr);
275 token.extend_from_slice(&hdr.dcid);
276
277 token
278 }
279}
280
281#[derive(Clone, PartialEq, Eq, Hash)]
282struct QuicListenerAccept;
283
284#[derive(Clone)]
285pub struct QuicListener {
286 laddrs: Arc<Vec<SocketAddr>>,
287 state: Arc<Mutex<QuicListenerState>>,
288 event_map: Arc<KeyWaitMap<QuicListenerAccept, ()>>,
289 send_map: FuturesUnorderedMap<QuicConnState, Result<(Vec<u8>, SendInfo)>>,
290}
291
292impl QuicListener {
293 async fn remove_conn(&self, scid: &ConnectionId<'static>) {
294 let mut raw = self.state.lock().await;
295
296 if raw.remove_conn(scid) {
297 log::trace!("scid={:?}, remove connection from server pool", scid);
298 } else {
299 log::warn!(
300 "scid={:?}, removed from server pool with error: not found",
301 scid
302 );
303 }
304 }
305}
306
307impl QuicListener {
308 pub fn new<A: ToSocketAddrs>(laddrs: A, config: Config) -> Result<Self> {
310 Ok(QuicListener {
311 laddrs: Arc::new(laddrs.to_socket_addrs()?.collect()),
312 state: Arc::new(Mutex::new(QuicListenerState::new(config)?)),
313 event_map: Arc::new(KeyWaitMap::new()),
314 send_map: FuturesUnorderedMap::new(),
315 })
316 }
317
318 pub fn local_addrs(&self) -> impl Iterator<Item = &SocketAddr> {
320 self.laddrs.iter()
321 }
322
323 pub async fn send(&self) -> Result<(Vec<u8>, SendInfo)> {
324 while let Some((conn, result)) = (&self.send_map).next().await {
325 match result {
326 Ok((buf, send_info)) => {
327 let send = conn.clone().send_owned();
328 self.send_map.insert(conn, send);
329
330 return Ok((buf, send_info));
331 }
332 Err(err) => {
333 log::error!(
334 "QuicConn: id={:?}, send with error, err={}, removed from listener pool",
335 conn.id,
336 err
337 );
338 }
339 }
340 }
341
342 Err(Error::new(ErrorKind::BrokenPipe, "QuicListener broken"))
343 }
344
345 pub async fn recv<Buf: AsMut<[u8]>>(
349 &self,
350 mut buf: Buf,
351 recv_info: RecvInfo,
352 ) -> Result<(usize, Option<Vec<u8>>)> {
353 let buf = buf.as_mut();
354 let header =
355 quiche::Header::from_slice(buf, quiche::MAX_CONN_ID_LEN).map_err(map_quic_error)?;
356
357 let mut state = self.state.lock().await;
358
359 log::trace!("quic listener: {:?}", header);
360
361 if let Some((conn, is_established)) = state.get_conn(&header.dcid) {
362 drop(state);
364
365 let recv_size = match conn.recv(buf, recv_info).await {
366 Ok(recv_size) => recv_size,
367 Err(err) => {
368 log::error!("conn recv, id={:?}, err={}", conn.id, err);
369
370 self.remove_conn(&header.dcid).await;
371
372 return Ok((buf.len(), None));
373 }
374 };
375
376 if !is_established && conn.is_established().await {
377 state = self.state.lock().await;
379 state.established(&header.dcid);
381
382 self.event_map.insert(QuicListenerAccept, ());
383 }
384
385 return Ok((recv_size, None));
386 }
387
388 match state.handshake(&header, buf, recv_info) {
390 Ok(QuicListenerHandshake::Connection {
391 conn,
392 is_established,
393 read_size,
394 }) => {
395 if is_established {
397 self.event_map.insert(QuicListenerAccept, ());
398 }
399
400 let send = conn.clone().send_owned();
401
402 self.send_map.insert(conn, send);
403
404 return Ok((read_size, None));
405 }
406 Ok(QuicListenerHandshake::Response {
407 buf,
408 read_size: recv_size,
409 }) => return Ok((recv_size, Some(buf))),
410 Err(err) => {
411 log::error!("quic listener handshake, err={}", err);
412
413 return Ok((buf.len(), None));
414 }
415 }
416 }
417
418 pub async fn accept(&self) -> Result<QuicConn> {
420 loop {
421 let mut state = self.state.lock().await;
422
423 if let Some(conn) = state.incoming_conns.pop_front() {
424 return Ok(conn);
425 }
426
427 self.event_map.wait(&QuicListenerAccept, state).await;
428 }
429 }
430
431 pub fn incoming(&self) -> impl Stream<Item = Result<QuicConn>> + Send + Unpin {
433 Box::pin(unfold(self.clone(), |listener| async {
434 let res = listener.accept().await;
435 Some((res, listener))
436 }))
437 }
438}