1use aead::{
2 consts::U16, generic_array::GenericArray, Key, KeyInit, KeySizeUser,
3};
4use cipher::{BlockCipher, BlockEncrypt, BlockSizeUser};
5use ctr::Ctr32BE;
6use ghash::{universal_hash::UniversalHash, GHash};
7
8#[derive(Clone)]
9struct GcmGhash<const TAG_SIZE: usize> {
10 ghash: GHash,
11 ghash_pad: [u8; TAG_SIZE],
12 msg_buf: [u8; TAG_SIZE],
13 msg_buf_offset: usize,
14 ad_len: usize,
15 msg_len: usize,
16}
17
18impl<const TAG_SIZE: usize> GcmGhash<TAG_SIZE> {
19 fn new(h: &[u8], ghash_pad: [u8; TAG_SIZE]) -> Result<Self, ()> {
20 let ghash = GHash::new(h.try_into().unwrap());
21
22 Ok(Self {
23 ghash,
24 ghash_pad,
25 msg_buf: [0u8; TAG_SIZE],
26 msg_buf_offset: 0,
27 ad_len: 0,
28 msg_len: 0,
29 })
30 }
31
32 fn set_aad(&mut self, aad: &[u8]) {
33 self.ad_len = aad.len();
34 self.ghash.update_padded(aad);
35 }
36
37 fn update(&mut self, msg: &[u8]) {
38 if self.msg_buf_offset > 0 {
39 let taking = std::cmp::min(msg.len(), TAG_SIZE - self.msg_buf_offset);
40 self.msg_buf[self.msg_buf_offset..self.msg_buf_offset + taking]
41 .copy_from_slice(&msg[..taking]);
42 self.msg_buf_offset += taking;
43 assert!(self.msg_buf_offset <= TAG_SIZE);
44
45 self.msg_len += taking;
46
47 if self.msg_buf_offset == TAG_SIZE {
48 self
49 .ghash
50 .update(std::slice::from_ref(ghash::Block::from_slice(
51 &self.msg_buf,
52 )));
53 self.msg_buf_offset = 0;
54 return self.update(&msg[taking..]);
55 } else {
56 return;
57 }
58 }
59
60 self.msg_len += msg.len();
61
62 assert_eq!(self.msg_buf_offset, 0);
63 let full_blocks = msg.len() / 16;
64 let leftover = msg.len() - 16 * full_blocks;
65 assert!(leftover < TAG_SIZE);
66 if full_blocks > 0 {
67 let blocks = unsafe {
70 std::slice::from_raw_parts(
71 msg[..16 * full_blocks].as_ptr().cast(),
72 full_blocks,
73 )
74 };
75 assert_eq!(
76 std::mem::size_of_val(blocks) + leftover,
77 std::mem::size_of_val(msg)
78 );
79 self.ghash.update(blocks);
80 }
81
82 self.msg_buf[0..leftover].copy_from_slice(&msg[full_blocks * 16..]);
83 self.msg_buf_offset = leftover;
84 assert!(self.msg_buf_offset < TAG_SIZE);
85 }
86
87 fn finalize(mut self) -> GenericArray<u8, U16> {
88 if self.msg_buf_offset > 0 {
89 self
90 .ghash
91 .update_padded(&self.msg_buf[..self.msg_buf_offset]);
92 }
93
94 let mut final_block = [0u8; 16];
95 final_block[..8].copy_from_slice(&(8 * self.ad_len as u64).to_be_bytes());
96 final_block[8..].copy_from_slice(&(8 * self.msg_len as u64).to_be_bytes());
97
98 self.ghash.update(&[final_block.into()]);
99 let mut hash = self.ghash.finalize();
100
101 for (i, b) in hash.iter_mut().enumerate() {
102 *b ^= self.ghash_pad[i];
103 }
104
105 hash
106 }
107}
108
109pub struct AesGcm<Aes>
110where
111 Aes: BlockCipher + BlockSizeUser<BlockSize = U16> + BlockEncrypt,
112{
113 ctr: Ctr32BE<Aes>,
115
116 ghash: GcmGhash<16>,
118}
119
120impl<Aes> KeySizeUser for AesGcm<Aes>
121where
122 Aes:
123 KeySizeUser + BlockCipher + BlockSizeUser<BlockSize = U16> + BlockEncrypt,
124{
125 type KeySize = Aes::KeySize;
126}
127
128impl<Aes> AesGcm<Aes>
129where
130 Aes: BlockCipher + BlockSizeUser<BlockSize = U16> + BlockEncrypt + KeyInit,
131{
132 pub fn new(key: &Key<Self>, nonce: &[u8]) -> Self {
133 let cipher = Aes::new(key);
134 let mut ghash_key = ghash::Key::default();
135 cipher.encrypt_block(&mut ghash_key);
136
137 use cipher::InnerIvInit;
138 use cipher::StreamCipherSeek;
139
140 let mut nonce_block = GenericArray::default();
141 if nonce.len() == 12 {
142 nonce_block[..nonce.len()].copy_from_slice(nonce);
143 } else {
144 let mut ghash = GHash::new(&ghash_key);
148 ghash.update_padded(nonce);
149 ghash.update_padded(&(8 * nonce.len() as u128).to_be_bytes());
150 nonce_block.copy_from_slice(&ghash.finalize());
151 for i in nonce_block.iter_mut().rev() {
153 *i = i.wrapping_sub(1);
154 if *i != 0xff {
155 break;
156 }
157 }
158 }
159 let mut ctr = ctr::Ctr32BE::from_core(ctr::CtrCore::inner_iv_init(
160 cipher,
161 &nonce_block,
162 ));
163 ctr.seek(Aes::block_size());
164
165 let mut pad = [0u8; 16];
166 ctr.apply_keystream(&mut pad);
167
168 let ghash = GcmGhash::new(&ghash_key, pad).unwrap();
169 Self { ctr, ghash }
170 }
171}
172
173use cipher::StreamCipher;
174
175impl<Aes> AesGcm<Aes>
176where
177 Aes: BlockCipher + BlockSizeUser<BlockSize = U16> + BlockEncrypt + KeyInit,
178{
179 pub fn set_aad(&mut self, aad: &[u8]) {
180 self.ghash.set_aad(aad);
181 }
182
183 pub fn encrypt(&mut self, block: &mut [u8]) {
184 self.ctr.apply_keystream(block);
185 self.ghash.update(block);
186 }
187
188 pub fn decrypt(&mut self, block: &mut [u8]) {
189 self.ghash.update(block);
190 self.ctr.apply_keystream(block);
191 }
192
193 pub fn finish(self) -> GenericArray<u8, U16> {
194 self.ghash.finalize()
195 }
196}