1use std::fmt::Display;
2
3use aes_siv::{siv::Aes128Siv, siv::Aes256Siv, Key, KeyInit};
4use rand::Rng;
5use zeroize::{Zeroize, ZeroizeOnDrop};
6
7use crate::keyset::DecodedServerCookie;
8
9use super::extension_fields::ExtensionField;
10
11#[derive(Debug)]
12pub struct DecryptError;
13
14impl Display for DecryptError {
15 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
16 write!(f, "Could not decrypt ciphertext")
17 }
18}
19
20impl std::error::Error for DecryptError {}
21
22#[derive(Debug)]
23pub struct KeyError;
24
25impl Display for KeyError {
26 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27 write!(f, "Invalid key")
28 }
29}
30
31impl std::error::Error for KeyError {}
32
33struct Buffer<'a> {
34 buffer: &'a mut [u8],
35 valid: usize,
36}
37
38impl<'a> Buffer<'a> {
39 fn new(buffer: &'a mut [u8], valid: usize) -> Self {
40 Self { buffer, valid }
41 }
42
43 fn valid(&self) -> usize {
44 self.valid
45 }
46}
47
48impl AsMut<[u8]> for Buffer<'_> {
49 fn as_mut(&mut self) -> &mut [u8] {
50 &mut self.buffer[..self.valid]
51 }
52}
53
54impl AsRef<[u8]> for Buffer<'_> {
55 fn as_ref(&self) -> &[u8] {
56 &self.buffer[..self.valid]
57 }
58}
59
60impl aead::Buffer for Buffer<'_> {
61 fn extend_from_slice(&mut self, other: &[u8]) -> aead::Result<()> {
62 self.buffer
63 .get_mut(self.valid..(self.valid + other.len()))
64 .ok_or(aead::Error)?
65 .copy_from_slice(other);
66 self.valid += other.len();
67 Ok(())
68 }
69
70 fn truncate(&mut self, len: usize) {
71 self.valid = std::cmp::min(self.valid, len);
72 }
73}
74
75#[derive(Debug, Clone, Copy, PartialEq, Eq)]
76pub struct EncryptResult {
77 pub nonce_length: usize,
78 pub ciphertext_length: usize,
79}
80
81pub trait Cipher: Sync + Send + ZeroizeOnDrop + 'static {
82 fn encrypt(
88 &self,
89 buffer: &mut [u8],
90 plaintext_length: usize,
91 associated_data: &[u8],
92 ) -> std::io::Result<EncryptResult>;
93
94 fn decrypt(
96 &self,
97 nonce: &[u8],
98 ciphertext: &[u8],
99 associated_data: &[u8],
100 ) -> Result<Vec<u8>, DecryptError>;
101
102 fn key_bytes(&self) -> &[u8];
103}
104
105pub enum CipherHolder<'a> {
106 DecodedServerCookie(DecodedServerCookie),
107 Other(&'a dyn Cipher),
108}
109
110impl AsRef<dyn Cipher> for CipherHolder<'_> {
111 fn as_ref(&self) -> &dyn Cipher {
112 match self {
113 CipherHolder::DecodedServerCookie(cookie) => cookie.c2s.as_ref(),
114 CipherHolder::Other(cipher) => *cipher,
115 }
116 }
117}
118
119pub trait CipherProvider {
120 fn get(&self, context: &[ExtensionField<'_>]) -> Option<CipherHolder<'_>>;
121}
122
123pub struct NoCipher;
124
125impl CipherProvider for NoCipher {
126 fn get<'a>(&self, _context: &[ExtensionField<'_>]) -> Option<CipherHolder<'_>> {
127 None
128 }
129}
130
131impl CipherProvider for dyn Cipher {
132 fn get(&self, _context: &[ExtensionField<'_>]) -> Option<CipherHolder<'_>> {
133 Some(CipherHolder::Other(self))
134 }
135}
136
137impl CipherProvider for Option<&dyn Cipher> {
138 fn get(&self, _context: &[ExtensionField<'_>]) -> Option<CipherHolder<'_>> {
139 self.map(CipherHolder::Other)
140 }
141}
142
143impl<C: Cipher> CipherProvider for C {
144 fn get(&self, _context: &[ExtensionField<'_>]) -> Option<CipherHolder<'_>> {
145 Some(CipherHolder::Other(self))
146 }
147}
148
149impl<C: Cipher> CipherProvider for Option<C> {
150 fn get(&self, _context: &[ExtensionField<'_>]) -> Option<CipherHolder<'_>> {
151 self.as_ref().map(|v| CipherHolder::Other(v))
152 }
153}
154
155pub struct AesSivCmac256 {
156 key: Key<Aes128Siv>,
159}
160
161impl ZeroizeOnDrop for AesSivCmac256 {}
162
163impl AesSivCmac256 {
164 pub fn new(key: Key<Aes128Siv>) -> Self {
165 AesSivCmac256 { key }
166 }
167
168 #[cfg(feature = "nts-pool")]
169 pub fn key_size() -> usize {
170 Self::new(Default::default()).key.len()
172 }
173
174 #[cfg(feature = "nts-pool")]
175 pub fn from_key_bytes(key_bytes: &[u8]) -> Result<Self, KeyError> {
176 (key_bytes.len() == Self::key_size())
177 .then(|| Self::new(*aead::Key::<Aes128Siv>::from_slice(key_bytes)))
178 .ok_or(KeyError)
179 }
180}
181
182impl Drop for AesSivCmac256 {
183 fn drop(&mut self) {
184 self.key.zeroize();
185 }
186}
187
188impl Cipher for AesSivCmac256 {
189 fn encrypt(
190 &self,
191 buffer: &mut [u8],
192 plaintext_length: usize,
193 associated_data: &[u8],
194 ) -> std::io::Result<EncryptResult> {
195 let mut siv = Aes128Siv::new(&self.key);
196 let nonce: [u8; 16] = rand::thread_rng().gen();
197
198 if buffer.len() < nonce.len() + plaintext_length {
201 return Err(std::io::ErrorKind::WriteZero.into());
202 }
203 buffer.copy_within(..plaintext_length, nonce.len());
204 buffer[..nonce.len()].copy_from_slice(&nonce);
206
207 let mut buffer_wrap = Buffer::new(&mut buffer[nonce.len()..], plaintext_length);
210 siv.encrypt_in_place([associated_data, &nonce], &mut buffer_wrap)
211 .map_err(|_| std::io::ErrorKind::Other)?;
212
213 Ok(EncryptResult {
214 nonce_length: nonce.len(),
215 ciphertext_length: buffer_wrap.valid(),
216 })
217 }
218
219 fn decrypt(
220 &self,
221 nonce: &[u8],
222 ciphertext: &[u8],
223 associated_data: &[u8],
224 ) -> Result<Vec<u8>, DecryptError> {
225 let mut siv = Aes128Siv::new(&self.key);
226 siv.decrypt([associated_data, nonce], ciphertext)
227 .map_err(|_| DecryptError)
228 }
229
230 fn key_bytes(&self) -> &[u8] {
231 &self.key
232 }
233}
234
235impl std::fmt::Debug for AesSivCmac256 {
237 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
238 f.debug_struct("AesSivCmac256").finish()
239 }
240}
241
242pub struct AesSivCmac512 {
243 key: Key<Aes256Siv>,
246}
247
248impl AesSivCmac512 {
249 pub fn new(key: Key<Aes256Siv>) -> Self {
250 AesSivCmac512 { key }
251 }
252
253 #[cfg(feature = "nts-pool")]
254 pub fn key_size() -> usize {
255 Self::new(Default::default()).key.len()
257 }
258
259 #[cfg(feature = "nts-pool")]
260 pub fn from_key_bytes(key_bytes: &[u8]) -> Result<Self, KeyError> {
261 (key_bytes.len() == Self::key_size())
262 .then(|| Self::new(*aead::Key::<Aes256Siv>::from_slice(key_bytes)))
263 .ok_or(KeyError)
264 }
265}
266
267impl ZeroizeOnDrop for AesSivCmac512 {}
268
269impl Drop for AesSivCmac512 {
270 fn drop(&mut self) {
271 self.key.zeroize();
272 }
273}
274
275impl Cipher for AesSivCmac512 {
276 fn encrypt(
277 &self,
278 buffer: &mut [u8],
279 plaintext_length: usize,
280 associated_data: &[u8],
281 ) -> std::io::Result<EncryptResult> {
282 let mut siv = Aes256Siv::new(&self.key);
283 let nonce: [u8; 16] = rand::thread_rng().gen();
284
285 if buffer.len() < nonce.len() + plaintext_length {
288 return Err(std::io::ErrorKind::WriteZero.into());
289 }
290 buffer.copy_within(..plaintext_length, nonce.len());
291 buffer[..nonce.len()].copy_from_slice(&nonce);
293
294 let mut buffer_wrap = Buffer::new(&mut buffer[nonce.len()..], plaintext_length);
297 siv.encrypt_in_place([associated_data, &nonce], &mut buffer_wrap)
298 .map_err(|_| std::io::ErrorKind::Other)?;
299
300 Ok(EncryptResult {
301 nonce_length: nonce.len(),
302 ciphertext_length: buffer_wrap.valid(),
303 })
304 }
305
306 fn decrypt(
307 &self,
308 nonce: &[u8],
309 ciphertext: &[u8],
310 associated_data: &[u8],
311 ) -> Result<Vec<u8>, DecryptError> {
312 let mut siv = Aes256Siv::new(&self.key);
313 siv.decrypt([associated_data, nonce], ciphertext)
314 .map_err(|_| DecryptError)
315 }
316
317 fn key_bytes(&self) -> &[u8] {
318 &self.key
319 }
320}
321
322impl std::fmt::Debug for AesSivCmac512 {
324 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
325 f.debug_struct("AesSivCmac512").finish()
326 }
327}
328
329#[cfg(test)]
330pub struct IdentityCipher {
331 nonce_length: usize,
332}
333
334#[cfg(test)]
335impl IdentityCipher {
336 pub fn new(nonce_length: usize) -> Self {
337 Self { nonce_length }
338 }
339}
340
341#[cfg(test)]
342impl ZeroizeOnDrop for IdentityCipher {}
343
344#[cfg(test)]
345impl Cipher for IdentityCipher {
346 fn encrypt(
347 &self,
348 buffer: &mut [u8],
349 plaintext_length: usize,
350 associated_data: &[u8],
351 ) -> std::io::Result<EncryptResult> {
352 debug_assert!(associated_data.is_empty());
353
354 let nonce: Vec<u8> = (0..self.nonce_length as u8).collect();
355
356 if buffer.len() < nonce.len() + plaintext_length {
359 return Err(std::io::ErrorKind::WriteZero.into());
360 }
361 buffer.copy_within(..plaintext_length, nonce.len());
362 buffer[..nonce.len()].copy_from_slice(&nonce);
364
365 Ok(EncryptResult {
366 nonce_length: nonce.len(),
367 ciphertext_length: plaintext_length,
368 })
369 }
370
371 fn decrypt(
372 &self,
373 nonce: &[u8],
374 ciphertext: &[u8],
375 associated_data: &[u8],
376 ) -> Result<Vec<u8>, DecryptError> {
377 debug_assert!(associated_data.is_empty());
378
379 debug_assert_eq!(nonce.len(), self.nonce_length);
380
381 Ok(ciphertext.to_vec())
382 }
383
384 fn key_bytes(&self) -> &[u8] {
385 unimplemented!()
386 }
387}
388
389#[cfg(test)]
390mod tests {
391 use super::*;
392
393 #[test]
394 fn test_aes_siv_cmac_256() {
395 let mut testvec: Vec<u8> = (0..16).collect();
396 testvec.resize(testvec.len() + 32, 0);
397 let key = AesSivCmac256::new([0u8; 32].into());
398 let EncryptResult {
399 nonce_length,
400 ciphertext_length,
401 } = key.encrypt(&mut testvec, 16, &[]).unwrap();
402 let result = key
403 .decrypt(
404 &testvec[..nonce_length],
405 &testvec[nonce_length..(nonce_length + ciphertext_length)],
406 &[],
407 )
408 .unwrap();
409 assert_eq!(result, (0..16).collect::<Vec<u8>>());
410 }
411
412 #[test]
413 fn test_aes_siv_cmac_256_with_assoc_data() {
414 let mut testvec: Vec<u8> = (0..16).collect();
415 testvec.resize(testvec.len() + 32, 0);
416 let key = AesSivCmac256::new([0u8; 32].into());
417 let EncryptResult {
418 nonce_length,
419 ciphertext_length,
420 } = key.encrypt(&mut testvec, 16, &[1]).unwrap();
421 assert!(key
422 .decrypt(
423 &testvec[..nonce_length],
424 &testvec[nonce_length..(nonce_length + ciphertext_length)],
425 &[2]
426 )
427 .is_err());
428 let result = key
429 .decrypt(
430 &testvec[..nonce_length],
431 &testvec[nonce_length..(nonce_length + ciphertext_length)],
432 &[1],
433 )
434 .unwrap();
435 assert_eq!(result, (0..16).collect::<Vec<u8>>());
436 }
437
438 #[test]
439 fn test_aes_siv_cmac_512() {
440 let mut testvec: Vec<u8> = (0..16).collect();
441 testvec.resize(testvec.len() + 32, 0);
442 let key = AesSivCmac512::new([0u8; 64].into());
443 let EncryptResult {
444 nonce_length,
445 ciphertext_length,
446 } = key.encrypt(&mut testvec, 16, &[]).unwrap();
447 let result = key
448 .decrypt(
449 &testvec[..nonce_length],
450 &testvec[nonce_length..(nonce_length + ciphertext_length)],
451 &[],
452 )
453 .unwrap();
454 assert_eq!(result, (0..16).collect::<Vec<u8>>());
455 }
456
457 #[test]
458 fn test_aes_siv_cmac_512_with_assoc_data() {
459 let mut testvec: Vec<u8> = (0..16).collect();
460 testvec.resize(testvec.len() + 32, 0);
461 let key = AesSivCmac512::new([0u8; 64].into());
462 let EncryptResult {
463 nonce_length,
464 ciphertext_length,
465 } = key.encrypt(&mut testvec, 16, &[1]).unwrap();
466 assert!(key
467 .decrypt(
468 &testvec[..nonce_length],
469 &testvec[nonce_length..(nonce_length + ciphertext_length)],
470 &[2]
471 )
472 .is_err());
473 let result = key
474 .decrypt(
475 &testvec[..nonce_length],
476 &testvec[nonce_length..(nonce_length + ciphertext_length)],
477 &[1],
478 )
479 .unwrap();
480 assert_eq!(result, (0..16).collect::<Vec<u8>>());
481 }
482
483 #[cfg(feature = "nts-pool")]
484 #[test]
485 fn key_functions_correctness() {
486 use aead::KeySizeUser;
487 assert_eq!(Aes128Siv::key_size(), AesSivCmac256::key_size());
488 assert_eq!(Aes256Siv::key_size(), AesSivCmac512::key_size());
489
490 let key_bytes = (1..=64).collect::<Vec<u8>>();
491 assert!(AesSivCmac256::from_key_bytes(&key_bytes).is_err());
492
493 let slice = &key_bytes[..AesSivCmac256::key_size()];
494 assert_eq!(
495 AesSivCmac256::from_key_bytes(slice).unwrap().key_bytes(),
496 slice
497 );
498
499 let slice = &key_bytes[..AesSivCmac512::key_size()];
500 assert_eq!(
501 AesSivCmac512::from_key_bytes(slice).unwrap().key_bytes(),
502 slice
503 );
504 }
505}