1use std::{
2 io::{Read, Write},
3 sync::Arc,
4};
5
6use aead::{generic_array::GenericArray, KeyInit};
7
8use crate::{
9 nts_record::AeadAlgorithm,
10 packet::{
11 AesSivCmac256, AesSivCmac512, Cipher, CipherHolder, CipherProvider, DecryptError,
12 EncryptResult, ExtensionField,
13 },
14};
15
16pub struct DecodedServerCookie {
17 pub(crate) algorithm: AeadAlgorithm,
18 pub s2c: Box<dyn Cipher>,
19 pub c2s: Box<dyn Cipher>,
20}
21
22impl DecodedServerCookie {
23 fn plaintext(&self) -> Vec<u8> {
24 let mut plaintext = Vec::new();
25
26 let algorithm_bytes = (self.algorithm as u16).to_be_bytes();
27 plaintext.extend_from_slice(&algorithm_bytes);
28 plaintext.extend_from_slice(self.s2c.key_bytes());
29 plaintext.extend_from_slice(self.c2s.key_bytes());
30
31 plaintext
32 }
33}
34
35impl std::fmt::Debug for DecodedServerCookie {
36 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37 f.debug_struct("DecodedServerCookie")
38 .field("algorithm", &self.algorithm)
39 .finish()
40 }
41}
42
43#[derive(Debug)]
44pub struct KeySetProvider {
45 current: Arc<KeySet>,
46 history: usize,
47}
48
49impl KeySetProvider {
50 pub fn new(history: usize) -> Self {
54 KeySetProvider {
55 current: Arc::new(KeySet {
56 keys: vec![AesSivCmac512::new(aes_siv::Aes256SivAead::generate_key(
57 rand::thread_rng(),
58 ))],
59 id_offset: 0,
60 primary: 0,
61 }),
62 history,
63 }
64 }
65
66 #[cfg(feature = "__internal-fuzz")]
67 pub fn dangerous_new_deterministic(history: usize) -> Self {
68 KeySetProvider {
69 current: Arc::new(KeySet {
70 keys: vec![AesSivCmac512::new(
71 std::array::from_fn(|i| (i as u8)).into(),
72 )],
73 id_offset: 0,
74 primary: 0,
75 }),
76 history,
77 }
78 }
79
80 pub fn rotate(&mut self) {
82 let next_key = AesSivCmac512::new(aes_siv::Aes256SivAead::generate_key(rand::thread_rng()));
83 let mut keys = Vec::with_capacity((self.history + 1).min(self.current.keys.len() + 1));
84 for key in self.current.keys
85 [self.current.keys.len().saturating_sub(self.history)..self.current.keys.len()]
86 .iter()
87 {
88 keys.push(AesSivCmac512::new(GenericArray::clone_from_slice(
90 key.key_bytes(),
91 )));
92 }
93 keys.push(next_key);
94 self.current = Arc::new(KeySet {
95 id_offset: self
96 .current
97 .id_offset
98 .wrapping_add(self.current.keys.len().saturating_sub(self.history) as u32),
99 primary: keys.len() as u32 - 1,
100 keys,
101 });
102 }
103
104 pub fn load(
105 reader: &mut impl Read,
106 history: usize,
107 ) -> std::io::Result<(Self, std::time::SystemTime)> {
108 let mut buf = [0; 64];
109 reader.read_exact(&mut buf[0..20])?;
110 let time = std::time::SystemTime::UNIX_EPOCH
111 + std::time::Duration::from_secs(u64::from_be_bytes(buf[0..8].try_into().unwrap()));
112 let id_offset = u32::from_be_bytes(buf[8..12].try_into().unwrap());
113 let primary = u32::from_be_bytes(buf[12..16].try_into().unwrap());
114 let len = u32::from_be_bytes(buf[16..20].try_into().unwrap());
115 if primary > len {
116 return Err(std::io::ErrorKind::Other.into());
117 }
118 let mut keys = vec![];
119 for _ in 0..len {
120 reader.read_exact(&mut buf[0..64])?;
121 keys.push(AesSivCmac512::new(buf.into()));
122 }
123 Ok((
124 KeySetProvider {
125 current: Arc::new(KeySet {
126 keys,
127 id_offset,
128 primary,
129 }),
130 history,
131 },
132 time,
133 ))
134 }
135
136 pub fn store(&self, writer: &mut impl Write) -> std::io::Result<()> {
137 let time = std::time::SystemTime::now()
138 .duration_since(std::time::SystemTime::UNIX_EPOCH)
139 .expect("Could not get current time");
140 writer.write_all(&time.as_secs().to_be_bytes())?;
141 writer.write_all(&self.current.id_offset.to_be_bytes())?;
142 writer.write_all(&self.current.primary.to_be_bytes())?;
143 writer.write_all(&(self.current.keys.len() as u32).to_be_bytes())?;
144 for key in self.current.keys.iter() {
145 writer.write_all(key.key_bytes())?;
146 }
147 Ok(())
148 }
149
150 pub fn get(&self) -> Arc<KeySet> {
152 self.current.clone()
153 }
154}
155
156pub struct KeySet {
157 keys: Vec<AesSivCmac512>,
158 id_offset: u32,
159 primary: u32,
160}
161
162impl KeySet {
163 #[cfg(feature = "__internal-fuzz")]
164 pub fn encode_cookie_pub(&self, cookie: &DecodedServerCookie) -> Vec<u8> {
165 self.encode_cookie(cookie)
166 }
167
168 pub(crate) fn encode_cookie(&self, cookie: &DecodedServerCookie) -> Vec<u8> {
169 let mut output = cookie.plaintext();
170 let plaintext_length = output.as_slice().len();
171
172 output.resize(output.len() + 2 + 4 + 16 + 16, 0);
175
176 output.copy_within(0..plaintext_length, 6);
178 let EncryptResult {
179 nonce_length,
180 ciphertext_length,
181 } = self.keys[self.primary as usize]
182 .encrypt(&mut output[6..], plaintext_length, &[])
183 .expect("Failed to encrypt cookie");
184
185 debug_assert_eq!(nonce_length, 16);
186 debug_assert_eq!(plaintext_length + 16, ciphertext_length);
187
188 output[0..4].copy_from_slice(&(self.primary.wrapping_add(self.id_offset)).to_be_bytes());
189 output[4..6].copy_from_slice(&(ciphertext_length as u16).to_be_bytes());
190 debug_assert_eq!(output.len(), 6 + nonce_length + ciphertext_length);
191 output
192 }
193
194 #[cfg(feature = "__internal-fuzz")]
195 pub fn decode_cookie_pub(&self, cookie: &[u8]) -> Result<DecodedServerCookie, DecryptError> {
196 self.decode_cookie(cookie)
197 }
198
199 pub(crate) fn decode_cookie(&self, cookie: &[u8]) -> Result<DecodedServerCookie, DecryptError> {
200 if cookie.len() < 4 + 2 + 16 {
202 return Err(DecryptError);
203 }
204
205 let id = u32::from_be_bytes(cookie[0..4].try_into().unwrap());
206 let id = id.wrapping_sub(self.id_offset) as usize;
207 let key = self.keys.get(id).ok_or(DecryptError)?;
208
209 let cipher_text_length = u16::from_be_bytes([cookie[4], cookie[5]]) as usize;
210
211 let nonce = &cookie[6..22];
212 let ciphertext = cookie[22..].get(..cipher_text_length).ok_or(DecryptError)?;
213 let plaintext = key.decrypt(nonce, ciphertext, &[])?;
214
215 let [b0, b1, ref key_bytes @ ..] = plaintext[..] else {
216 return Err(DecryptError);
217 };
218
219 let algorithm =
220 AeadAlgorithm::try_deserialize(u16::from_be_bytes([b0, b1])).ok_or(DecryptError)?;
221
222 Ok(match algorithm {
223 AeadAlgorithm::AeadAesSivCmac256 => {
224 const KEY_WIDTH: usize = 32;
225
226 if key_bytes.len() != 2 * KEY_WIDTH {
227 return Err(DecryptError);
228 }
229
230 let (s2c, c2s) = key_bytes.split_at(KEY_WIDTH);
231
232 DecodedServerCookie {
233 algorithm,
234 s2c: Box::new(AesSivCmac256::new(GenericArray::clone_from_slice(s2c))),
235 c2s: Box::new(AesSivCmac256::new(GenericArray::clone_from_slice(c2s))),
236 }
237 }
238 AeadAlgorithm::AeadAesSivCmac512 => {
239 const KEY_WIDTH: usize = 64;
240
241 if key_bytes.len() != 2 * KEY_WIDTH {
242 return Err(DecryptError);
243 }
244
245 let (s2c, c2s) = key_bytes.split_at(KEY_WIDTH);
246
247 DecodedServerCookie {
248 algorithm,
249 s2c: Box::new(AesSivCmac512::new(GenericArray::clone_from_slice(s2c))),
250 c2s: Box::new(AesSivCmac512::new(GenericArray::clone_from_slice(c2s))),
251 }
252 }
253 })
254 }
255
256 #[cfg(test)]
257 pub(crate) fn new() -> Self {
258 Self {
259 keys: vec![AesSivCmac512::new(std::iter::repeat(0).take(64).collect())],
260 id_offset: 1,
261 primary: 0,
262 }
263 }
264}
265
266impl CipherProvider for KeySet {
267 fn get(&self, context: &[ExtensionField<'_>]) -> Option<CipherHolder<'_>> {
268 let mut decoded = None;
269
270 for ef in context {
271 if let ExtensionField::NtsCookie(cookie) = ef {
272 if decoded.is_some() {
273 return None;
275 }
276 decoded = Some(self.decode_cookie(cookie).ok()?);
277 }
278 }
279
280 decoded.map(CipherHolder::DecodedServerCookie)
281 }
282}
283
284impl std::fmt::Debug for KeySet {
285 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
286 f.debug_struct("KeySet")
287 .field("keys", &self.keys.len())
288 .field("id_offset", &self.id_offset)
289 .field("primary", &self.primary)
290 .finish()
291 }
292}
293
294#[cfg(any(test, feature = "__internal-fuzz"))]
295pub fn test_cookie() -> DecodedServerCookie {
296 DecodedServerCookie {
297 algorithm: AeadAlgorithm::AeadAesSivCmac256,
298 s2c: Box::new(AesSivCmac256::new((0..32_u8).collect())),
299 c2s: Box::new(AesSivCmac256::new((32..64_u8).collect())),
300 }
301}
302
303#[cfg(test)]
304mod tests {
305
306 use std::io::Cursor;
307
308 use super::*;
309
310 #[test]
311 fn roundtrip_aes_siv_cmac_256() {
312 let decoded = DecodedServerCookie {
313 algorithm: AeadAlgorithm::AeadAesSivCmac256,
314 s2c: Box::new(AesSivCmac256::new((0..32_u8).collect())),
315 c2s: Box::new(AesSivCmac256::new((32..64_u8).collect())),
316 };
317
318 let keyset = KeySet {
319 keys: vec![AesSivCmac512::new(std::iter::repeat(0).take(64).collect())],
320 id_offset: 1,
321 primary: 0,
322 };
323
324 let encoded = keyset.encode_cookie(&decoded);
325 let round = keyset.decode_cookie(&encoded).unwrap();
326 assert_eq!(decoded.algorithm, round.algorithm);
327 assert_eq!(decoded.s2c.key_bytes(), round.s2c.key_bytes());
328 assert_eq!(decoded.c2s.key_bytes(), round.c2s.key_bytes());
329 }
330
331 #[test]
332 fn test_encode_after_rotate() {
333 let decoded = DecodedServerCookie {
334 algorithm: AeadAlgorithm::AeadAesSivCmac256,
335 s2c: Box::new(AesSivCmac256::new((0..32_u8).collect())),
336 c2s: Box::new(AesSivCmac256::new((32..64_u8).collect())),
337 };
338
339 let mut provider = KeySetProvider::new(1);
340 provider.rotate();
341 let keyset = provider.get();
342
343 let encoded = keyset.encode_cookie(&decoded);
344 let round = keyset.decode_cookie(&encoded).unwrap();
345 assert_eq!(decoded.algorithm, round.algorithm);
346 assert_eq!(decoded.s2c.key_bytes(), round.s2c.key_bytes());
347 assert_eq!(decoded.c2s.key_bytes(), round.c2s.key_bytes());
348 }
349
350 #[test]
351 fn can_decode_cookie_with_padding() {
352 let decoded = DecodedServerCookie {
353 algorithm: AeadAlgorithm::AeadAesSivCmac512,
354 s2c: Box::new(AesSivCmac512::new((0..64_u8).collect())),
355 c2s: Box::new(AesSivCmac512::new((64..128_u8).collect())),
356 };
357
358 let keyset = KeySet {
359 keys: vec![AesSivCmac512::new(std::iter::repeat(0).take(64).collect())],
360 id_offset: 1,
361 primary: 0,
362 };
363
364 let mut encoded = keyset.encode_cookie(&decoded);
365 encoded.extend([0, 0]);
366
367 let round = keyset.decode_cookie(&encoded).unwrap();
368 assert_eq!(decoded.algorithm, round.algorithm);
369 assert_eq!(decoded.s2c.key_bytes(), round.s2c.key_bytes());
370 assert_eq!(decoded.c2s.key_bytes(), round.c2s.key_bytes());
371 }
372
373 #[test]
374 fn roundtrip_aes_siv_cmac_512() {
375 let decoded = DecodedServerCookie {
376 algorithm: AeadAlgorithm::AeadAesSivCmac512,
377 s2c: Box::new(AesSivCmac512::new((0..64_u8).collect())),
378 c2s: Box::new(AesSivCmac512::new((64..128_u8).collect())),
379 };
380
381 let keyset = KeySet {
382 keys: vec![AesSivCmac512::new(std::iter::repeat(0).take(64).collect())],
383 id_offset: 1,
384 primary: 0,
385 };
386
387 let encoded = keyset.encode_cookie(&decoded);
388 let round = keyset.decode_cookie(&encoded).unwrap();
389 assert_eq!(decoded.algorithm, round.algorithm);
390 assert_eq!(decoded.s2c.key_bytes(), round.s2c.key_bytes());
391 assert_eq!(decoded.c2s.key_bytes(), round.c2s.key_bytes());
392 }
393
394 #[test]
395 fn test_save_restore() {
396 let mut provider = KeySetProvider::new(8);
397 provider.rotate();
398 provider.rotate();
399 let mut output = Cursor::new(vec![]);
400 provider.store(&mut output).unwrap();
401 let mut input = Cursor::new(output.into_inner());
402 let (copy, time) = KeySetProvider::load(&mut input, 8).unwrap();
403 assert!(
404 std::time::SystemTime::now()
405 .duration_since(time)
406 .unwrap()
407 .as_secs()
408 < 2
409 );
410 assert_eq!(provider.get().primary, copy.get().primary);
411 assert_eq!(provider.get().id_offset, copy.get().id_offset);
412 for i in 0..provider.get().keys.len() {
413 assert_eq!(
414 provider.get().keys[i].key_bytes(),
415 copy.get().keys[i].key_bytes()
416 );
417 }
418 }
419
420 #[test]
421 fn old_cookie_still_valid() {
422 let decoded = DecodedServerCookie {
423 algorithm: AeadAlgorithm::AeadAesSivCmac256,
424 s2c: Box::new(AesSivCmac256::new((0..32_u8).collect())),
425 c2s: Box::new(AesSivCmac256::new((32..64_u8).collect())),
426 };
427
428 let mut provider = KeySetProvider::new(1);
429 let encoded = provider.get().encode_cookie(&decoded);
430
431 let round = provider.get().decode_cookie(&encoded).unwrap();
432 assert_eq!(decoded.algorithm, round.algorithm);
433 assert_eq!(decoded.s2c.key_bytes(), round.s2c.key_bytes());
434 assert_eq!(decoded.c2s.key_bytes(), round.c2s.key_bytes());
435
436 provider.rotate();
437
438 let round = provider.get().decode_cookie(&encoded).unwrap();
439 assert_eq!(decoded.algorithm, round.algorithm);
440 assert_eq!(decoded.s2c.key_bytes(), round.s2c.key_bytes());
441 assert_eq!(decoded.c2s.key_bytes(), round.c2s.key_bytes());
442
443 provider.rotate();
444
445 assert!(provider.get().decode_cookie(&encoded).is_err());
446 }
447
448 #[test]
449 fn invalid_cookie_length() {
450 let input = b"\x23\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x04\x00\x24\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\x04\x00\x18\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x04\x04\x00\x28\x00\x10\x00\x10\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00";
452
453 let provider = KeySetProvider::new(1);
454
455 let output = provider.get().decode_cookie(input);
456
457 assert!(output.is_err());
458 }
459}