aead_gcm_stream/
lib.rs

1use aead::{
2  consts::U16, generic_array::GenericArray, Key, KeyInit, KeySizeUser,
3};
4use cipher::{BlockCipher, BlockEncrypt, BlockSizeUser};
5use ctr::Ctr32BE;
6use ghash::{universal_hash::UniversalHash, GHash};
7
8#[derive(Clone)]
9struct GcmGhash<const TAG_SIZE: usize> {
10  ghash: GHash,
11  ghash_pad: [u8; TAG_SIZE],
12  msg_buf: [u8; TAG_SIZE],
13  msg_buf_offset: usize,
14  ad_len: usize,
15  msg_len: usize,
16}
17
18impl<const TAG_SIZE: usize> GcmGhash<TAG_SIZE> {
19  fn new(h: &[u8], ghash_pad: [u8; TAG_SIZE]) -> Result<Self, ()> {
20    let ghash = GHash::new(h.try_into().unwrap());
21
22    Ok(Self {
23      ghash,
24      ghash_pad,
25      msg_buf: [0u8; TAG_SIZE],
26      msg_buf_offset: 0,
27      ad_len: 0,
28      msg_len: 0,
29    })
30  }
31
32  fn set_aad(&mut self, aad: &[u8]) {
33    self.ad_len = aad.len();
34    self.ghash.update_padded(aad);
35  }
36
37  fn update(&mut self, msg: &[u8]) {
38    if self.msg_buf_offset > 0 {
39      let taking = std::cmp::min(msg.len(), TAG_SIZE - self.msg_buf_offset);
40      self.msg_buf[self.msg_buf_offset..self.msg_buf_offset + taking]
41        .copy_from_slice(&msg[..taking]);
42      self.msg_buf_offset += taking;
43      assert!(self.msg_buf_offset <= TAG_SIZE);
44
45      self.msg_len += taking;
46
47      if self.msg_buf_offset == TAG_SIZE {
48        self
49          .ghash
50          .update(std::slice::from_ref(ghash::Block::from_slice(
51            &self.msg_buf,
52          )));
53        self.msg_buf_offset = 0;
54        return self.update(&msg[taking..]);
55      } else {
56        return;
57      }
58    }
59
60    self.msg_len += msg.len();
61
62    assert_eq!(self.msg_buf_offset, 0);
63    let full_blocks = msg.len() / 16;
64    let leftover = msg.len() - 16 * full_blocks;
65    assert!(leftover < TAG_SIZE);
66    if full_blocks > 0 {
67      // Safety: Transmute [u8] to [[u8; 16]], like slice::as_chunks.
68      // Then transmute [[u8; 16]] to [GenericArray<U16>], per repr(transparent).
69      let blocks = unsafe {
70        std::slice::from_raw_parts(
71          msg[..16 * full_blocks].as_ptr().cast(),
72          full_blocks,
73        )
74      };
75      assert_eq!(
76        std::mem::size_of_val(blocks) + leftover,
77        std::mem::size_of_val(msg)
78      );
79      self.ghash.update(blocks);
80    }
81
82    self.msg_buf[0..leftover].copy_from_slice(&msg[full_blocks * 16..]);
83    self.msg_buf_offset = leftover;
84    assert!(self.msg_buf_offset < TAG_SIZE);
85  }
86
87  fn finalize(mut self) -> GenericArray<u8, U16> {
88    if self.msg_buf_offset > 0 {
89      self
90        .ghash
91        .update_padded(&self.msg_buf[..self.msg_buf_offset]);
92    }
93
94    let mut final_block = [0u8; 16];
95    final_block[..8].copy_from_slice(&(8 * self.ad_len as u64).to_be_bytes());
96    final_block[8..].copy_from_slice(&(8 * self.msg_len as u64).to_be_bytes());
97
98    self.ghash.update(&[final_block.into()]);
99    let mut hash = self.ghash.finalize();
100
101    for (i, b) in hash.iter_mut().enumerate() {
102      *b ^= self.ghash_pad[i];
103    }
104
105    hash
106  }
107}
108
109pub struct AesGcm<Aes>
110where
111  Aes: BlockCipher + BlockSizeUser<BlockSize = U16> + BlockEncrypt,
112{
113  /// Encryption cipher
114  ctr: Ctr32BE<Aes>,
115
116  /// GHASH authenticator
117  ghash: GcmGhash<16>,
118}
119
120impl<Aes> KeySizeUser for AesGcm<Aes>
121where
122  Aes:
123    KeySizeUser + BlockCipher + BlockSizeUser<BlockSize = U16> + BlockEncrypt,
124{
125  type KeySize = Aes::KeySize;
126}
127
128impl<Aes> AesGcm<Aes>
129where
130  Aes: BlockCipher + BlockSizeUser<BlockSize = U16> + BlockEncrypt + KeyInit,
131{
132  pub fn new(key: &Key<Self>, nonce: &[u8]) -> Self {
133    let cipher = Aes::new(key);
134    let mut ghash_key = ghash::Key::default();
135    cipher.encrypt_block(&mut ghash_key);
136
137    use cipher::InnerIvInit;
138    use cipher::StreamCipherSeek;
139
140    let mut nonce_block = GenericArray::default();
141    if nonce.len() == 12 {
142      nonce_block[..nonce.len()].copy_from_slice(nonce);
143    } else {
144      // We calculate GHASH(nonce || padding || 0^64 || len_u64(nonce)) to get J0 (initial counter block)
145      // See NIST SP 800-38D, section 7.1, algorithm 4, step 2 or section 7.2, algorithm 5, step 3
146      // https://nvlpubs.nist.gov/nistpubs/Legacy/SP/nistspecialpublication800-38d.pdf
147      let mut ghash = GHash::new(&ghash_key);
148      ghash.update_padded(nonce);
149      ghash.update_padded(&(8 * nonce.len() as u128).to_be_bytes());
150      nonce_block.copy_from_slice(&ghash.finalize());
151      // We subtract 1 from the counter block to align with the CTR implementation below
152      for i in nonce_block.iter_mut().rev() {
153        *i = i.wrapping_sub(1);
154        if *i != 0xff {
155          break;
156        }
157      }
158    }
159    let mut ctr = ctr::Ctr32BE::from_core(ctr::CtrCore::inner_iv_init(
160      cipher,
161      &nonce_block,
162    ));
163    ctr.seek(Aes::block_size());
164
165    let mut pad = [0u8; 16];
166    ctr.apply_keystream(&mut pad);
167
168    let ghash = GcmGhash::new(&ghash_key, pad).unwrap();
169    Self { ctr, ghash }
170  }
171}
172
173use cipher::StreamCipher;
174
175impl<Aes> AesGcm<Aes>
176where
177  Aes: BlockCipher + BlockSizeUser<BlockSize = U16> + BlockEncrypt + KeyInit,
178{
179  pub fn set_aad(&mut self, aad: &[u8]) {
180    self.ghash.set_aad(aad);
181  }
182
183  pub fn encrypt(&mut self, block: &mut [u8]) {
184    self.ctr.apply_keystream(block);
185    self.ghash.update(block);
186  }
187
188  pub fn decrypt(&mut self, block: &mut [u8]) {
189    self.ghash.update(block);
190    self.ctr.apply_keystream(block);
191  }
192
193  pub fn finish(self) -> GenericArray<u8, U16> {
194    self.ghash.finalize()
195  }
196}