ant_quic/
cid_generator.rs1use std::hash::Hasher;
9
10use rand::{Rng, RngCore};
11
12use crate::Duration;
13use crate::MAX_CID_SIZE;
14use crate::shared::ConnectionId;
15
16pub trait ConnectionIdGenerator: Send + Sync {
18 fn generate_cid(&mut self) -> ConnectionId;
26
27 fn validate(&self, _cid: &ConnectionId) -> Result<(), InvalidCid> {
31 Ok(())
32 }
33
34 fn cid_len(&self) -> usize;
36 fn cid_lifetime(&self) -> Option<Duration>;
40}
41
42#[derive(Debug, Copy, Clone)]
44pub struct InvalidCid;
45
46#[derive(Debug, Clone, Copy)]
51pub struct RandomConnectionIdGenerator {
52 cid_len: usize,
53 lifetime: Option<Duration>,
54}
55
56impl Default for RandomConnectionIdGenerator {
57 fn default() -> Self {
58 Self {
59 cid_len: 8,
60 lifetime: None,
61 }
62 }
63}
64
65impl RandomConnectionIdGenerator {
66 pub fn new(cid_len: usize) -> Self {
70 debug_assert!(cid_len <= MAX_CID_SIZE);
71 Self {
72 cid_len,
73 ..Self::default()
74 }
75 }
76
77 pub fn set_lifetime(&mut self, d: Duration) -> &mut Self {
79 self.lifetime = Some(d);
80 self
81 }
82}
83
84impl ConnectionIdGenerator for RandomConnectionIdGenerator {
85 fn generate_cid(&mut self) -> ConnectionId {
86 let mut bytes_arr = [0; MAX_CID_SIZE];
87 rand::thread_rng().fill_bytes(&mut bytes_arr[..self.cid_len]);
88
89 ConnectionId::new(&bytes_arr[..self.cid_len])
90 }
91
92 fn cid_len(&self) -> usize {
94 self.cid_len
95 }
96
97 fn cid_lifetime(&self) -> Option<Duration> {
98 self.lifetime
99 }
100}
101
102pub struct HashedConnectionIdGenerator {
108 key: u64,
109 lifetime: Option<Duration>,
110}
111
112impl HashedConnectionIdGenerator {
113 pub fn new() -> Self {
115 Self::from_key(rand::thread_rng().r#gen())
116 }
117
118 pub fn from_key(key: u64) -> Self {
123 Self {
124 key,
125 lifetime: None,
126 }
127 }
128
129 pub fn set_lifetime(&mut self, d: Duration) -> &mut Self {
131 self.lifetime = Some(d);
132 self
133 }
134}
135
136impl Default for HashedConnectionIdGenerator {
137 fn default() -> Self {
138 Self::new()
139 }
140}
141
142impl ConnectionIdGenerator for HashedConnectionIdGenerator {
143 fn generate_cid(&mut self) -> ConnectionId {
144 let mut bytes_arr = [0; NONCE_LEN + SIGNATURE_LEN];
145 rand::thread_rng().fill_bytes(&mut bytes_arr[..NONCE_LEN]);
146 let mut hasher = rustc_hash::FxHasher::default();
147 hasher.write_u64(self.key);
148 hasher.write(&bytes_arr[..NONCE_LEN]);
149 bytes_arr[NONCE_LEN..].copy_from_slice(&hasher.finish().to_le_bytes()[..SIGNATURE_LEN]);
150 ConnectionId::new(&bytes_arr)
151 }
152
153 fn validate(&self, cid: &ConnectionId) -> Result<(), InvalidCid> {
154 let (nonce, signature) = cid.split_at(NONCE_LEN);
155 let mut hasher = rustc_hash::FxHasher::default();
156 hasher.write_u64(self.key);
157 hasher.write(nonce);
158 let expected = hasher.finish().to_le_bytes();
159 match expected[..SIGNATURE_LEN] == signature[..] {
160 true => Ok(()),
161 false => Err(InvalidCid),
162 }
163 }
164
165 fn cid_len(&self) -> usize {
166 NONCE_LEN + SIGNATURE_LEN
167 }
168
169 fn cid_lifetime(&self) -> Option<Duration> {
170 self.lifetime
171 }
172}
173
174const NONCE_LEN: usize = 3; const SIGNATURE_LEN: usize = 8 - NONCE_LEN; #[cfg(test)]
178mod tests {
179 use super::*;
180
181 #[test]
182 fn validate_keyed_cid() {
183 let mut generator = HashedConnectionIdGenerator::new();
184 let cid = generator.generate_cid();
185 generator.validate(&cid).unwrap();
186 }
187}