1use core::convert::TryInto;
8
9pub const AES_BLOCK_SIZE: usize = 16;
11
12#[derive(Debug, Clone, Copy)]
14pub enum AesKeySize {
15 Bits128,
16 Bits192,
17 Bits256,
18}
19
20pub struct AesKey {
22 pub round_keys: Vec<[u8; AES_BLOCK_SIZE]>,
23 pub nr: usize, }
25
26impl AesKey {
27 pub fn new(key_data: &[u8], key_size: AesKeySize) -> Self {
34 let (key_len, nr, nk) = match key_size {
35 AesKeySize::Bits128 => (16, 10, 4),
36 AesKeySize::Bits192 => (24, 12, 6),
37 AesKeySize::Bits256 => (32, 14, 8),
38 };
39 assert_eq!(
40 key_data.len(),
41 key_len,
42 "Key length mismatch for AES key size"
43 );
44
45 let expanded_len = AES_BLOCK_SIZE * (nr + 1);
46 let mut round_keys = vec![0u8; expanded_len];
47 round_keys[..key_len].copy_from_slice(key_data);
49
50 key_expansion(&mut round_keys, nk, nr);
51
52 let mut round_blocks = Vec::with_capacity(nr + 1);
54 for i in 0..(nr + 1) {
55 let offset = i * AES_BLOCK_SIZE;
56 let block: [u8; AES_BLOCK_SIZE] = round_keys[offset..offset + AES_BLOCK_SIZE]
57 .try_into()
58 .unwrap();
59 round_blocks.push(block);
60 }
61
62 Self {
63 round_keys: round_blocks,
64 nr,
65 }
66 }
67}
68
69pub fn aes_encrypt_block(plaintext: &mut [u8; AES_BLOCK_SIZE], key: &AesKey) {
75 add_round_key(plaintext, &key.round_keys[0]);
76
77 for round in 1..key.nr {
78 sub_bytes(plaintext);
79 shift_rows(plaintext);
80 mix_columns(plaintext);
81 add_round_key(plaintext, &key.round_keys[round]);
82 }
83
84 sub_bytes(plaintext);
86 shift_rows(plaintext);
87 add_round_key(plaintext, &key.round_keys[key.nr]);
88}
89
90pub fn aes_decrypt_block(ciphertext: &mut [u8; AES_BLOCK_SIZE], key: &AesKey) {
96 add_round_key(ciphertext, &key.round_keys[key.nr]);
97 inv_shift_rows(ciphertext);
98 inv_sub_bytes(ciphertext);
99
100 for round in (1..key.nr).rev() {
101 add_round_key(ciphertext, &key.round_keys[round]);
102 inv_mix_columns(ciphertext);
103 inv_shift_rows(ciphertext);
104 inv_sub_bytes(ciphertext);
105 }
106
107 add_round_key(ciphertext, &key.round_keys[0]);
109}
110
111static SBOX: [u8; 256] = [
117 0x63, 0x7c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5, 0x30, 0x01, 0x67, 0x2b, 0xfe, 0xd7, 0xab, 0x76,
118 0xca, 0x82, 0xc9, 0x7d, 0xfa, 0x59, 0x47, 0xf0, 0xad, 0xd4, 0xa2, 0xaf, 0x9c, 0xa4, 0x72, 0xc0,
119 0xb7, 0xfd, 0x93, 0x26, 0x36, 0x3f, 0xf7, 0xcc, 0x34, 0xa5, 0xe5, 0xf1, 0x71, 0xd8, 0x31, 0x15,
120 0x04, 0xc7, 0x23, 0xc3, 0x18, 0x96, 0x05, 0x9a, 0x07, 0x12, 0x80, 0xe2, 0xeb, 0x27, 0xb2, 0x75,
121 0x09, 0x83, 0x2c, 0x1a, 0x1b, 0x6e, 0x5a, 0xa0, 0x52, 0x3b, 0xd6, 0xb3, 0x29, 0xe3, 0x2f, 0x84,
122 0x53, 0xd1, 0x00, 0xed, 0x20, 0xfc, 0xb1, 0x5b, 0x6a, 0xcb, 0xbe, 0x39, 0x4a, 0x4c, 0x58, 0xcf,
123 0xd0, 0xef, 0xaa, 0xfb, 0x43, 0x4d, 0x33, 0x85, 0x45, 0xf9, 0x02, 0x7f, 0x50, 0x3c, 0x9f, 0xa8,
124 0x51, 0xa3, 0x40, 0x8f, 0x92, 0x9d, 0x38, 0xf5, 0xbc, 0xb6, 0xda, 0x21, 0x10, 0xff, 0xf3, 0xd2,
125 0xcd, 0x0c, 0x13, 0xec, 0x5f, 0x97, 0x44, 0x17, 0xc4, 0xa7, 0x7e, 0x3d, 0x64, 0x5d, 0x19, 0x73,
126 0x60, 0x81, 0x4f, 0xdc, 0x22, 0x2a, 0x90, 0x88, 0x46, 0xee, 0xb8, 0x14, 0xde, 0x5e, 0x0b, 0xdb,
127 0xe0, 0x32, 0x3a, 0x0a, 0x49, 0x06, 0x24, 0x5c, 0xc2, 0xd3, 0xac, 0x62, 0x91, 0x95, 0xe4, 0x79,
128 0xe7, 0xc8, 0x37, 0x6d, 0x8d, 0xd5, 0x4e, 0xa9, 0x6c, 0x56, 0xf4, 0xea, 0x65, 0x7a, 0xae, 0x08,
129 0xba, 0x78, 0x25, 0x2e, 0x1c, 0xa6, 0xb4, 0xc6, 0xe8, 0xdd, 0x74, 0x1f, 0x4b, 0xbd, 0x8b, 0x8a,
130 0x70, 0x3e, 0xb5, 0x66, 0x48, 0x03, 0xf6, 0x0e, 0x61, 0x35, 0x57, 0xb9, 0x86, 0xc1, 0x1d, 0x9e,
131 0xe1, 0xf8, 0x98, 0x11, 0x69, 0xd9, 0x8e, 0x94, 0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf,
132 0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, 0x41, 0x99, 0x2d, 0x0f, 0xb0, 0x54, 0xbb, 0x16,
133];
134
135static INV_SBOX: [u8; 256] = [
137 0x52, 0x09, 0x6a, 0xd5, 0x30, 0x36, 0xa5, 0x38, 0xbf, 0x40, 0xa3, 0x9e, 0x81, 0xf3, 0xd7, 0xfb,
138 0x7c, 0xe3, 0x39, 0x82, 0x9b, 0x2f, 0xff, 0x87, 0x34, 0x8e, 0x43, 0x44, 0xc4, 0xde, 0xe9, 0xcb,
139 0x54, 0x7b, 0x94, 0x32, 0xa6, 0xc2, 0x23, 0x3d, 0xee, 0x4c, 0x95, 0x0b, 0x42, 0xfa, 0xc3, 0x4e,
140 0x08, 0x2e, 0xa1, 0x66, 0x28, 0xd9, 0x24, 0xb2, 0x76, 0x5b, 0xa2, 0x49, 0x6d, 0x8b, 0xd1, 0x25,
141 0x72, 0xf8, 0xf6, 0x64, 0x86, 0x68, 0x98, 0x16, 0xd4, 0xa4, 0x5c, 0xcc, 0x5d, 0x65, 0xb6, 0x92,
142 0x6c, 0x70, 0x48, 0x50, 0xfd, 0xed, 0xb9, 0xda, 0x5e, 0x15, 0x46, 0x57, 0xa7, 0x8d, 0x9d, 0x84,
143 0x90, 0xd8, 0xab, 0x00, 0x8c, 0xbc, 0xd3, 0x0a, 0xf7, 0xe4, 0x58, 0x05, 0xb8, 0xb3, 0x45, 0x06,
144 0xd0, 0x2c, 0x1e, 0x8f, 0xca, 0x3f, 0x0f, 0x02, 0xc1, 0xaf, 0xbd, 0x03, 0x01, 0x13, 0x8a, 0x6b,
145 0x3a, 0x91, 0x11, 0x41, 0x4f, 0x67, 0xdc, 0xea, 0x97, 0xf2, 0xcf, 0xce, 0xf0, 0xb4, 0xe6, 0x73,
146 0x96, 0xac, 0x74, 0x22, 0xe7, 0xad, 0x35, 0x85, 0xe2, 0xf9, 0x37, 0xe8, 0x1c, 0x75, 0xdf, 0x6e,
147 0x47, 0xf1, 0x1a, 0x71, 0x1d, 0x29, 0xc5, 0x89, 0x6f, 0xb7, 0x62, 0x0e, 0xaa, 0x18, 0xbe, 0x1b,
148 0xfc, 0x56, 0x3e, 0x4b, 0xc6, 0xd2, 0x79, 0x20, 0x9a, 0xdb, 0xc0, 0xfe, 0x78, 0xcd, 0x5a, 0xf4,
149 0x1f, 0xdd, 0xa8, 0x33, 0x88, 0x07, 0xc7, 0x31, 0xb1, 0x12, 0x10, 0x59, 0x27, 0x80, 0xec, 0x5f,
150 0x60, 0x51, 0x7f, 0xa9, 0x19, 0xb5, 0x4a, 0x0d, 0x2d, 0xe5, 0x7a, 0x9f, 0x93, 0xc9, 0x9c, 0xef,
151 0xa0, 0xe0, 0x3b, 0x4d, 0xae, 0x2a, 0xf5, 0xb0, 0xc8, 0xeb, 0xbb, 0x3c, 0x83, 0x53, 0x99, 0x61,
152 0x17, 0x2b, 0x04, 0x7e, 0xba, 0x77, 0xd6, 0x26, 0xe1, 0x69, 0x14, 0x63, 0x55, 0x21, 0x0c, 0x7d,
153];
154
155static RCON: [u8; 255] = {
157 let mut rcon = [0u8; 255];
158 rcon[0] = 0x00;
159 rcon[1] = 0x01;
160 rcon[2] = 0x02;
161 rcon[3] = 0x04;
162 rcon[4] = 0x08;
163 rcon[5] = 0x10;
164 rcon[6] = 0x20;
165 rcon[7] = 0x40;
166 rcon[8] = 0x80;
167 rcon[9] = 0x1B;
168 rcon[10] = 0x36;
169 rcon
171};
172
173fn sub_bytes(state: &mut [u8; AES_BLOCK_SIZE]) {
174 for b in state.iter_mut() {
175 *b = SBOX[*b as usize];
176 }
177}
178
179fn inv_sub_bytes(state: &mut [u8; AES_BLOCK_SIZE]) {
180 for b in state.iter_mut() {
181 *b = INV_SBOX[*b as usize];
182 }
183}
184
185fn shift_rows(state: &mut [u8; AES_BLOCK_SIZE]) {
186 let row1 = [state[1], state[5], state[9], state[13]];
188 state[1] = row1[1];
189 state[5] = row1[2];
190 state[9] = row1[3];
191 state[13] = row1[0];
192
193 let row2 = [state[2], state[6], state[10], state[14]];
195 state[2] = row2[2];
196 state[6] = row2[3];
197 state[10] = row2[0];
198 state[14] = row2[1];
199
200 let row3 = [state[3], state[7], state[11], state[15]];
202 state[3] = row3[3];
203 state[7] = row3[0];
204 state[11] = row3[1];
205 state[15] = row3[2];
206}
207
208fn inv_shift_rows(state: &mut [u8; AES_BLOCK_SIZE]) {
209 let row1 = [state[1], state[5], state[9], state[13]];
211 state[1] = row1[3];
212 state[5] = row1[0];
213 state[9] = row1[1];
214 state[13] = row1[2];
215
216 let row2 = [state[2], state[6], state[10], state[14]];
218 state[2] = row2[2];
219 state[6] = row2[3];
220 state[10] = row2[0];
221 state[14] = row2[1];
222
223 let row3 = [state[3], state[7], state[11], state[15]];
225 state[3] = row3[1];
226 state[7] = row3[2];
227 state[11] = row3[3];
228 state[15] = row3[0];
229}
230
231fn xtime(x: u8) -> u8 {
232 if (x & 0x80) != 0 {
233 (x << 1) ^ 0x1B
234 } else {
235 x << 1
236 }
237}
238
239fn mix_columns(state: &mut [u8; AES_BLOCK_SIZE]) {
240 for col in 0..4 {
241 let base = col * 4;
242 let t = state[base] ^ state[base + 1] ^ state[base + 2] ^ state[base + 3];
243 let temp0 = state[base];
244 let temp1 = state[base + 1];
245 let temp2 = state[base + 2];
246 let temp3 = state[base + 3];
247
248 state[base] ^= t ^ xtime(temp0 ^ temp1);
249 state[base + 1] ^= t ^ xtime(temp1 ^ temp2);
250 state[base + 2] ^= t ^ xtime(temp2 ^ temp3);
251 state[base + 3] ^= t ^ xtime(temp3 ^ temp0);
252 }
253}
254
255fn inv_mix_columns(state: &mut [u8; AES_BLOCK_SIZE]) {
256 for col in 0..4 {
259 let base = col * 4;
260 let a0 = state[base];
261 let a1 = state[base + 1];
262 let a2 = state[base + 2];
263 let a3 = state[base + 3];
264
265 state[base] = mul(a0, 0x0e) ^ mul(a1, 0x0b) ^ mul(a2, 0x0d) ^ mul(a3, 0x09);
266 state[base + 1] = mul(a0, 0x09) ^ mul(a1, 0x0e) ^ mul(a2, 0x0b) ^ mul(a3, 0x0d);
267 state[base + 2] = mul(a0, 0x0d) ^ mul(a1, 0x09) ^ mul(a2, 0x0e) ^ mul(a3, 0x0b);
268 state[base + 3] = mul(a0, 0x0b) ^ mul(a1, 0x0d) ^ mul(a2, 0x09) ^ mul(a3, 0x0e);
269 }
270}
271
272fn mul(x: u8, y: u8) -> u8 {
273 let mut r = 0;
275 let mut a = x;
276 let mut b = y;
277 for _ in 0..8 {
278 if (b & 1) == 1 {
279 r ^= a;
280 }
281 let hi_bit_set = (a & 0x80) != 0;
282 a <<= 1;
283 if hi_bit_set {
284 a ^= 0x1b;
285 }
286 b >>= 1;
287 }
288 r
289}
290
291fn add_round_key(state: &mut [u8; AES_BLOCK_SIZE], round_key: &[u8; AES_BLOCK_SIZE]) {
292 for (s, k) in state.iter_mut().zip(round_key) {
293 *s ^= *k;
294 }
295}
296
297fn key_expansion(expanded: &mut [u8], nk: usize, nr: usize) {
299 let total_words = (nr + 1) * 4; let mut i = nk;
301 while i < total_words {
302 let mut temp = [
303 expanded[(i - 1) * 4],
304 expanded[(i - 1) * 4 + 1],
305 expanded[(i - 1) * 4 + 2],
306 expanded[(i - 1) * 4 + 3],
307 ];
308
309 if i % nk == 0 {
310 temp = [temp[1], temp[2], temp[3], temp[0]];
312 for t in temp.iter_mut() {
314 *t = SBOX[*t as usize];
315 }
316 temp[0] ^= RCON[i / nk];
318 } else if nk > 6 && i % nk == 4 {
319 for t in temp.iter_mut() {
320 *t = SBOX[*t as usize];
321 }
322 }
323
324 let wprev = (i - nk) * 4;
325 for (j, tj) in temp.iter().enumerate() {
326 expanded[i * 4 + j] = expanded[wprev + j] ^ tj;
327 }
328 i += 1;
329 }
330}
331
332#[cfg(test)]
333mod tests {
334 use super::*;
335
336 #[test]
340 fn test_aes128_encrypt_block() {
341 let key_data = hex_to_bytes("2b7e151628aed2a6abf7158809cf4f3c");
342 let mut block = hex_to_array("6bc1bee22e409f96e93d7e117393172a");
343
344 let aes_key = AesKey::new(&key_data, AesKeySize::Bits128);
345
346 aes_encrypt_block(&mut block, &aes_key);
347
348 let expected = hex_to_array("3ad77bb40d7a3660a89ecaf32466ef97");
349 assert_eq!(block, expected);
350 }
351
352 #[test]
353 fn test_aes128_decrypt_block() {
354 let key_data = hex_to_bytes("2b7e151628aed2a6abf7158809cf4f3c");
355 let mut block = hex_to_array("3ad77bb40d7a3660a89ecaf32466ef97");
356
357 let aes_key = AesKey::new(&key_data, AesKeySize::Bits128);
358
359 aes_decrypt_block(&mut block, &aes_key);
360
361 let expected = hex_to_array("6bc1bee22e409f96e93d7e117393172a");
362 assert_eq!(block, expected);
363 }
364
365 fn hex_to_bytes(s: &str) -> Vec<u8> {
367 (0..s.len())
368 .step_by(2)
369 .map(|i| u8::from_str_radix(&s[i..i + 2], 16).unwrap())
370 .collect()
371 }
372
373 fn hex_to_array(s: &str) -> [u8; 16] {
374 let bytes = hex_to_bytes(s);
375 bytes.try_into().unwrap()
376 }
377}