futures_quic/
listener.rs

1use std::collections::{HashMap, VecDeque};
2use std::io::{Error, ErrorKind, Result};
3use std::net::{SocketAddr, ToSocketAddrs};
4use std::sync::Arc;
5
6use futures::lock::Mutex;
7use futures::stream::unfold;
8use futures::{Stream, StreamExt};
9use futures_map::{FuturesUnorderedMap, KeyWaitMap};
10use quiche::{Config, ConnectionId, RecvInfo, SendInfo};
11use ring::{hmac::Key, rand::SystemRandom};
12
13use crate::errors::map_quic_error;
14use crate::{QuicConn, QuicConnState};
15
16enum QuicListenerHandshake {
17    Connection {
18        #[allow(unused)]
19        conn: QuicConnState,
20        is_established: bool,
21        /// the number of bytes processed from the input buffer
22        read_size: usize,
23    },
24    Response {
25        /// buf of response packet.
26        buf: Vec<u8>,
27        /// the number of bytes processed from the input buffer
28        read_size: usize,
29    },
30}
31
32/// Server-side incoming connection handshake pool.
33pub struct QuicListenerState {
34    /// The quic config shared between connections for this listener.
35    config: Config,
36    /// The seed for generating source id
37    seed_key: Key,
38    /// A quic connection pool that is in the handshake phase
39    handshaking_pool: HashMap<ConnectionId<'static>, QuicConnState>,
40    /// A quic connection pool that is already connected
41    established_conns: HashMap<ConnectionId<'static>, QuicConnState>,
42    /// When first see a inbound quic connection, push it into this queue.
43    incoming_conns: VecDeque<QuicConn>,
44}
45
46impl QuicListenerState {
47    /// Create `HandshakePool` with provided `config`.
48    fn new(config: Config) -> Result<Self> {
49        let rng = SystemRandom::new();
50
51        let seed_key = ring::hmac::Key::generate(ring::hmac::HMAC_SHA256, &rng)
52            .map_err(|err| Error::new(ErrorKind::Other, format!("{}", err)))?;
53
54        Ok(Self {
55            config,
56            seed_key,
57            handshaking_pool: Default::default(),
58            incoming_conns: Default::default(),
59            established_conns: Default::default(),
60        })
61    }
62
63    /// Get connection by id.
64    ///
65    /// If found, returns tuple (QuicConnState, is_established).
66    fn get_conn<'a>(&self, id: &ConnectionId<'a>) -> Option<(QuicConnState, bool)> {
67        if let Some(conn) = self.handshaking_pool.get(id) {
68            return Some((conn.clone(), false));
69        }
70
71        if let Some(conn) = self.established_conns.get(id) {
72            return Some((conn.clone(), true));
73        }
74
75        None
76    }
77
78    /// Move connection from handshaking set to established set by id.
79    #[allow(unused)]
80    fn established<'a>(&mut self, id: &ConnectionId<'a>) {
81        let id = id.clone().into_owned();
82        if let Some(conn) = self.handshaking_pool.remove(&id) {
83            self.established_conns.insert(id, conn.clone());
84            self.incoming_conns.push_back(conn.into());
85        }
86    }
87
88    /// remove connection from pool.
89    fn remove_conn<'a>(&mut self, id: &ConnectionId<'a>) -> bool {
90        let id = id.clone().into_owned();
91        if self.handshaking_pool.remove(&id).is_some() {
92            return true;
93        }
94
95        if self.established_conns.remove(&id).is_some() {
96            return true;
97        }
98
99        false
100    }
101
102    /// Process Initial packet.
103    fn handshake<'a>(
104        &mut self,
105        header: &quiche::Header<'a>,
106        buf: &'a mut [u8],
107        recv_info: RecvInfo,
108    ) -> Result<QuicListenerHandshake> {
109        if header.ty != quiche::Type::Initial {
110            return Err(Error::new(
111                ErrorKind::InvalidData,
112                format!("Invalid packet: {:?}", recv_info),
113            ));
114        }
115
116        self.client_hello(header, buf, recv_info)
117    }
118
119    fn client_hello<'a>(
120        &mut self,
121        header: &quiche::Header<'a>,
122        buf: &'a mut [u8],
123        recv_info: RecvInfo,
124    ) -> Result<QuicListenerHandshake> {
125        if !quiche::version_is_supported(header.version) {
126            return self.negotiation_version(header, recv_info, buf);
127        }
128
129        let token = header.token.as_ref().unwrap();
130
131        // generate new token and retry
132        if token.is_empty() {
133            return self.retry(header, recv_info, buf);
134        }
135
136        // check token .
137        let odcid = Self::validate_token(token, &recv_info.from)?;
138
139        let scid: quiche::ConnectionId<'_> = header.dcid.clone();
140
141        if quiche::MAX_CONN_ID_LEN != scid.len() {
142            return Err(Error::new(
143                ErrorKind::Interrupted,
144                format!("Check dcid length error, len={}", scid.len()),
145            ));
146        }
147
148        let mut quiche_conn = quiche::accept(
149            &scid,
150            Some(&odcid),
151            recv_info.to,
152            recv_info.from,
153            &mut self.config,
154        )
155        .map_err(map_quic_error)?;
156
157        let read_size = quiche_conn.recv(buf, recv_info).map_err(map_quic_error)?;
158
159        let is_established = quiche_conn.is_established();
160
161        let scid = quiche_conn.source_id().into_owned();
162        let dcid = quiche_conn.destination_id().into_owned();
163
164        log::trace!("Create new incoming conn, scid={:?}, dcid={:?}", scid, dcid);
165
166        let conn = QuicConnState::new(quiche_conn, 1, None);
167
168        if is_established {
169            self.established_conns.insert(scid, conn.clone());
170            self.incoming_conns.push_back(conn.clone().into());
171        } else {
172            self.handshaking_pool.insert(scid, conn.clone());
173        }
174
175        Ok(QuicListenerHandshake::Connection {
176            conn,
177            is_established,
178            read_size,
179        })
180    }
181
182    fn negotiation_version<'a>(
183        &mut self,
184        header: &quiche::Header<'a>,
185        _recv_info: RecvInfo,
186        buf: &mut [u8],
187    ) -> Result<QuicListenerHandshake> {
188        let scid = header.scid.clone().into_owned();
189        let dcid = header.dcid.clone().into_owned();
190
191        let mut read_buf = vec![0; 128];
192
193        let write_size = quiche::negotiate_version(&scid, &dcid, buf).map_err(map_quic_error)?;
194
195        read_buf.resize(write_size, 0);
196
197        Ok(QuicListenerHandshake::Response {
198            buf: read_buf,
199            read_size: buf.len(),
200        })
201    }
202    /// Generate retry package
203    fn retry<'a>(
204        &mut self,
205        header: &quiche::Header<'a>,
206        recv_info: RecvInfo,
207        buf: &mut [u8],
208    ) -> Result<QuicListenerHandshake> {
209        let token = self.mint_token(&header, &recv_info.from);
210
211        let new_scid = ring::hmac::sign(&self.seed_key, &header.dcid);
212        let new_scid = &new_scid.as_ref()[..quiche::MAX_CONN_ID_LEN];
213        let new_scid = quiche::ConnectionId::from_vec(new_scid.to_vec());
214
215        let scid = header.scid.clone().into_owned();
216        let dcid: ConnectionId<'_> = header.dcid.clone().into_owned();
217        let version = header.version;
218
219        let mut read_buf = vec![0; 1200];
220
221        let write_size = quiche::retry(&scid, &dcid, &new_scid, &token, version, &mut read_buf)
222            .map_err(map_quic_error)?;
223
224        read_buf.resize(write_size, 0);
225
226        Ok(QuicListenerHandshake::Response {
227            buf: read_buf,
228            read_size: buf.len(),
229        })
230    }
231
232    fn validate_token<'a>(token: &'a [u8], src: &SocketAddr) -> Result<quiche::ConnectionId<'a>> {
233        if token.len() < 6 {
234            return Err(Error::new(
235                ErrorKind::Interrupted,
236                format!("Invalid token, token length < 6"),
237            ));
238        }
239
240        if &token[..6] != b"quiche" {
241            return Err(Error::new(
242                ErrorKind::Interrupted,
243                format!("Invalid token, not start with 'quiche'"),
244            ));
245        }
246
247        let token = &token[6..];
248
249        let addr = match src.ip() {
250            std::net::IpAddr::V4(a) => a.octets().to_vec(),
251            std::net::IpAddr::V6(a) => a.octets().to_vec(),
252        };
253
254        if token.len() < addr.len() || &token[..addr.len()] != addr.as_slice() {
255            return Err(Error::new(
256                ErrorKind::Interrupted,
257                format!("Invalid token, address mismatch"),
258            ));
259        }
260
261        Ok(quiche::ConnectionId::from_ref(&token[addr.len()..]))
262    }
263
264    fn mint_token<'a>(&self, hdr: &quiche::Header<'a>, src: &SocketAddr) -> Vec<u8> {
265        let mut token = Vec::new();
266
267        token.extend_from_slice(b"quiche");
268
269        let addr = match src.ip() {
270            std::net::IpAddr::V4(a) => a.octets().to_vec(),
271            std::net::IpAddr::V6(a) => a.octets().to_vec(),
272        };
273
274        token.extend_from_slice(&addr);
275        token.extend_from_slice(&hdr.dcid);
276
277        token
278    }
279}
280
281#[derive(Clone, PartialEq, Eq, Hash)]
282struct QuicListenerAccept;
283
284#[derive(Clone)]
285pub struct QuicListener {
286    laddrs: Arc<Vec<SocketAddr>>,
287    state: Arc<Mutex<QuicListenerState>>,
288    event_map: Arc<KeyWaitMap<QuicListenerAccept, ()>>,
289    send_map: FuturesUnorderedMap<QuicConnState, Result<(Vec<u8>, SendInfo)>>,
290}
291
292impl QuicListener {
293    async fn remove_conn(&self, scid: &ConnectionId<'static>) {
294        let mut raw = self.state.lock().await;
295
296        if raw.remove_conn(scid) {
297            log::trace!("scid={:?}, remove connection from server pool", scid);
298        } else {
299            log::warn!(
300                "scid={:?}, removed from server pool with error: not found",
301                scid
302            );
303        }
304    }
305}
306
307impl QuicListener {
308    /// Create a new `QuicListener` instance with provided `laddrs` and `config`.
309    pub fn new<A: ToSocketAddrs>(laddrs: A, config: Config) -> Result<Self> {
310        Ok(QuicListener {
311            laddrs: Arc::new(laddrs.to_socket_addrs()?.collect()),
312            state: Arc::new(Mutex::new(QuicListenerState::new(config)?)),
313            event_map: Arc::new(KeyWaitMap::new()),
314            send_map: FuturesUnorderedMap::new(),
315        })
316    }
317
318    /// Get the `QuicListener`'s local bound socket address iterator.
319    pub fn local_addrs(&self) -> impl Iterator<Item = &SocketAddr> {
320        self.laddrs.iter()
321    }
322
323    pub async fn send(&self) -> Result<(Vec<u8>, SendInfo)> {
324        while let Some((conn, result)) = (&self.send_map).next().await {
325            match result {
326                Ok((buf, send_info)) => {
327                    let send = conn.clone().send_owned();
328                    self.send_map.insert(conn, send);
329
330                    return Ok((buf, send_info));
331                }
332                Err(err) => {
333                    log::error!(
334                        "QuicConn: id={:?}, send with error, err={}, removed from listener pool",
335                        conn.id,
336                        err
337                    );
338                }
339            }
340        }
341
342        Err(Error::new(ErrorKind::BrokenPipe, "QuicListener broken"))
343    }
344
345    /// Processes QUIC packets received from the peer.
346    ///
347    /// On success the number of bytes processed from the input buffer is returned.
348    pub async fn recv<Buf: AsMut<[u8]>>(
349        &self,
350        mut buf: Buf,
351        recv_info: RecvInfo,
352    ) -> Result<(usize, Option<Vec<u8>>)> {
353        let buf = buf.as_mut();
354        let header =
355            quiche::Header::from_slice(buf, quiche::MAX_CONN_ID_LEN).map_err(map_quic_error)?;
356
357        let mut state = self.state.lock().await;
358
359        log::trace!("quic listener: {:?}", header);
360
361        if let Some((conn, is_established)) = state.get_conn(&header.dcid) {
362            // release the lock before call [QuicConnState::recv] function.
363            drop(state);
364
365            let recv_size = match conn.recv(buf, recv_info).await {
366                Ok(recv_size) => recv_size,
367                Err(err) => {
368                    log::error!("conn recv, id={:?}, err={}", conn.id, err);
369
370                    self.remove_conn(&header.dcid).await;
371
372                    return Ok((buf.len(), None));
373                }
374            };
375
376            if !is_established && conn.is_established().await {
377                // relock the state.
378                state = self.state.lock().await;
379                // move the connection to established set and push state into incoming queue.
380                state.established(&header.dcid);
381
382                self.event_map.insert(QuicListenerAccept, ());
383            }
384
385            return Ok((recv_size, None));
386        }
387
388        // Perform the handshake process.
389        match state.handshake(&header, buf, recv_info) {
390            Ok(QuicListenerHandshake::Connection {
391                conn,
392                is_established,
393                read_size,
394            }) => {
395                // notify incoming queue read ops.
396                if is_established {
397                    self.event_map.insert(QuicListenerAccept, ());
398                }
399
400                let send = conn.clone().send_owned();
401
402                self.send_map.insert(conn, send);
403
404                return Ok((read_size, None));
405            }
406            Ok(QuicListenerHandshake::Response {
407                buf,
408                read_size: recv_size,
409            }) => return Ok((recv_size, Some(buf))),
410            Err(err) => {
411                log::error!("quic listener handshake, err={}", err);
412
413                return Ok((buf.len(), None));
414            }
415        }
416    }
417
418    /// Accept a new inbound connection.
419    pub async fn accept(&self) -> Result<QuicConn> {
420        loop {
421            let mut state = self.state.lock().await;
422
423            if let Some(conn) = state.incoming_conns.pop_front() {
424                return Ok(conn);
425            }
426
427            self.event_map.wait(&QuicListenerAccept, state).await;
428        }
429    }
430
431    /// Returns a stream of incoming connections.
432    pub fn incoming(&self) -> impl Stream<Item = Result<QuicConn>> + Send + Unpin {
433        Box::pin(unfold(self.clone(), |listener| async {
434            let res = listener.accept().await;
435            Some((res, listener))
436        }))
437    }
438}