ntp_proto/packet/
crypto.rs

1use std::fmt::Display;
2
3use aes_siv::{siv::Aes128Siv, siv::Aes256Siv, Key, KeyInit};
4use rand::Rng;
5use zeroize::{Zeroize, ZeroizeOnDrop};
6
7use crate::keyset::DecodedServerCookie;
8
9use super::extension_fields::ExtensionField;
10
11#[derive(Debug)]
12pub struct DecryptError;
13
14impl Display for DecryptError {
15    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
16        write!(f, "Could not decrypt ciphertext")
17    }
18}
19
20impl std::error::Error for DecryptError {}
21
22#[derive(Debug)]
23pub struct KeyError;
24
25impl Display for KeyError {
26    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27        write!(f, "Invalid key")
28    }
29}
30
31impl std::error::Error for KeyError {}
32
33struct Buffer<'a> {
34    buffer: &'a mut [u8],
35    valid: usize,
36}
37
38impl<'a> Buffer<'a> {
39    fn new(buffer: &'a mut [u8], valid: usize) -> Self {
40        Self { buffer, valid }
41    }
42
43    fn valid(&self) -> usize {
44        self.valid
45    }
46}
47
48impl AsMut<[u8]> for Buffer<'_> {
49    fn as_mut(&mut self) -> &mut [u8] {
50        &mut self.buffer[..self.valid]
51    }
52}
53
54impl AsRef<[u8]> for Buffer<'_> {
55    fn as_ref(&self) -> &[u8] {
56        &self.buffer[..self.valid]
57    }
58}
59
60impl aead::Buffer for Buffer<'_> {
61    fn extend_from_slice(&mut self, other: &[u8]) -> aead::Result<()> {
62        self.buffer
63            .get_mut(self.valid..(self.valid + other.len()))
64            .ok_or(aead::Error)?
65            .copy_from_slice(other);
66        self.valid += other.len();
67        Ok(())
68    }
69
70    fn truncate(&mut self, len: usize) {
71        self.valid = std::cmp::min(self.valid, len);
72    }
73}
74
75#[derive(Debug, Clone, Copy, PartialEq, Eq)]
76pub struct EncryptResult {
77    pub nonce_length: usize,
78    pub ciphertext_length: usize,
79}
80
81pub trait Cipher: Sync + Send + ZeroizeOnDrop + 'static {
82    /// encrypts the plaintext present in the buffer
83    ///
84    /// - encrypts `plaintext_length` bytes from the buffer
85    /// - puts the nonce followed by the ciphertext into the buffer
86    /// - returns the size of the nonce and ciphertext
87    fn encrypt(
88        &self,
89        buffer: &mut [u8],
90        plaintext_length: usize,
91        associated_data: &[u8],
92    ) -> std::io::Result<EncryptResult>;
93
94    // MUST support arbitrary length nonces
95    fn decrypt(
96        &self,
97        nonce: &[u8],
98        ciphertext: &[u8],
99        associated_data: &[u8],
100    ) -> Result<Vec<u8>, DecryptError>;
101
102    fn key_bytes(&self) -> &[u8];
103}
104
105pub enum CipherHolder<'a> {
106    DecodedServerCookie(DecodedServerCookie),
107    Other(&'a dyn Cipher),
108}
109
110impl AsRef<dyn Cipher> for CipherHolder<'_> {
111    fn as_ref(&self) -> &dyn Cipher {
112        match self {
113            CipherHolder::DecodedServerCookie(cookie) => cookie.c2s.as_ref(),
114            CipherHolder::Other(cipher) => *cipher,
115        }
116    }
117}
118
119pub trait CipherProvider {
120    fn get(&self, context: &[ExtensionField<'_>]) -> Option<CipherHolder<'_>>;
121}
122
123pub struct NoCipher;
124
125impl CipherProvider for NoCipher {
126    fn get<'a>(&self, _context: &[ExtensionField<'_>]) -> Option<CipherHolder<'_>> {
127        None
128    }
129}
130
131impl CipherProvider for dyn Cipher {
132    fn get(&self, _context: &[ExtensionField<'_>]) -> Option<CipherHolder<'_>> {
133        Some(CipherHolder::Other(self))
134    }
135}
136
137impl CipherProvider for Option<&dyn Cipher> {
138    fn get(&self, _context: &[ExtensionField<'_>]) -> Option<CipherHolder<'_>> {
139        self.map(CipherHolder::Other)
140    }
141}
142
143impl<C: Cipher> CipherProvider for C {
144    fn get(&self, _context: &[ExtensionField<'_>]) -> Option<CipherHolder<'_>> {
145        Some(CipherHolder::Other(self))
146    }
147}
148
149impl<C: Cipher> CipherProvider for Option<C> {
150    fn get(&self, _context: &[ExtensionField<'_>]) -> Option<CipherHolder<'_>> {
151        self.as_ref().map(|v| CipherHolder::Other(v))
152    }
153}
154
155pub struct AesSivCmac256 {
156    // 128 vs 256 difference is due to using the official name (us) vs
157    // the number of bits of security (aes_siv crate)
158    key: Key<Aes128Siv>,
159}
160
161impl ZeroizeOnDrop for AesSivCmac256 {}
162
163impl AesSivCmac256 {
164    pub fn new(key: Key<Aes128Siv>) -> Self {
165        AesSivCmac256 { key }
166    }
167
168    #[cfg(feature = "nts-pool")]
169    pub fn key_size() -> usize {
170        // prefer trust in compiler optimisation over trust in mental arithmetic
171        Self::new(Default::default()).key.len()
172    }
173
174    #[cfg(feature = "nts-pool")]
175    pub fn from_key_bytes(key_bytes: &[u8]) -> Result<Self, KeyError> {
176        (key_bytes.len() == Self::key_size())
177            .then(|| Self::new(*aead::Key::<Aes128Siv>::from_slice(key_bytes)))
178            .ok_or(KeyError)
179    }
180}
181
182impl Drop for AesSivCmac256 {
183    fn drop(&mut self) {
184        self.key.zeroize();
185    }
186}
187
188impl Cipher for AesSivCmac256 {
189    fn encrypt(
190        &self,
191        buffer: &mut [u8],
192        plaintext_length: usize,
193        associated_data: &[u8],
194    ) -> std::io::Result<EncryptResult> {
195        let mut siv = Aes128Siv::new(&self.key);
196        let nonce: [u8; 16] = rand::thread_rng().gen();
197
198        // Prepare the buffer for in place encryption by moving the plaintext
199        // back, creating space for the nonce.
200        if buffer.len() < nonce.len() + plaintext_length {
201            return Err(std::io::ErrorKind::WriteZero.into());
202        }
203        buffer.copy_within(..plaintext_length, nonce.len());
204        // And place the nonce where the caller expects it
205        buffer[..nonce.len()].copy_from_slice(&nonce);
206
207        // Create a wrapper around the plaintext portion of the buffer that has
208        // the methods aes_siv needs to do encryption in-place.
209        let mut buffer_wrap = Buffer::new(&mut buffer[nonce.len()..], plaintext_length);
210        siv.encrypt_in_place([associated_data, &nonce], &mut buffer_wrap)
211            .map_err(|_| std::io::ErrorKind::Other)?;
212
213        Ok(EncryptResult {
214            nonce_length: nonce.len(),
215            ciphertext_length: buffer_wrap.valid(),
216        })
217    }
218
219    fn decrypt(
220        &self,
221        nonce: &[u8],
222        ciphertext: &[u8],
223        associated_data: &[u8],
224    ) -> Result<Vec<u8>, DecryptError> {
225        let mut siv = Aes128Siv::new(&self.key);
226        siv.decrypt([associated_data, nonce], ciphertext)
227            .map_err(|_| DecryptError)
228    }
229
230    fn key_bytes(&self) -> &[u8] {
231        &self.key
232    }
233}
234
235// Ensure siv is not shown in debug output
236impl std::fmt::Debug for AesSivCmac256 {
237    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
238        f.debug_struct("AesSivCmac256").finish()
239    }
240}
241
242pub struct AesSivCmac512 {
243    // 256 vs 512 difference is due to using the official name (us) vs
244    // the number of bits of security (aes_siv crate)
245    key: Key<Aes256Siv>,
246}
247
248impl AesSivCmac512 {
249    pub fn new(key: Key<Aes256Siv>) -> Self {
250        AesSivCmac512 { key }
251    }
252
253    #[cfg(feature = "nts-pool")]
254    pub fn key_size() -> usize {
255        // prefer trust in compiler optimisation over trust in mental arithmetic
256        Self::new(Default::default()).key.len()
257    }
258
259    #[cfg(feature = "nts-pool")]
260    pub fn from_key_bytes(key_bytes: &[u8]) -> Result<Self, KeyError> {
261        (key_bytes.len() == Self::key_size())
262            .then(|| Self::new(*aead::Key::<Aes256Siv>::from_slice(key_bytes)))
263            .ok_or(KeyError)
264    }
265}
266
267impl ZeroizeOnDrop for AesSivCmac512 {}
268
269impl Drop for AesSivCmac512 {
270    fn drop(&mut self) {
271        self.key.zeroize();
272    }
273}
274
275impl Cipher for AesSivCmac512 {
276    fn encrypt(
277        &self,
278        buffer: &mut [u8],
279        plaintext_length: usize,
280        associated_data: &[u8],
281    ) -> std::io::Result<EncryptResult> {
282        let mut siv = Aes256Siv::new(&self.key);
283        let nonce: [u8; 16] = rand::thread_rng().gen();
284
285        // Prepare the buffer for in place encryption by moving the plaintext
286        // back, creating space for the nonce.
287        if buffer.len() < nonce.len() + plaintext_length {
288            return Err(std::io::ErrorKind::WriteZero.into());
289        }
290        buffer.copy_within(..plaintext_length, nonce.len());
291        // And place the nonce where the caller expects it
292        buffer[..nonce.len()].copy_from_slice(&nonce);
293
294        // Create a wrapper around the plaintext portion of the buffer that has
295        // the methods aes_siv needs to do encryption in-place.
296        let mut buffer_wrap = Buffer::new(&mut buffer[nonce.len()..], plaintext_length);
297        siv.encrypt_in_place([associated_data, &nonce], &mut buffer_wrap)
298            .map_err(|_| std::io::ErrorKind::Other)?;
299
300        Ok(EncryptResult {
301            nonce_length: nonce.len(),
302            ciphertext_length: buffer_wrap.valid(),
303        })
304    }
305
306    fn decrypt(
307        &self,
308        nonce: &[u8],
309        ciphertext: &[u8],
310        associated_data: &[u8],
311    ) -> Result<Vec<u8>, DecryptError> {
312        let mut siv = Aes256Siv::new(&self.key);
313        siv.decrypt([associated_data, nonce], ciphertext)
314            .map_err(|_| DecryptError)
315    }
316
317    fn key_bytes(&self) -> &[u8] {
318        &self.key
319    }
320}
321
322// Ensure siv is not shown in debug output
323impl std::fmt::Debug for AesSivCmac512 {
324    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
325        f.debug_struct("AesSivCmac512").finish()
326    }
327}
328
329#[cfg(test)]
330pub struct IdentityCipher {
331    nonce_length: usize,
332}
333
334#[cfg(test)]
335impl IdentityCipher {
336    pub fn new(nonce_length: usize) -> Self {
337        Self { nonce_length }
338    }
339}
340
341#[cfg(test)]
342impl ZeroizeOnDrop for IdentityCipher {}
343
344#[cfg(test)]
345impl Cipher for IdentityCipher {
346    fn encrypt(
347        &self,
348        buffer: &mut [u8],
349        plaintext_length: usize,
350        associated_data: &[u8],
351    ) -> std::io::Result<EncryptResult> {
352        debug_assert!(associated_data.is_empty());
353
354        let nonce: Vec<u8> = (0..self.nonce_length as u8).collect();
355
356        // Prepare the buffer for in place encryption by moving the plaintext
357        // back, creating space for the nonce.
358        if buffer.len() < nonce.len() + plaintext_length {
359            return Err(std::io::ErrorKind::WriteZero.into());
360        }
361        buffer.copy_within(..plaintext_length, nonce.len());
362        // And place the nonce where the caller expects it
363        buffer[..nonce.len()].copy_from_slice(&nonce);
364
365        Ok(EncryptResult {
366            nonce_length: nonce.len(),
367            ciphertext_length: plaintext_length,
368        })
369    }
370
371    fn decrypt(
372        &self,
373        nonce: &[u8],
374        ciphertext: &[u8],
375        associated_data: &[u8],
376    ) -> Result<Vec<u8>, DecryptError> {
377        debug_assert!(associated_data.is_empty());
378
379        debug_assert_eq!(nonce.len(), self.nonce_length);
380
381        Ok(ciphertext.to_vec())
382    }
383
384    fn key_bytes(&self) -> &[u8] {
385        unimplemented!()
386    }
387}
388
389#[cfg(test)]
390mod tests {
391    use super::*;
392
393    #[test]
394    fn test_aes_siv_cmac_256() {
395        let mut testvec: Vec<u8> = (0..16).collect();
396        testvec.resize(testvec.len() + 32, 0);
397        let key = AesSivCmac256::new([0u8; 32].into());
398        let EncryptResult {
399            nonce_length,
400            ciphertext_length,
401        } = key.encrypt(&mut testvec, 16, &[]).unwrap();
402        let result = key
403            .decrypt(
404                &testvec[..nonce_length],
405                &testvec[nonce_length..(nonce_length + ciphertext_length)],
406                &[],
407            )
408            .unwrap();
409        assert_eq!(result, (0..16).collect::<Vec<u8>>());
410    }
411
412    #[test]
413    fn test_aes_siv_cmac_256_with_assoc_data() {
414        let mut testvec: Vec<u8> = (0..16).collect();
415        testvec.resize(testvec.len() + 32, 0);
416        let key = AesSivCmac256::new([0u8; 32].into());
417        let EncryptResult {
418            nonce_length,
419            ciphertext_length,
420        } = key.encrypt(&mut testvec, 16, &[1]).unwrap();
421        assert!(key
422            .decrypt(
423                &testvec[..nonce_length],
424                &testvec[nonce_length..(nonce_length + ciphertext_length)],
425                &[2]
426            )
427            .is_err());
428        let result = key
429            .decrypt(
430                &testvec[..nonce_length],
431                &testvec[nonce_length..(nonce_length + ciphertext_length)],
432                &[1],
433            )
434            .unwrap();
435        assert_eq!(result, (0..16).collect::<Vec<u8>>());
436    }
437
438    #[test]
439    fn test_aes_siv_cmac_512() {
440        let mut testvec: Vec<u8> = (0..16).collect();
441        testvec.resize(testvec.len() + 32, 0);
442        let key = AesSivCmac512::new([0u8; 64].into());
443        let EncryptResult {
444            nonce_length,
445            ciphertext_length,
446        } = key.encrypt(&mut testvec, 16, &[]).unwrap();
447        let result = key
448            .decrypt(
449                &testvec[..nonce_length],
450                &testvec[nonce_length..(nonce_length + ciphertext_length)],
451                &[],
452            )
453            .unwrap();
454        assert_eq!(result, (0..16).collect::<Vec<u8>>());
455    }
456
457    #[test]
458    fn test_aes_siv_cmac_512_with_assoc_data() {
459        let mut testvec: Vec<u8> = (0..16).collect();
460        testvec.resize(testvec.len() + 32, 0);
461        let key = AesSivCmac512::new([0u8; 64].into());
462        let EncryptResult {
463            nonce_length,
464            ciphertext_length,
465        } = key.encrypt(&mut testvec, 16, &[1]).unwrap();
466        assert!(key
467            .decrypt(
468                &testvec[..nonce_length],
469                &testvec[nonce_length..(nonce_length + ciphertext_length)],
470                &[2]
471            )
472            .is_err());
473        let result = key
474            .decrypt(
475                &testvec[..nonce_length],
476                &testvec[nonce_length..(nonce_length + ciphertext_length)],
477                &[1],
478            )
479            .unwrap();
480        assert_eq!(result, (0..16).collect::<Vec<u8>>());
481    }
482
483    #[cfg(feature = "nts-pool")]
484    #[test]
485    fn key_functions_correctness() {
486        use aead::KeySizeUser;
487        assert_eq!(Aes128Siv::key_size(), AesSivCmac256::key_size());
488        assert_eq!(Aes256Siv::key_size(), AesSivCmac512::key_size());
489
490        let key_bytes = (1..=64).collect::<Vec<u8>>();
491        assert!(AesSivCmac256::from_key_bytes(&key_bytes).is_err());
492
493        let slice = &key_bytes[..AesSivCmac256::key_size()];
494        assert_eq!(
495            AesSivCmac256::from_key_bytes(slice).unwrap().key_bytes(),
496            slice
497        );
498
499        let slice = &key_bytes[..AesSivCmac512::key_size()];
500        assert_eq!(
501            AesSivCmac512::from_key_bytes(slice).unwrap().key_bytes(),
502            slice
503        );
504    }
505}