ant_quic/
cid_generator.rs

1// Copyright 2024 Saorsa Labs Ltd.
2//
3// This Saorsa Network Software is licensed under the General Public License (GPL), version 3.
4// Please see the file LICENSE-GPL, or visit <http://www.gnu.org/licenses/> for the full text.
5//
6// Full details available at https://saorsalabs.com/licenses
7
8use std::hash::Hasher;
9
10use rand::{Rng, RngCore};
11
12use crate::Duration;
13use crate::MAX_CID_SIZE;
14use crate::shared::ConnectionId;
15
16/// Generates connection IDs for incoming connections
17pub trait ConnectionIdGenerator: Send + Sync {
18    /// Generates a new CID
19    ///
20    /// Connection IDs MUST NOT contain any information that can be used by
21    /// an external observer (that is, one that does not cooperate with the
22    /// issuer) to correlate them with other connection IDs for the same
23    /// connection. They MUST have high entropy, e.g. due to encrypted data
24    /// or cryptographic-grade random data.
25    fn generate_cid(&mut self) -> ConnectionId;
26
27    /// Quickly determine whether `cid` could have been generated by this generator
28    ///
29    /// False positives are permitted, but increase the cost of handling invalid packets.
30    fn validate(&self, _cid: &ConnectionId) -> Result<(), InvalidCid> {
31        Ok(())
32    }
33
34    /// Returns the length of a CID for connections created by this generator
35    fn cid_len(&self) -> usize;
36    /// Returns the lifetime of generated Connection IDs
37    ///
38    /// Connection IDs will be retired after the returned `Duration`, if any. Assumed to be constant.
39    fn cid_lifetime(&self) -> Option<Duration>;
40}
41
42/// The connection ID was not recognized by the [`ConnectionIdGenerator`]
43#[derive(Debug, Copy, Clone)]
44pub struct InvalidCid;
45
46/// Generates purely random connection IDs of a specified length
47///
48/// Random CIDs can be smaller than those produced by [`HashedConnectionIdGenerator`], but cannot be
49/// usefully [`validate`](ConnectionIdGenerator::validate)d.
50#[derive(Debug, Clone, Copy)]
51pub struct RandomConnectionIdGenerator {
52    cid_len: usize,
53    lifetime: Option<Duration>,
54}
55
56impl Default for RandomConnectionIdGenerator {
57    fn default() -> Self {
58        Self {
59            cid_len: 8,
60            lifetime: None,
61        }
62    }
63}
64
65impl RandomConnectionIdGenerator {
66    /// Initialize Random CID generator with a fixed CID length
67    ///
68    /// The given length must be less than or equal to MAX_CID_SIZE.
69    pub fn new(cid_len: usize) -> Self {
70        debug_assert!(cid_len <= MAX_CID_SIZE);
71        Self {
72            cid_len,
73            ..Self::default()
74        }
75    }
76
77    /// Set the lifetime of CIDs created by this generator
78    pub fn set_lifetime(&mut self, d: Duration) -> &mut Self {
79        self.lifetime = Some(d);
80        self
81    }
82}
83
84impl ConnectionIdGenerator for RandomConnectionIdGenerator {
85    fn generate_cid(&mut self) -> ConnectionId {
86        let mut bytes_arr = [0; MAX_CID_SIZE];
87        rand::thread_rng().fill_bytes(&mut bytes_arr[..self.cid_len]);
88
89        ConnectionId::new(&bytes_arr[..self.cid_len])
90    }
91
92    /// Provide the length of dst_cid in short header packet
93    fn cid_len(&self) -> usize {
94        self.cid_len
95    }
96
97    fn cid_lifetime(&self) -> Option<Duration> {
98        self.lifetime
99    }
100}
101
102/// Generates 8-byte connection IDs that can be efficiently
103/// [`validate`](ConnectionIdGenerator::validate)d
104///
105/// This generator uses a non-cryptographic hash and can therefore still be spoofed, but nonetheless
106/// helps prevents Quinn from responding to non-QUIC packets at very low cost.
107pub struct HashedConnectionIdGenerator {
108    key: u64,
109    lifetime: Option<Duration>,
110}
111
112impl HashedConnectionIdGenerator {
113    /// Create a generator with a random key
114    pub fn new() -> Self {
115        Self::from_key(rand::thread_rng().r#gen())
116    }
117
118    /// Create a generator with a specific key
119    ///
120    /// Allows [`validate`](ConnectionIdGenerator::validate) to recognize a consistent set of
121    /// connection IDs across restarts
122    pub fn from_key(key: u64) -> Self {
123        Self {
124            key,
125            lifetime: None,
126        }
127    }
128
129    /// Set the lifetime of CIDs created by this generator
130    pub fn set_lifetime(&mut self, d: Duration) -> &mut Self {
131        self.lifetime = Some(d);
132        self
133    }
134}
135
136impl Default for HashedConnectionIdGenerator {
137    fn default() -> Self {
138        Self::new()
139    }
140}
141
142impl ConnectionIdGenerator for HashedConnectionIdGenerator {
143    fn generate_cid(&mut self) -> ConnectionId {
144        let mut bytes_arr = [0; NONCE_LEN + SIGNATURE_LEN];
145        rand::thread_rng().fill_bytes(&mut bytes_arr[..NONCE_LEN]);
146        let mut hasher = rustc_hash::FxHasher::default();
147        hasher.write_u64(self.key);
148        hasher.write(&bytes_arr[..NONCE_LEN]);
149        bytes_arr[NONCE_LEN..].copy_from_slice(&hasher.finish().to_le_bytes()[..SIGNATURE_LEN]);
150        ConnectionId::new(&bytes_arr)
151    }
152
153    fn validate(&self, cid: &ConnectionId) -> Result<(), InvalidCid> {
154        let (nonce, signature) = cid.split_at(NONCE_LEN);
155        let mut hasher = rustc_hash::FxHasher::default();
156        hasher.write_u64(self.key);
157        hasher.write(nonce);
158        let expected = hasher.finish().to_le_bytes();
159        match expected[..SIGNATURE_LEN] == signature[..] {
160            true => Ok(()),
161            false => Err(InvalidCid),
162        }
163    }
164
165    fn cid_len(&self) -> usize {
166        NONCE_LEN + SIGNATURE_LEN
167    }
168
169    fn cid_lifetime(&self) -> Option<Duration> {
170        self.lifetime
171    }
172}
173
174const NONCE_LEN: usize = 3; // Good for more than 16 million connections
175const SIGNATURE_LEN: usize = 8 - NONCE_LEN; // 8-byte total CID length
176
177#[cfg(test)]
178mod tests {
179    use super::*;
180
181    #[test]
182    fn validate_keyed_cid() {
183        let mut generator = HashedConnectionIdGenerator::new();
184        let cid = generator.generate_cid();
185        generator.validate(&cid).unwrap();
186    }
187}