ntp_proto/
keyset.rs

1use std::{
2    io::{Read, Write},
3    sync::Arc,
4};
5
6use aead::{generic_array::GenericArray, KeyInit};
7
8use crate::{
9    nts_record::AeadAlgorithm,
10    packet::{
11        AesSivCmac256, AesSivCmac512, Cipher, CipherHolder, CipherProvider, DecryptError,
12        EncryptResult, ExtensionField,
13    },
14};
15
16pub struct DecodedServerCookie {
17    pub(crate) algorithm: AeadAlgorithm,
18    pub s2c: Box<dyn Cipher>,
19    pub c2s: Box<dyn Cipher>,
20}
21
22impl DecodedServerCookie {
23    fn plaintext(&self) -> Vec<u8> {
24        let mut plaintext = Vec::new();
25
26        let algorithm_bytes = (self.algorithm as u16).to_be_bytes();
27        plaintext.extend_from_slice(&algorithm_bytes);
28        plaintext.extend_from_slice(self.s2c.key_bytes());
29        plaintext.extend_from_slice(self.c2s.key_bytes());
30
31        plaintext
32    }
33}
34
35impl std::fmt::Debug for DecodedServerCookie {
36    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37        f.debug_struct("DecodedServerCookie")
38            .field("algorithm", &self.algorithm)
39            .finish()
40    }
41}
42
43#[derive(Debug)]
44pub struct KeySetProvider {
45    current: Arc<KeySet>,
46    history: usize,
47}
48
49impl KeySetProvider {
50    /// Create a new keysetprovider that keeps history old
51    /// keys around (so in total, history+1 keys are valid
52    /// at any time)
53    pub fn new(history: usize) -> Self {
54        KeySetProvider {
55            current: Arc::new(KeySet {
56                keys: vec![AesSivCmac512::new(aes_siv::Aes256SivAead::generate_key(
57                    rand::thread_rng(),
58                ))],
59                id_offset: 0,
60                primary: 0,
61            }),
62            history,
63        }
64    }
65
66    #[cfg(feature = "__internal-fuzz")]
67    pub fn dangerous_new_deterministic(history: usize) -> Self {
68        KeySetProvider {
69            current: Arc::new(KeySet {
70                keys: vec![AesSivCmac512::new(
71                    std::array::from_fn(|i| (i as u8)).into(),
72                )],
73                id_offset: 0,
74                primary: 0,
75            }),
76            history,
77        }
78    }
79
80    /// Rotate a new key in as primary, forgetting an old one if needed
81    pub fn rotate(&mut self) {
82        let next_key = AesSivCmac512::new(aes_siv::Aes256SivAead::generate_key(rand::thread_rng()));
83        let mut keys = Vec::with_capacity((self.history + 1).min(self.current.keys.len() + 1));
84        for key in self.current.keys
85            [self.current.keys.len().saturating_sub(self.history)..self.current.keys.len()]
86            .iter()
87        {
88            // This is the rare case where we do really want to make a copy.
89            keys.push(AesSivCmac512::new(GenericArray::clone_from_slice(
90                key.key_bytes(),
91            )));
92        }
93        keys.push(next_key);
94        self.current = Arc::new(KeySet {
95            id_offset: self
96                .current
97                .id_offset
98                .wrapping_add(self.current.keys.len().saturating_sub(self.history) as u32),
99            primary: keys.len() as u32 - 1,
100            keys,
101        });
102    }
103
104    pub fn load(
105        reader: &mut impl Read,
106        history: usize,
107    ) -> std::io::Result<(Self, std::time::SystemTime)> {
108        let mut buf = [0; 64];
109        reader.read_exact(&mut buf[0..20])?;
110        let time = std::time::SystemTime::UNIX_EPOCH
111            + std::time::Duration::from_secs(u64::from_be_bytes(buf[0..8].try_into().unwrap()));
112        let id_offset = u32::from_be_bytes(buf[8..12].try_into().unwrap());
113        let primary = u32::from_be_bytes(buf[12..16].try_into().unwrap());
114        let len = u32::from_be_bytes(buf[16..20].try_into().unwrap());
115        if primary > len {
116            return Err(std::io::ErrorKind::Other.into());
117        }
118        let mut keys = vec![];
119        for _ in 0..len {
120            reader.read_exact(&mut buf[0..64])?;
121            keys.push(AesSivCmac512::new(buf.into()));
122        }
123        Ok((
124            KeySetProvider {
125                current: Arc::new(KeySet {
126                    keys,
127                    id_offset,
128                    primary,
129                }),
130                history,
131            },
132            time,
133        ))
134    }
135
136    pub fn store(&self, writer: &mut impl Write) -> std::io::Result<()> {
137        let time = std::time::SystemTime::now()
138            .duration_since(std::time::SystemTime::UNIX_EPOCH)
139            .expect("Could not get current time");
140        writer.write_all(&time.as_secs().to_be_bytes())?;
141        writer.write_all(&self.current.id_offset.to_be_bytes())?;
142        writer.write_all(&self.current.primary.to_be_bytes())?;
143        writer.write_all(&(self.current.keys.len() as u32).to_be_bytes())?;
144        for key in self.current.keys.iter() {
145            writer.write_all(key.key_bytes())?;
146        }
147        Ok(())
148    }
149
150    /// Get the current KeySet
151    pub fn get(&self) -> Arc<KeySet> {
152        self.current.clone()
153    }
154}
155
156pub struct KeySet {
157    keys: Vec<AesSivCmac512>,
158    id_offset: u32,
159    primary: u32,
160}
161
162impl KeySet {
163    #[cfg(feature = "__internal-fuzz")]
164    pub fn encode_cookie_pub(&self, cookie: &DecodedServerCookie) -> Vec<u8> {
165        self.encode_cookie(cookie)
166    }
167
168    pub(crate) fn encode_cookie(&self, cookie: &DecodedServerCookie) -> Vec<u8> {
169        let mut output = cookie.plaintext();
170        let plaintext_length = output.as_slice().len();
171
172        // Add space for header (4 + 2 bytes), additional ciphertext
173        // data from the cmac (16 bytes) and nonce (16 bytes).
174        output.resize(output.len() + 2 + 4 + 16 + 16, 0);
175
176        // And move plaintext to make space for header
177        output.copy_within(0..plaintext_length, 6);
178        let EncryptResult {
179            nonce_length,
180            ciphertext_length,
181        } = self.keys[self.primary as usize]
182            .encrypt(&mut output[6..], plaintext_length, &[])
183            .expect("Failed to encrypt cookie");
184
185        debug_assert_eq!(nonce_length, 16);
186        debug_assert_eq!(plaintext_length + 16, ciphertext_length);
187
188        output[0..4].copy_from_slice(&(self.primary.wrapping_add(self.id_offset)).to_be_bytes());
189        output[4..6].copy_from_slice(&(ciphertext_length as u16).to_be_bytes());
190        debug_assert_eq!(output.len(), 6 + nonce_length + ciphertext_length);
191        output
192    }
193
194    #[cfg(feature = "__internal-fuzz")]
195    pub fn decode_cookie_pub(&self, cookie: &[u8]) -> Result<DecodedServerCookie, DecryptError> {
196        self.decode_cookie(cookie)
197    }
198
199    pub(crate) fn decode_cookie(&self, cookie: &[u8]) -> Result<DecodedServerCookie, DecryptError> {
200        // we need at least an id, cipher text length and nonce for this message to be valid
201        if cookie.len() < 4 + 2 + 16 {
202            return Err(DecryptError);
203        }
204
205        let id = u32::from_be_bytes(cookie[0..4].try_into().unwrap());
206        let id = id.wrapping_sub(self.id_offset) as usize;
207        let key = self.keys.get(id).ok_or(DecryptError)?;
208
209        let cipher_text_length = u16::from_be_bytes([cookie[4], cookie[5]]) as usize;
210
211        let nonce = &cookie[6..22];
212        let ciphertext = cookie[22..].get(..cipher_text_length).ok_or(DecryptError)?;
213        let plaintext = key.decrypt(nonce, ciphertext, &[])?;
214
215        let [b0, b1, ref key_bytes @ ..] = plaintext[..] else {
216            return Err(DecryptError);
217        };
218
219        let algorithm =
220            AeadAlgorithm::try_deserialize(u16::from_be_bytes([b0, b1])).ok_or(DecryptError)?;
221
222        Ok(match algorithm {
223            AeadAlgorithm::AeadAesSivCmac256 => {
224                const KEY_WIDTH: usize = 32;
225
226                if key_bytes.len() != 2 * KEY_WIDTH {
227                    return Err(DecryptError);
228                }
229
230                let (s2c, c2s) = key_bytes.split_at(KEY_WIDTH);
231
232                DecodedServerCookie {
233                    algorithm,
234                    s2c: Box::new(AesSivCmac256::new(GenericArray::clone_from_slice(s2c))),
235                    c2s: Box::new(AesSivCmac256::new(GenericArray::clone_from_slice(c2s))),
236                }
237            }
238            AeadAlgorithm::AeadAesSivCmac512 => {
239                const KEY_WIDTH: usize = 64;
240
241                if key_bytes.len() != 2 * KEY_WIDTH {
242                    return Err(DecryptError);
243                }
244
245                let (s2c, c2s) = key_bytes.split_at(KEY_WIDTH);
246
247                DecodedServerCookie {
248                    algorithm,
249                    s2c: Box::new(AesSivCmac512::new(GenericArray::clone_from_slice(s2c))),
250                    c2s: Box::new(AesSivCmac512::new(GenericArray::clone_from_slice(c2s))),
251                }
252            }
253        })
254    }
255
256    #[cfg(test)]
257    pub(crate) fn new() -> Self {
258        Self {
259            keys: vec![AesSivCmac512::new(std::iter::repeat(0).take(64).collect())],
260            id_offset: 1,
261            primary: 0,
262        }
263    }
264}
265
266impl CipherProvider for KeySet {
267    fn get(&self, context: &[ExtensionField<'_>]) -> Option<CipherHolder<'_>> {
268        let mut decoded = None;
269
270        for ef in context {
271            if let ExtensionField::NtsCookie(cookie) = ef {
272                if decoded.is_some() {
273                    // more than one cookie, abort
274                    return None;
275                }
276                decoded = Some(self.decode_cookie(cookie).ok()?);
277            }
278        }
279
280        decoded.map(CipherHolder::DecodedServerCookie)
281    }
282}
283
284impl std::fmt::Debug for KeySet {
285    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
286        f.debug_struct("KeySet")
287            .field("keys", &self.keys.len())
288            .field("id_offset", &self.id_offset)
289            .field("primary", &self.primary)
290            .finish()
291    }
292}
293
294#[cfg(any(test, feature = "__internal-fuzz"))]
295pub fn test_cookie() -> DecodedServerCookie {
296    DecodedServerCookie {
297        algorithm: AeadAlgorithm::AeadAesSivCmac256,
298        s2c: Box::new(AesSivCmac256::new((0..32_u8).collect())),
299        c2s: Box::new(AesSivCmac256::new((32..64_u8).collect())),
300    }
301}
302
303#[cfg(test)]
304mod tests {
305
306    use std::io::Cursor;
307
308    use super::*;
309
310    #[test]
311    fn roundtrip_aes_siv_cmac_256() {
312        let decoded = DecodedServerCookie {
313            algorithm: AeadAlgorithm::AeadAesSivCmac256,
314            s2c: Box::new(AesSivCmac256::new((0..32_u8).collect())),
315            c2s: Box::new(AesSivCmac256::new((32..64_u8).collect())),
316        };
317
318        let keyset = KeySet {
319            keys: vec![AesSivCmac512::new(std::iter::repeat(0).take(64).collect())],
320            id_offset: 1,
321            primary: 0,
322        };
323
324        let encoded = keyset.encode_cookie(&decoded);
325        let round = keyset.decode_cookie(&encoded).unwrap();
326        assert_eq!(decoded.algorithm, round.algorithm);
327        assert_eq!(decoded.s2c.key_bytes(), round.s2c.key_bytes());
328        assert_eq!(decoded.c2s.key_bytes(), round.c2s.key_bytes());
329    }
330
331    #[test]
332    fn test_encode_after_rotate() {
333        let decoded = DecodedServerCookie {
334            algorithm: AeadAlgorithm::AeadAesSivCmac256,
335            s2c: Box::new(AesSivCmac256::new((0..32_u8).collect())),
336            c2s: Box::new(AesSivCmac256::new((32..64_u8).collect())),
337        };
338
339        let mut provider = KeySetProvider::new(1);
340        provider.rotate();
341        let keyset = provider.get();
342
343        let encoded = keyset.encode_cookie(&decoded);
344        let round = keyset.decode_cookie(&encoded).unwrap();
345        assert_eq!(decoded.algorithm, round.algorithm);
346        assert_eq!(decoded.s2c.key_bytes(), round.s2c.key_bytes());
347        assert_eq!(decoded.c2s.key_bytes(), round.c2s.key_bytes());
348    }
349
350    #[test]
351    fn can_decode_cookie_with_padding() {
352        let decoded = DecodedServerCookie {
353            algorithm: AeadAlgorithm::AeadAesSivCmac512,
354            s2c: Box::new(AesSivCmac512::new((0..64_u8).collect())),
355            c2s: Box::new(AesSivCmac512::new((64..128_u8).collect())),
356        };
357
358        let keyset = KeySet {
359            keys: vec![AesSivCmac512::new(std::iter::repeat(0).take(64).collect())],
360            id_offset: 1,
361            primary: 0,
362        };
363
364        let mut encoded = keyset.encode_cookie(&decoded);
365        encoded.extend([0, 0]);
366
367        let round = keyset.decode_cookie(&encoded).unwrap();
368        assert_eq!(decoded.algorithm, round.algorithm);
369        assert_eq!(decoded.s2c.key_bytes(), round.s2c.key_bytes());
370        assert_eq!(decoded.c2s.key_bytes(), round.c2s.key_bytes());
371    }
372
373    #[test]
374    fn roundtrip_aes_siv_cmac_512() {
375        let decoded = DecodedServerCookie {
376            algorithm: AeadAlgorithm::AeadAesSivCmac512,
377            s2c: Box::new(AesSivCmac512::new((0..64_u8).collect())),
378            c2s: Box::new(AesSivCmac512::new((64..128_u8).collect())),
379        };
380
381        let keyset = KeySet {
382            keys: vec![AesSivCmac512::new(std::iter::repeat(0).take(64).collect())],
383            id_offset: 1,
384            primary: 0,
385        };
386
387        let encoded = keyset.encode_cookie(&decoded);
388        let round = keyset.decode_cookie(&encoded).unwrap();
389        assert_eq!(decoded.algorithm, round.algorithm);
390        assert_eq!(decoded.s2c.key_bytes(), round.s2c.key_bytes());
391        assert_eq!(decoded.c2s.key_bytes(), round.c2s.key_bytes());
392    }
393
394    #[test]
395    fn test_save_restore() {
396        let mut provider = KeySetProvider::new(8);
397        provider.rotate();
398        provider.rotate();
399        let mut output = Cursor::new(vec![]);
400        provider.store(&mut output).unwrap();
401        let mut input = Cursor::new(output.into_inner());
402        let (copy, time) = KeySetProvider::load(&mut input, 8).unwrap();
403        assert!(
404            std::time::SystemTime::now()
405                .duration_since(time)
406                .unwrap()
407                .as_secs()
408                < 2
409        );
410        assert_eq!(provider.get().primary, copy.get().primary);
411        assert_eq!(provider.get().id_offset, copy.get().id_offset);
412        for i in 0..provider.get().keys.len() {
413            assert_eq!(
414                provider.get().keys[i].key_bytes(),
415                copy.get().keys[i].key_bytes()
416            );
417        }
418    }
419
420    #[test]
421    fn old_cookie_still_valid() {
422        let decoded = DecodedServerCookie {
423            algorithm: AeadAlgorithm::AeadAesSivCmac256,
424            s2c: Box::new(AesSivCmac256::new((0..32_u8).collect())),
425            c2s: Box::new(AesSivCmac256::new((32..64_u8).collect())),
426        };
427
428        let mut provider = KeySetProvider::new(1);
429        let encoded = provider.get().encode_cookie(&decoded);
430
431        let round = provider.get().decode_cookie(&encoded).unwrap();
432        assert_eq!(decoded.algorithm, round.algorithm);
433        assert_eq!(decoded.s2c.key_bytes(), round.s2c.key_bytes());
434        assert_eq!(decoded.c2s.key_bytes(), round.c2s.key_bytes());
435
436        provider.rotate();
437
438        let round = provider.get().decode_cookie(&encoded).unwrap();
439        assert_eq!(decoded.algorithm, round.algorithm);
440        assert_eq!(decoded.s2c.key_bytes(), round.s2c.key_bytes());
441        assert_eq!(decoded.c2s.key_bytes(), round.c2s.key_bytes());
442
443        provider.rotate();
444
445        assert!(provider.get().decode_cookie(&encoded).is_err());
446    }
447
448    #[test]
449    fn invalid_cookie_length() {
450        // this cookie data lies about its length, pretending to be longer than it actually is.
451        let input = b"\x23\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x04\x00\x24\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\x04\x00\x18\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x04\x04\x00\x28\x00\x10\x00\x10\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00";
452
453        let provider = KeySetProvider::new(1);
454
455        let output = provider.get().decode_cookie(input);
456
457        assert!(output.is_err());
458    }
459}