1#![doc = include_str!("../README.md")]
2#![cfg_attr(not(test), no_std)]
3#![warn(missing_docs)]
4use aes::cipher::{
5 BlockCipher, BlockDecrypt, BlockEncrypt, BlockSizeUser, InnerIvInit, KeyInit,
6 consts::{U12, U16},
7 generic_array::GenericArray,
8};
9use ctr::CtrCore;
10use ghash::{GHash, universal_hash::UniversalHash};
11
12pub use aes;
13pub use ctr;
14pub use ghash;
15
16pub fn recover_counter<
29 C: BlockCipher
30 + BlockEncrypt
31 + BlockDecrypt
32 + BlockSizeUser<BlockSize = <ghash::GHash as BlockSizeUser>::BlockSize>,
33>(
34 cipher: &C,
35 c: &[u8],
36 tag: Option<&GenericArray<u8, <ghash::GHash as BlockSizeUser>::BlockSize>>,
37 aad: &[u8],
38) -> GenericArray<u8, <ghash::GHash as BlockSizeUser>::BlockSize> {
39 let mut c = c;
40 let tag = match tag {
41 Some(t) => t,
42 None => {
43 let (cc, tt) = c.split_at(c.len() - 16);
44 c = cc;
45 aes::Block::from_slice(tt)
46 }
47 };
48
49 let mut ghash_key = ghash::Key::default();
51 cipher.encrypt_block(&mut ghash_key);
52
53 let mut ghash = GHash::new(&ghash_key);
54
55 ghash.update_padded(aad);
56 ghash.update_padded(c);
57
58 let associated_data_bits = (aad.len() as u64) * 8;
59 let buffer_bits = (c.len() as u64) * 8;
60
61 let mut block = ghash::Block::default();
62 block[..8].copy_from_slice(&associated_data_bits.to_be_bytes());
63 block[8..].copy_from_slice(&buffer_bits.to_be_bytes());
64 ghash.update(&[block]);
65
66 let mut mask = ghash.finalize();
67 for (a, b) in mask.iter_mut().zip(tag.iter()) {
68 *a ^= *b;
69 }
70
71 cipher.decrypt_block(&mut mask);
72
73 mask
74}
75
76pub fn instantiate_keystream<
87 C: BlockCipher
88 + BlockEncrypt
89 + BlockSizeUser<BlockSize = <ghash::GHash as BlockSizeUser>::BlockSize>,
90>(
91 cipher: C,
92 y0: &GenericArray<u8, <C as BlockSizeUser>::BlockSize>,
93) -> ctr::Ctr32BE<C> {
94 let mut y0_inc = *y0;
95 {
96 let mut ctr = u32::from_be_bytes(y0_inc[y0.len() - 4..].try_into().unwrap());
97 ctr = ctr.wrapping_add(1);
98 y0_inc[y0.len() - 4..].copy_from_slice(&ctr.to_be_bytes());
99 }
100 let ctr = ctr::Ctr32BE::<C>::from_core(CtrCore::inner_iv_init(cipher, &y0_inc));
101 ctr
102}
103
104pub const fn extract_nonce<
110 C: BlockSizeUser<BlockSize = <ghash::GHash as BlockSizeUser>::BlockSize>,
111>(
112 y0: &GenericArray<u8, <C as BlockSizeUser>::BlockSize>,
113) -> Option<&GenericArray<u8, U12>> {
114 #[expect(unused)]
115 const ASSERT_Y0_IS_16_BYTES: GenericArray<u8, U16> =
116 unsafe { core::mem::zeroed::<aes::Block>() };
117 let y0 = unsafe { core::mem::transmute::<&aes::Block, &[u8; 16]>(y0) };
118 if y0[12] == 0 && y0[13] == 0 && y0[14] == 0 && y0[15] == 1 {
119 Some(unsafe { core::mem::transmute::<&[u8; 16], &GenericArray<u8, U12>>(y0) })
120 } else {
121 None
122 }
123}
124
125#[cfg(test)]
126mod tests {
127 use aes::cipher::StreamCipher;
128 use aes_gcm::aead::AeadMutInPlace;
129
130 use super::*;
131
132 #[inline(always)]
133 const fn fnv1a(state: u64, data: u64) -> u64 {
134 let state = state ^ data;
135 state.wrapping_mul(0x100000001b3u64)
136 }
137
138 fn test_recover<
139 C: BlockCipher + BlockEncrypt + BlockDecrypt + BlockSizeUser<BlockSize = U16> + KeyInit,
140 >() {
141 let mut rng = 0xcbf29ce484222325u64;
142 let mut key = aes_gcm::Key::<C>::default();
143
144 for rep in 0..5 {
145 rng = fnv1a(rng, rep);
146 key.chunks_mut(8).enumerate().for_each(|(i, k)| {
147 rng = fnv1a(rng, i as u64);
148 let bytes = rng.to_be_bytes();
149 k.copy_from_slice(&bytes[..k.len()]);
150 });
151 let aes_cipher = C::new(&key);
152 let mut cipher = aes_gcm::AesGcm::<C, U12>::new(&key);
153
154 let mut nonce = aes_gcm::Nonce::<U12>::default();
155 nonce.chunks_mut(4).enumerate().for_each(|(i, n)| {
156 rng = fnv1a(rng, i as u64);
157 let bytes = rng.to_be_bytes();
158 n.copy_from_slice(&bytes[..n.len()]);
159 });
160 let mut plaintext = [0; 128];
161 b"RealPlaintext"
162 .into_iter()
163 .cycle()
164 .zip(plaintext.iter_mut())
165 .for_each(|(p, c)| *c = *p);
166
167 for pt_len in 0..plaintext.len() {
168 for aad in [b"".as_slice(), b"GenuineAAD".as_slice()] {
169 let mut ciphertext = plaintext;
170
171 let tag = cipher
172 .encrypt_in_place_detached(&nonce, aad, &mut ciphertext[..pt_len])
173 .unwrap();
174
175 let recovered_j0 =
176 recover_counter(&aes_cipher, &ciphertext[..pt_len], Some(&tag), aad);
177 let recovered_nonce = extract_nonce::<C>(&recovered_j0).unwrap();
178 assert_eq!(recovered_nonce, &nonce);
179
180 let mut recovered_keystream = instantiate_keystream(&aes_cipher, &recovered_j0);
181 recovered_keystream.apply_keystream(&mut plaintext[..pt_len]);
182 assert_eq!(plaintext[..pt_len], ciphertext[..pt_len]);
183 }
184 }
185 }
186 }
187
188 #[test]
189 fn test_recover_j0_aes128() {
190 test_recover::<aes::Aes128>();
191 }
192
193 #[test]
194 fn test_recover_j0_aes256() {
195 test_recover::<aes::Aes256>();
196 }
197}