const_ciphers/
aes.rs

1use const_for::const_for;
2
3pub enum AesMode {
4    ECB,
5    CBC { iv: [u8; 16] },
6    CTR { iv: [u8; 12] },
7}
8
9pub struct AesConst {}
10
11impl AesConst {
12    pub const fn encrypt<const N: usize>(plaintext: &[u8; N], key: &[u8; 16], mode: &AesMode) -> [u8; N] {
13        match mode {
14            AesMode::ECB => Self::encrypt_ecb(plaintext, key),
15            AesMode::CBC { iv } => Self::encrypt_cbc(plaintext, key, iv),
16            AesMode::CTR { iv } => Self::encrypt_ctr(plaintext, key, iv),
17        }
18    }
19
20    pub const fn decrypt<const N: usize>(
21        ciphertext: &[u8; N],
22        key: &[u8; 16],
23        mode: &AesMode,
24    ) -> [u8; N] {
25        match mode {
26            AesMode::ECB => Self::decrypt_ecb(ciphertext, key),
27            AesMode::CBC { iv } => Self::decrypt_cbc(ciphertext, key, iv),
28            AesMode::CTR { iv } => Self::decrypt_ctr(ciphertext, key, iv),
29        }
30    }
31
32    pub const fn encrypt_block(plaintext: &[u8; 16], key: &[u8; 16]) -> [u8; 16] {
33        let round_keys = Self::key_expansion(key);
34        let mut state = Self::add_round_key(plaintext, &round_keys[0]);
35
36        let mut round = 1;
37        while round < 10 {
38            state = Self::sub_bytes(&state);
39            state = Self::shift_rows(&state);
40            state = Self::mix_columns(&state);
41            state = Self::add_round_key(&state, &round_keys[round]);
42            round += 1;
43        }
44
45        state = Self::sub_bytes(&state);
46        state = Self::shift_rows(&state);
47        state = Self::add_round_key(&state, &round_keys[10]);
48
49        state
50    }
51
52    pub const fn decrypt_block(ciphertext: &[u8; 16], key: &[u8; 16]) -> [u8; 16] {
53        let round_keys = Self::key_expansion(key);
54        let mut state = Self::add_round_key(ciphertext, &round_keys[10]);
55
56        let mut round = 9;
57        while round > 0 {
58            state = Self::inv_shift_rows(&state);
59            state = Self::inv_sub_bytes(&state);
60            state = Self::add_round_key(&state, &round_keys[round]);
61            state = Self::inv_mix_columns(&state);
62            round -= 1;
63        }
64
65        state = Self::inv_shift_rows(&state);
66        state = Self::inv_sub_bytes(&state);
67        state = Self::add_round_key(&state, &round_keys[0]);
68
69        state
70    }
71
72    const fn encrypt_ecb<const N: usize>(plaintext: &[u8; N], key: &[u8; 16]) -> [u8; N] {
73        let mut result = [0u8; N];
74        let mut i = 0;
75        while i < N {
76            let mut block = [0u8; 16];
77            let mut j = 0;
78            while j < 16 {
79                block[j] = plaintext[i + j];
80                j += 1;
81            }
82
83            let enc_block = Self::encrypt_block(&block, key);
84
85            j = 0;
86            while j < 16 {
87                result[i + j] = enc_block[j];
88                j += 1;
89            }
90            i += 16;
91        }
92        result
93    }
94
95    const fn decrypt_ecb<const N: usize>(ciphertext: &[u8; N], key: &[u8; 16]) -> [u8; N] {
96        if N % 16 != 0 {
97            panic!("Invalid ciphertext length for ECB.");
98        }
99
100        let mut result = [0u8; N];
101        let mut i = 0;
102        while i < N {
103            let mut block = [0u8; 16];
104            let mut j = 0;
105            while j < 16 {
106                block[j] = ciphertext[i + j];
107                j += 1;
108            }
109
110            let dec_block = Self::decrypt_block(&block, key);
111
112            j = 0;
113            while j < 16 {
114                result[i + j] = dec_block[j];
115                j += 1;
116            }
117            i += 16;
118        }
119        result
120    }
121
122    const fn encrypt_cbc<const N: usize>(
123        plaintext: &[u8; N],
124        key: &[u8; 16],
125        iv: &[u8; 16],
126    ) -> [u8; N] {
127        if N % 16 != 0 {
128            panic!("Invalid plaintext length for CBC.");
129        }
130
131        let mut result = [0u8; N];
132        let mut prev = *iv;
133        let mut i = 0;
134        while i < N {
135            let mut block = [0u8; 16];
136            let mut j = 0;
137            while j < 16 {
138                block[j] = plaintext[i + j] ^ prev[j];
139                j += 1;
140            }
141
142            let enc_block = Self::encrypt_block(&block, key);
143
144            j = 0;
145            while j < 16 {
146                result[i + j] = enc_block[j];
147                j += 1;
148            }
149
150            prev = enc_block;
151            i += 16;
152        }
153        result
154    }
155
156    const fn decrypt_cbc<const N: usize>(
157        ciphertext: &[u8; N],
158        key: &[u8; 16],
159        iv: &[u8; 16],
160    ) -> [u8; N] {
161        let mut result = [0u8; N];
162        let mut prev = *iv;
163        let mut i = 0;
164        while i < N {
165            let mut block = [0u8; 16];
166            let mut j = 0;
167            while j < 16 {
168                block[j] = ciphertext[i + j];
169                j += 1;
170            }
171
172            let dec_block = Self::decrypt_block(&block, key);
173
174            j = 0;
175            while j < 16 {
176                result[i + j] = dec_block[j] ^ prev[j];
177                j += 1;
178            }
179
180            prev = block;
181            i += 16;
182        }
183        result
184    }
185
186    const fn encrypt_ctr<const N: usize>(
187        plaintext: &[u8; N],
188        key: &[u8; 16],
189        iv: &[u8; 12],
190    ) -> [u8; N] {
191        let mut ciphertext = [0u8; N];
192        let mut counter = Self::ctr_init(iv);
193
194        let mut i = 0;
195        while i < N {
196            let keystream = Self::encrypt_block(&counter, key);
197            let block_size = if i + 16 > N { N - i } else { 16 };
198            const_for!(j in 0..block_size => {
199                ciphertext[i + j] = plaintext[i + j] ^ keystream[j];
200            });
201            counter = Self::ctr_increment(&counter);
202            i += 16;
203        }
204
205        ciphertext
206    }
207
208    const fn decrypt_ctr<const N: usize>(
209        ciphertext: &[u8; N],
210        key: &[u8; 16],
211        iv: &[u8; 12],
212    ) -> [u8; N] {
213        // CTR decryption is identical to encryption
214        Self::encrypt_ctr(ciphertext, key, iv)
215    }
216
217    const fn ctr_init(iv: &[u8; 12]) -> [u8; 16] {
218        let mut counter = [0u8; 16];
219        let mut i = 0;
220        while i < 12 {
221            counter[i] = iv[i];
222            i += 1;
223        }
224        counter[15] = 1;
225        counter
226    }
227
228    const fn ctr_increment(counter: &[u8; 16]) -> [u8; 16] {
229        let mut new_counter = *counter;
230        let mut i = 15;
231        while i >= 12 {
232            if new_counter[i] == 255 {
233                new_counter[i] = 0;
234                if i == 12 {
235                    break;
236                }
237                i -= 1;
238            } else {
239                new_counter[i] += 1;
240                break;
241            }
242        }
243        new_counter
244    }
245
246    const SBOX: [u8; 256] = [
247        0x63, 0x7c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5, 0x30, 0x01, 0x67, 0x2b, 0xfe, 0xd7, 0xab,
248        0x76, 0xca, 0x82, 0xc9, 0x7d, 0xfa, 0x59, 0x47, 0xf0, 0xad, 0xd4, 0xa2, 0xaf, 0x9c, 0xa4,
249        0x72, 0xc0, 0xb7, 0xfd, 0x93, 0x26, 0x36, 0x3f, 0xf7, 0xcc, 0x34, 0xa5, 0xe5, 0xf1, 0x71,
250        0xd8, 0x31, 0x15, 0x04, 0xc7, 0x23, 0xc3, 0x18, 0x96, 0x05, 0x9a, 0x07, 0x12, 0x80, 0xe2,
251        0xeb, 0x27, 0xb2, 0x75, 0x09, 0x83, 0x2c, 0x1a, 0x1b, 0x6e, 0x5a, 0xa0, 0x52, 0x3b, 0xd6,
252        0xb3, 0x29, 0xe3, 0x2f, 0x84, 0x53, 0xd1, 0x00, 0xed, 0x20, 0xfc, 0xb1, 0x5b, 0x6a, 0xcb,
253        0xbe, 0x39, 0x4a, 0x4c, 0x58, 0xcf, 0xd0, 0xef, 0xaa, 0xfb, 0x43, 0x4d, 0x33, 0x85, 0x45,
254        0xf9, 0x02, 0x7f, 0x50, 0x3c, 0x9f, 0xa8, 0x51, 0xa3, 0x40, 0x8f, 0x92, 0x9d, 0x38, 0xf5,
255        0xbc, 0xb6, 0xda, 0x21, 0x10, 0xff, 0xf3, 0xd2, 0xcd, 0x0c, 0x13, 0xec, 0x5f, 0x97, 0x44,
256        0x17, 0xc4, 0xa7, 0x7e, 0x3d, 0x64, 0x5d, 0x19, 0x73, 0x60, 0x81, 0x4f, 0xdc, 0x22, 0x2a,
257        0x90, 0x88, 0x46, 0xee, 0xb8, 0x14, 0xde, 0x5e, 0x0b, 0xdb, 0xe0, 0x32, 0x3a, 0x0a, 0x49,
258        0x06, 0x24, 0x5c, 0xc2, 0xd3, 0xac, 0x62, 0x91, 0x95, 0xe4, 0x79, 0xe7, 0xc8, 0x37, 0x6d,
259        0x8d, 0xd5, 0x4e, 0xa9, 0x6c, 0x56, 0xf4, 0xea, 0x65, 0x7a, 0xae, 0x08, 0xba, 0x78, 0x25,
260        0x2e, 0x1c, 0xa6, 0xb4, 0xc6, 0xe8, 0xdd, 0x74, 0x1f, 0x4b, 0xbd, 0x8b, 0x8a, 0x70, 0x3e,
261        0xb5, 0x66, 0x48, 0x03, 0xf6, 0x0e, 0x61, 0x35, 0x57, 0xb9, 0x86, 0xc1, 0x1d, 0x9e, 0xe1,
262        0xf8, 0x98, 0x11, 0x69, 0xd9, 0x8e, 0x94, 0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf,
263        0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, 0x41, 0x99, 0x2d, 0x0f, 0xb0, 0x54, 0xbb,
264        0x16,
265    ];
266
267    const INV_SBOX: [u8; 256] = [
268        0x52, 0x09, 0x6a, 0xd5, 0x30, 0x36, 0xa5, 0x38, 0xbf, 0x40, 0xa3, 0x9e, 0x81, 0xf3, 0xd7,
269        0xfb, 0x7c, 0xe3, 0x39, 0x82, 0x9b, 0x2f, 0xff, 0x87, 0x34, 0x8e, 0x43, 0x44, 0xc4, 0xde,
270        0xe9, 0xcb, 0x54, 0x7b, 0x94, 0x32, 0xa6, 0xc2, 0x23, 0x3d, 0xee, 0x4c, 0x95, 0x0b, 0x42,
271        0xfa, 0xc3, 0x4e, 0x08, 0x2e, 0xa1, 0x66, 0x28, 0xd9, 0x24, 0xb2, 0x76, 0x5b, 0xa2, 0x49,
272        0x6d, 0x8b, 0xd1, 0x25, 0x72, 0xf8, 0xf6, 0x64, 0x86, 0x68, 0x98, 0x16, 0xd4, 0xa4, 0x5c,
273        0xcc, 0x5d, 0x65, 0xb6, 0x92, 0x6c, 0x70, 0x48, 0x50, 0xfd, 0xed, 0xb9, 0xda, 0x5e, 0x15,
274        0x46, 0x57, 0xa7, 0x8d, 0x9d, 0x84, 0x90, 0xd8, 0xab, 0x00, 0x8c, 0xbc, 0xd3, 0x0a, 0xf7,
275        0xe4, 0x58, 0x05, 0xb8, 0xb3, 0x45, 0x06, 0xd0, 0x2c, 0x1e, 0x8f, 0xca, 0x3f, 0x0f, 0x02,
276        0xc1, 0xaf, 0xbd, 0x03, 0x01, 0x13, 0x8a, 0x6b, 0x3a, 0x91, 0x11, 0x41, 0x4f, 0x67, 0xdc,
277        0xea, 0x97, 0xf2, 0xcf, 0xce, 0xf0, 0xb4, 0xe6, 0x73, 0x96, 0xac, 0x74, 0x22, 0xe7, 0xad,
278        0x35, 0x85, 0xe2, 0xf9, 0x37, 0xe8, 0x1c, 0x75, 0xdf, 0x6e, 0x47, 0xf1, 0x1a, 0x71, 0x1d,
279        0x29, 0xc5, 0x89, 0x6f, 0xb7, 0x62, 0x0e, 0xaa, 0x18, 0xbe, 0x1b, 0xfc, 0x56, 0x3e, 0x4b,
280        0xc6, 0xd2, 0x79, 0x20, 0x9a, 0xdb, 0xc0, 0xfe, 0x78, 0xcd, 0x5a, 0xf4, 0x1f, 0xdd, 0xa8,
281        0x33, 0x88, 0x07, 0xc7, 0x31, 0xb1, 0x12, 0x10, 0x59, 0x27, 0x80, 0xec, 0x5f, 0x60, 0x51,
282        0x7f, 0xa9, 0x19, 0xb5, 0x4a, 0x0d, 0x2d, 0xe5, 0x7a, 0x9f, 0x93, 0xc9, 0x9c, 0xef, 0xa0,
283        0xe0, 0x3b, 0x4d, 0xae, 0x2a, 0xf5, 0xb0, 0xc8, 0xeb, 0xbb, 0x3c, 0x83, 0x53, 0x99, 0x61,
284        0x17, 0x2b, 0x04, 0x7e, 0xba, 0x77, 0xd6, 0x26, 0xe1, 0x69, 0x14, 0x63, 0x55, 0x21, 0x0c,
285        0x7d,
286    ];
287
288    const RCON: [u8; 11] = [
289        0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1B, 0x36, 0x6C,
290    ];
291
292    const fn key_expansion(key: &[u8; 16]) -> [[u8; 16]; 11] {
293        let mut round_keys = [[0u8; 16]; 11];
294        let mut i = 0;
295
296        let mut j = 0;
297        while j < 16 {
298            round_keys[0][j] = key[j];
299            j += 1;
300        }
301
302        while i < 10 {
303            let mut t = [0u8; 4];
304            t[0] = round_keys[i][12];
305            t[1] = round_keys[i][13];
306            t[2] = round_keys[i][14];
307            t[3] = round_keys[i][15];
308
309            t = [t[1], t[2], t[3], t[0]];
310
311            let mut k = 0;
312            while k < 4 {
313                t[k] = Self::sbox(t[k]);
314                k += 1;
315            }
316
317            t[0] ^= Self::RCON[i + 1];
318
319            let mut m = 0;
320            while m < 4 {
321                round_keys[i + 1][m] = round_keys[i][m] ^ t[m];
322                m += 1;
323            }
324
325            let mut n = 4;
326            while n < 16 {
327                round_keys[i + 1][n] = round_keys[i + 1][n - 4] ^ round_keys[i][n];
328                n += 1;
329            }
330
331            i += 1;
332        }
333
334        round_keys
335    }
336
337    const fn sbox(byte: u8) -> u8 {
338        Self::SBOX[byte as usize]
339    }
340
341    const fn inv_sbox(byte: u8) -> u8 {
342        Self::INV_SBOX[byte as usize]
343    }
344
345    const fn add_round_key(state: &[u8; 16], round_key: &[u8; 16]) -> [u8; 16] {
346        let mut new_state = [0u8; 16];
347        let mut i = 0;
348        while i < 16 {
349            new_state[i] = state[i] ^ round_key[i];
350            i += 1;
351        }
352        new_state
353    }
354
355    const fn sub_bytes(state: &[u8; 16]) -> [u8; 16] {
356        let mut new_state = [0u8; 16];
357        let mut i = 0;
358        while i < 16 {
359            new_state[i] = Self::sbox(state[i]);
360            i += 1;
361        }
362        new_state
363    }
364
365    const fn inv_sub_bytes(state: &[u8; 16]) -> [u8; 16] {
366        let mut new_state = [0u8; 16];
367        let mut i = 0;
368        while i < 16 {
369            new_state[i] = Self::inv_sbox(state[i]);
370            i += 1;
371        }
372        new_state
373    }
374
375    const fn shift_rows(state: &[u8; 16]) -> [u8; 16] {
376        let mut new_state = [0u8; 16];
377
378        new_state[0] = state[0];
379        new_state[4] = state[4];
380        new_state[8] = state[8];
381        new_state[12] = state[12];
382
383        new_state[1] = state[5];
384        new_state[5] = state[9];
385        new_state[9] = state[13];
386        new_state[13] = state[1];
387
388        new_state[2] = state[10];
389        new_state[6] = state[14];
390        new_state[10] = state[2];
391        new_state[14] = state[6];
392
393        new_state[3] = state[15];
394        new_state[7] = state[3];
395        new_state[11] = state[7];
396        new_state[15] = state[11];
397
398        new_state
399    }
400
401    const fn inv_shift_rows(state: &[u8; 16]) -> [u8; 16] {
402        let mut new_state = [0u8; 16];
403
404        new_state[0] = state[0];
405        new_state[4] = state[4];
406        new_state[8] = state[8];
407        new_state[12] = state[12];
408
409        new_state[1] = state[13];
410        new_state[5] = state[1];
411        new_state[9] = state[5];
412        new_state[13] = state[9];
413
414        new_state[2] = state[10];
415        new_state[6] = state[14];
416        new_state[10] = state[2];
417        new_state[14] = state[6];
418
419        new_state[3] = state[7];
420        new_state[7] = state[11];
421        new_state[11] = state[15];
422        new_state[15] = state[3];
423
424        new_state
425    }
426
427    const fn gf_mul(a: u8, b: u8) -> u8 {
428        let mut result = 0;
429        let mut a = a;
430        let mut b = b;
431        let mut i = 0;
432        while i < 8 {
433            if (b & 1) != 0 {
434                result ^= a;
435            }
436            let high_bit = (a & 0x80) != 0;
437            a <<= 1;
438            if high_bit {
439                a ^= 0x1b;
440            }
441            b >>= 1;
442            i += 1;
443        }
444        result
445    }
446
447    const fn mix_columns(state: &[u8; 16]) -> [u8; 16] {
448        let mut new_state = [0u8; 16];
449        let mut i = 0;
450        while i < 16 {
451            let s0 = state[i];
452            let s1 = state[i + 1];
453            let s2 = state[i + 2];
454            let s3 = state[i + 3];
455
456            new_state[i] = Self::gf_mul(s0, 2) ^ Self::gf_mul(s1, 3) ^ s2 ^ s3;
457            new_state[i + 1] = s0 ^ Self::gf_mul(s1, 2) ^ Self::gf_mul(s2, 3) ^ s3;
458            new_state[i + 2] = s0 ^ s1 ^ Self::gf_mul(s2, 2) ^ Self::gf_mul(s3, 3);
459            new_state[i + 3] = Self::gf_mul(s0, 3) ^ s1 ^ s2 ^ Self::gf_mul(s3, 2);
460
461            i += 4;
462        }
463        new_state
464    }
465
466    const fn inv_mix_columns(state: &[u8; 16]) -> [u8; 16] {
467        let mut new_state = [0u8; 16];
468        let mut i = 0;
469        while i < 16 {
470            let s0 = state[i];
471            let s1 = state[i + 1];
472            let s2 = state[i + 2];
473            let s3 = state[i + 3];
474
475            new_state[i] = Self::gf_mul(s0, 14)
476                ^ Self::gf_mul(s1, 11)
477                ^ Self::gf_mul(s2, 13)
478                ^ Self::gf_mul(s3, 9);
479            new_state[i + 1] = Self::gf_mul(s0, 9)
480                ^ Self::gf_mul(s1, 14)
481                ^ Self::gf_mul(s2, 11)
482                ^ Self::gf_mul(s3, 13);
483            new_state[i + 2] = Self::gf_mul(s0, 13)
484                ^ Self::gf_mul(s1, 9)
485                ^ Self::gf_mul(s2, 14)
486                ^ Self::gf_mul(s3, 11);
487            new_state[i + 3] = Self::gf_mul(s0, 11)
488                ^ Self::gf_mul(s1, 13)
489                ^ Self::gf_mul(s2, 9)
490                ^ Self::gf_mul(s3, 14);
491
492            i += 4;
493        }
494        new_state
495    }
496}
497
498pub struct Aes {}
499
500impl Aes {
501    pub fn encrypt<const N: usize>(plaintext: &[u8; N], key: &[u8; 16], mode: &AesMode) -> [u8; N] {
502        match mode {
503            AesMode::ECB => Self::encrypt_ecb(plaintext, key),
504            AesMode::CBC { iv } => Self::encrypt_cbc(plaintext, key, iv),
505            AesMode::CTR { iv } => Self::encrypt_ctr(plaintext, key, iv),
506        }
507    }
508
509    pub fn decrypt<const N: usize>(
510        ciphertext: &[u8; N],
511        key: &[u8; 16],
512        mode: &AesMode,
513    ) -> [u8; N] {
514        match mode {
515            AesMode::ECB => Self::decrypt_ecb(ciphertext, key),
516            AesMode::CBC { iv } => Self::decrypt_cbc(ciphertext, key, iv),
517            AesMode::CTR { iv } => Self::decrypt_ctr(ciphertext, key, iv),
518        }
519    }
520
521    pub fn encrypt_block(plaintext: &[u8; 16], key: &[u8; 16]) -> [u8; 16] {
522        let round_keys = Self::key_expansion(key);
523        let mut state = Self::add_round_key(plaintext, &round_keys[0]);
524
525        let mut round = 1;
526        while round < 10 {
527            state = Self::sub_bytes(&state);
528            state = Self::shift_rows(&state);
529            state = Self::mix_columns(&state);
530            state = Self::add_round_key(&state, &round_keys[round]);
531            round += 1;
532        }
533
534        state = Self::sub_bytes(&state);
535        state = Self::shift_rows(&state);
536        state = Self::add_round_key(&state, &round_keys[10]);
537
538        state
539    }
540
541    pub fn decrypt_block(ciphertext: &[u8; 16], key: &[u8; 16]) -> [u8; 16] {
542        let round_keys = Self::key_expansion(key);
543        let mut state = Self::add_round_key(ciphertext, &round_keys[10]);
544
545        let mut round = 9;
546        while round > 0 {
547            state = Self::inv_shift_rows(&state);
548            state = Self::inv_sub_bytes(&state);
549            state = Self::add_round_key(&state, &round_keys[round]);
550            state = Self::inv_mix_columns(&state);
551            round -= 1;
552        }
553
554        state = Self::inv_shift_rows(&state);
555        state = Self::inv_sub_bytes(&state);
556        state = Self::add_round_key(&state, &round_keys[0]);
557
558        state
559    }
560
561    fn encrypt_ecb<const N: usize>(plaintext: &[u8; N], key: &[u8; 16]) -> [u8; N] {
562        let mut result = [0u8; N];
563        let mut i = 0;
564        while i < N {
565            let mut block = [0u8; 16];
566            let mut j = 0;
567            while j < 16 {
568                block[j] = plaintext[i + j];
569                j += 1;
570            }
571
572            let enc_block = Self::encrypt_block(&block, key);
573
574            j = 0;
575            while j < 16 {
576                result[i + j] = enc_block[j];
577                j += 1;
578            }
579            i += 16;
580        }
581        result
582    }
583
584    fn decrypt_ecb<const N: usize>(ciphertext: &[u8; N], key: &[u8; 16]) -> [u8; N] {
585        if N % 16 != 0 {
586            panic!("Invalid ciphertext length for ECB.");
587        }
588
589        let mut result = [0u8; N];
590        let mut i = 0;
591        while i < N {
592            let mut block = [0u8; 16];
593            let mut j = 0;
594            while j < 16 {
595                block[j] = ciphertext[i + j];
596                j += 1;
597            }
598
599            let dec_block = Self::decrypt_block(&block, key);
600
601            j = 0;
602            while j < 16 {
603                result[i + j] = dec_block[j];
604                j += 1;
605            }
606            i += 16;
607        }
608        result
609    }
610
611    fn encrypt_cbc<const N: usize>(plaintext: &[u8; N], key: &[u8; 16], iv: &[u8; 16]) -> [u8; N] {
612        if N % 16 != 0 {
613            panic!("Invalid plaintext length for CBC.");
614        }
615
616        let mut result = [0u8; N];
617        let mut prev = *iv;
618        let mut i = 0;
619        while i < N {
620            let mut block = [0u8; 16];
621            let mut j = 0;
622            while j < 16 {
623                block[j] = plaintext[i + j] ^ prev[j];
624                j += 1;
625            }
626
627            let enc_block = Self::encrypt_block(&block, key);
628
629            j = 0;
630            while j < 16 {
631                result[i + j] = enc_block[j];
632                j += 1;
633            }
634
635            prev = enc_block;
636            i += 16;
637        }
638        result
639    }
640
641    fn decrypt_cbc<const N: usize>(ciphertext: &[u8; N], key: &[u8; 16], iv: &[u8; 16]) -> [u8; N] {
642        let mut result = [0u8; N];
643        let mut prev = *iv;
644        let mut i = 0;
645        while i < N {
646            let mut block = [0u8; 16];
647            let mut j = 0;
648            while j < 16 {
649                block[j] = ciphertext[i + j];
650                j += 1;
651            }
652
653            let dec_block = Self::decrypt_block(&block, key);
654
655            j = 0;
656            while j < 16 {
657                result[i + j] = dec_block[j] ^ prev[j];
658                j += 1;
659            }
660
661            prev = block;
662            i += 16;
663        }
664        result
665    }
666
667    fn encrypt_ctr<const N: usize>(plaintext: &[u8; N], key: &[u8; 16], iv: &[u8; 12]) -> [u8; N] {
668        let mut ciphertext = [0u8; N];
669        let mut counter = Self::ctr_init(iv);
670
671        let mut i = 0;
672        while i < N {
673            let keystream = Self::encrypt_block(&counter, key);
674            let block_size = if i + 16 > N { N - i } else { 16 };
675            const_for!(j in 0..block_size => {
676                ciphertext[i + j] = plaintext[i + j] ^ keystream[j];
677            });
678            counter = Self::ctr_increment(&counter);
679            i += 16;
680        }
681
682        ciphertext
683    }
684
685    fn decrypt_ctr<const N: usize>(ciphertext: &[u8; N], key: &[u8; 16], iv: &[u8; 12]) -> [u8; N] {
686        // CTR decryption is identical to encryption
687        Self::encrypt_ctr(ciphertext, key, iv)
688    }
689
690    fn ctr_init(iv: &[u8; 12]) -> [u8; 16] {
691        let mut counter = [0u8; 16];
692        let mut i = 0;
693        while i < 12 {
694            counter[i] = iv[i];
695            i += 1;
696        }
697        counter[15] = 1;
698        counter
699    }
700
701    fn ctr_increment(counter: &[u8; 16]) -> [u8; 16] {
702        let mut new_counter = *counter;
703        let mut i = 15;
704        while i >= 12 {
705            if new_counter[i] == 255 {
706                new_counter[i] = 0;
707                if i == 12 {
708                    break;
709                }
710                i -= 1;
711            } else {
712                new_counter[i] += 1;
713                break;
714            }
715        }
716        new_counter
717    }
718
719    const SBOX: [u8; 256] = [
720        0x63, 0x7c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5, 0x30, 0x01, 0x67, 0x2b, 0xfe, 0xd7, 0xab,
721        0x76, 0xca, 0x82, 0xc9, 0x7d, 0xfa, 0x59, 0x47, 0xf0, 0xad, 0xd4, 0xa2, 0xaf, 0x9c, 0xa4,
722        0x72, 0xc0, 0xb7, 0xfd, 0x93, 0x26, 0x36, 0x3f, 0xf7, 0xcc, 0x34, 0xa5, 0xe5, 0xf1, 0x71,
723        0xd8, 0x31, 0x15, 0x04, 0xc7, 0x23, 0xc3, 0x18, 0x96, 0x05, 0x9a, 0x07, 0x12, 0x80, 0xe2,
724        0xeb, 0x27, 0xb2, 0x75, 0x09, 0x83, 0x2c, 0x1a, 0x1b, 0x6e, 0x5a, 0xa0, 0x52, 0x3b, 0xd6,
725        0xb3, 0x29, 0xe3, 0x2f, 0x84, 0x53, 0xd1, 0x00, 0xed, 0x20, 0xfc, 0xb1, 0x5b, 0x6a, 0xcb,
726        0xbe, 0x39, 0x4a, 0x4c, 0x58, 0xcf, 0xd0, 0xef, 0xaa, 0xfb, 0x43, 0x4d, 0x33, 0x85, 0x45,
727        0xf9, 0x02, 0x7f, 0x50, 0x3c, 0x9f, 0xa8, 0x51, 0xa3, 0x40, 0x8f, 0x92, 0x9d, 0x38, 0xf5,
728        0xbc, 0xb6, 0xda, 0x21, 0x10, 0xff, 0xf3, 0xd2, 0xcd, 0x0c, 0x13, 0xec, 0x5f, 0x97, 0x44,
729        0x17, 0xc4, 0xa7, 0x7e, 0x3d, 0x64, 0x5d, 0x19, 0x73, 0x60, 0x81, 0x4f, 0xdc, 0x22, 0x2a,
730        0x90, 0x88, 0x46, 0xee, 0xb8, 0x14, 0xde, 0x5e, 0x0b, 0xdb, 0xe0, 0x32, 0x3a, 0x0a, 0x49,
731        0x06, 0x24, 0x5c, 0xc2, 0xd3, 0xac, 0x62, 0x91, 0x95, 0xe4, 0x79, 0xe7, 0xc8, 0x37, 0x6d,
732        0x8d, 0xd5, 0x4e, 0xa9, 0x6c, 0x56, 0xf4, 0xea, 0x65, 0x7a, 0xae, 0x08, 0xba, 0x78, 0x25,
733        0x2e, 0x1c, 0xa6, 0xb4, 0xc6, 0xe8, 0xdd, 0x74, 0x1f, 0x4b, 0xbd, 0x8b, 0x8a, 0x70, 0x3e,
734        0xb5, 0x66, 0x48, 0x03, 0xf6, 0x0e, 0x61, 0x35, 0x57, 0xb9, 0x86, 0xc1, 0x1d, 0x9e, 0xe1,
735        0xf8, 0x98, 0x11, 0x69, 0xd9, 0x8e, 0x94, 0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf,
736        0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, 0x41, 0x99, 0x2d, 0x0f, 0xb0, 0x54, 0xbb,
737        0x16,
738    ];
739
740    const INV_SBOX: [u8; 256] = [
741        0x52, 0x09, 0x6a, 0xd5, 0x30, 0x36, 0xa5, 0x38, 0xbf, 0x40, 0xa3, 0x9e, 0x81, 0xf3, 0xd7,
742        0xfb, 0x7c, 0xe3, 0x39, 0x82, 0x9b, 0x2f, 0xff, 0x87, 0x34, 0x8e, 0x43, 0x44, 0xc4, 0xde,
743        0xe9, 0xcb, 0x54, 0x7b, 0x94, 0x32, 0xa6, 0xc2, 0x23, 0x3d, 0xee, 0x4c, 0x95, 0x0b, 0x42,
744        0xfa, 0xc3, 0x4e, 0x08, 0x2e, 0xa1, 0x66, 0x28, 0xd9, 0x24, 0xb2, 0x76, 0x5b, 0xa2, 0x49,
745        0x6d, 0x8b, 0xd1, 0x25, 0x72, 0xf8, 0xf6, 0x64, 0x86, 0x68, 0x98, 0x16, 0xd4, 0xa4, 0x5c,
746        0xcc, 0x5d, 0x65, 0xb6, 0x92, 0x6c, 0x70, 0x48, 0x50, 0xfd, 0xed, 0xb9, 0xda, 0x5e, 0x15,
747        0x46, 0x57, 0xa7, 0x8d, 0x9d, 0x84, 0x90, 0xd8, 0xab, 0x00, 0x8c, 0xbc, 0xd3, 0x0a, 0xf7,
748        0xe4, 0x58, 0x05, 0xb8, 0xb3, 0x45, 0x06, 0xd0, 0x2c, 0x1e, 0x8f, 0xca, 0x3f, 0x0f, 0x02,
749        0xc1, 0xaf, 0xbd, 0x03, 0x01, 0x13, 0x8a, 0x6b, 0x3a, 0x91, 0x11, 0x41, 0x4f, 0x67, 0xdc,
750        0xea, 0x97, 0xf2, 0xcf, 0xce, 0xf0, 0xb4, 0xe6, 0x73, 0x96, 0xac, 0x74, 0x22, 0xe7, 0xad,
751        0x35, 0x85, 0xe2, 0xf9, 0x37, 0xe8, 0x1c, 0x75, 0xdf, 0x6e, 0x47, 0xf1, 0x1a, 0x71, 0x1d,
752        0x29, 0xc5, 0x89, 0x6f, 0xb7, 0x62, 0x0e, 0xaa, 0x18, 0xbe, 0x1b, 0xfc, 0x56, 0x3e, 0x4b,
753        0xc6, 0xd2, 0x79, 0x20, 0x9a, 0xdb, 0xc0, 0xfe, 0x78, 0xcd, 0x5a, 0xf4, 0x1f, 0xdd, 0xa8,
754        0x33, 0x88, 0x07, 0xc7, 0x31, 0xb1, 0x12, 0x10, 0x59, 0x27, 0x80, 0xec, 0x5f, 0x60, 0x51,
755        0x7f, 0xa9, 0x19, 0xb5, 0x4a, 0x0d, 0x2d, 0xe5, 0x7a, 0x9f, 0x93, 0xc9, 0x9c, 0xef, 0xa0,
756        0xe0, 0x3b, 0x4d, 0xae, 0x2a, 0xf5, 0xb0, 0xc8, 0xeb, 0xbb, 0x3c, 0x83, 0x53, 0x99, 0x61,
757        0x17, 0x2b, 0x04, 0x7e, 0xba, 0x77, 0xd6, 0x26, 0xe1, 0x69, 0x14, 0x63, 0x55, 0x21, 0x0c,
758        0x7d,
759    ];
760
761    const RCON: [u8; 11] = [
762        0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1B, 0x36, 0x6C,
763    ];
764
765    fn key_expansion(key: &[u8; 16]) -> [[u8; 16]; 11] {
766        let mut round_keys = [[0u8; 16]; 11];
767        let mut i = 0;
768
769        let mut j = 0;
770        while j < 16 {
771            round_keys[0][j] = key[j];
772            j += 1;
773        }
774
775        while i < 10 {
776            let mut t = [0u8; 4];
777            t[0] = round_keys[i][12];
778            t[1] = round_keys[i][13];
779            t[2] = round_keys[i][14];
780            t[3] = round_keys[i][15];
781
782            t = [t[1], t[2], t[3], t[0]];
783
784            let mut k = 0;
785            while k < 4 {
786                t[k] = Self::sbox(t[k]);
787                k += 1;
788            }
789
790            t[0] ^= Self::RCON[i + 1];
791
792            let mut m = 0;
793            while m < 4 {
794                round_keys[i + 1][m] = round_keys[i][m] ^ t[m];
795                m += 1;
796            }
797
798            let mut n = 4;
799            while n < 16 {
800                round_keys[i + 1][n] = round_keys[i + 1][n - 4] ^ round_keys[i][n];
801                n += 1;
802            }
803
804            i += 1;
805        }
806
807        round_keys
808    }
809
810    fn sbox(byte: u8) -> u8 {
811        Self::SBOX[byte as usize]
812    }
813
814    fn inv_sbox(byte: u8) -> u8 {
815        Self::INV_SBOX[byte as usize]
816    }
817
818    fn add_round_key(state: &[u8; 16], round_key: &[u8; 16]) -> [u8; 16] {
819        let mut new_state = [0u8; 16];
820        let mut i = 0;
821        while i < 16 {
822            new_state[i] = state[i] ^ round_key[i];
823            i += 1;
824        }
825        new_state
826    }
827
828    fn sub_bytes(state: &[u8; 16]) -> [u8; 16] {
829        let mut new_state = [0u8; 16];
830        let mut i = 0;
831        while i < 16 {
832            new_state[i] = Self::sbox(state[i]);
833            i += 1;
834        }
835        new_state
836    }
837
838    fn inv_sub_bytes(state: &[u8; 16]) -> [u8; 16] {
839        let mut new_state = [0u8; 16];
840        let mut i = 0;
841        while i < 16 {
842            new_state[i] = Self::inv_sbox(state[i]);
843            i += 1;
844        }
845        new_state
846    }
847
848    fn shift_rows(state: &[u8; 16]) -> [u8; 16] {
849        let mut new_state = [0u8; 16];
850
851        new_state[0] = state[0];
852        new_state[4] = state[4];
853        new_state[8] = state[8];
854        new_state[12] = state[12];
855
856        new_state[1] = state[5];
857        new_state[5] = state[9];
858        new_state[9] = state[13];
859        new_state[13] = state[1];
860
861        new_state[2] = state[10];
862        new_state[6] = state[14];
863        new_state[10] = state[2];
864        new_state[14] = state[6];
865
866        new_state[3] = state[15];
867        new_state[7] = state[3];
868        new_state[11] = state[7];
869        new_state[15] = state[11];
870
871        new_state
872    }
873
874    fn inv_shift_rows(state: &[u8; 16]) -> [u8; 16] {
875        let mut new_state = [0u8; 16];
876
877        new_state[0] = state[0];
878        new_state[4] = state[4];
879        new_state[8] = state[8];
880        new_state[12] = state[12];
881
882        new_state[1] = state[13];
883        new_state[5] = state[1];
884        new_state[9] = state[5];
885        new_state[13] = state[9];
886
887        new_state[2] = state[10];
888        new_state[6] = state[14];
889        new_state[10] = state[2];
890        new_state[14] = state[6];
891
892        new_state[3] = state[7];
893        new_state[7] = state[11];
894        new_state[11] = state[15];
895        new_state[15] = state[3];
896
897        new_state
898    }
899
900    fn gf_mul(a: u8, b: u8) -> u8 {
901        let mut result = 0;
902        let mut a = a;
903        let mut b = b;
904        let mut i = 0;
905        while i < 8 {
906            if (b & 1) != 0 {
907                result ^= a;
908            }
909            let high_bit = (a & 0x80) != 0;
910            a <<= 1;
911            if high_bit {
912                a ^= 0x1b;
913            }
914            b >>= 1;
915            i += 1;
916        }
917        result
918    }
919
920    fn mix_columns(state: &[u8; 16]) -> [u8; 16] {
921        let mut new_state = [0u8; 16];
922        let mut i = 0;
923        while i < 16 {
924            let s0 = state[i];
925            let s1 = state[i + 1];
926            let s2 = state[i + 2];
927            let s3 = state[i + 3];
928
929            new_state[i] = Self::gf_mul(s0, 2) ^ Self::gf_mul(s1, 3) ^ s2 ^ s3;
930            new_state[i + 1] = s0 ^ Self::gf_mul(s1, 2) ^ Self::gf_mul(s2, 3) ^ s3;
931            new_state[i + 2] = s0 ^ s1 ^ Self::gf_mul(s2, 2) ^ Self::gf_mul(s3, 3);
932            new_state[i + 3] = Self::gf_mul(s0, 3) ^ s1 ^ s2 ^ Self::gf_mul(s3, 2);
933
934            i += 4;
935        }
936        new_state
937    }
938
939    fn inv_mix_columns(state: &[u8; 16]) -> [u8; 16] {
940        let mut new_state = [0u8; 16];
941        let mut i = 0;
942        while i < 16 {
943            let s0 = state[i];
944            let s1 = state[i + 1];
945            let s2 = state[i + 2];
946            let s3 = state[i + 3];
947
948            new_state[i] = Self::gf_mul(s0, 14)
949                ^ Self::gf_mul(s1, 11)
950                ^ Self::gf_mul(s2, 13)
951                ^ Self::gf_mul(s3, 9);
952            new_state[i + 1] = Self::gf_mul(s0, 9)
953                ^ Self::gf_mul(s1, 14)
954                ^ Self::gf_mul(s2, 11)
955                ^ Self::gf_mul(s3, 13);
956            new_state[i + 2] = Self::gf_mul(s0, 13)
957                ^ Self::gf_mul(s1, 9)
958                ^ Self::gf_mul(s2, 14)
959                ^ Self::gf_mul(s3, 11);
960            new_state[i + 3] = Self::gf_mul(s0, 11)
961                ^ Self::gf_mul(s1, 13)
962                ^ Self::gf_mul(s2, 9)
963                ^ Self::gf_mul(s3, 14);
964
965            i += 4;
966        }
967        new_state
968    }
969}
970
971#[cfg(test)]
972mod tests {
973    use super::{AesConst, AesMode};
974
975    #[test]
976    fn test_csingle_block_mode() {
977        let plaintext = [0u8; 16];
978        let key = [0xFF; 16];
979
980        let ciphertext = AesConst::encrypt_block(&plaintext, &key);
981        let decrypted = AesConst::decrypt_block(&ciphertext, &key);
982
983        assert_eq!(decrypted, plaintext, "Single block mode test failed");
984    }
985
986    #[test]
987    fn test_const_ecb_mode() {
988        let plaintext = [0u8; 32];
989        let key = [0xFF; 16];
990        let mode = AesMode::ECB {};
991
992        let ciphertext = AesConst::encrypt(&plaintext, &key, &mode);
993        let decrypted = AesConst::decrypt(&ciphertext, &key, &mode);
994
995        assert_eq!(decrypted, plaintext, "ECB mode test failed");
996    }
997
998    #[test]
999    fn test_const_cbc_mode() {
1000        let plaintext = [0u8; 32];
1001        let key = [0xFF; 16];
1002        let iv = [0x00; 16];
1003        let mode = AesMode::CBC { iv };
1004
1005        let ciphertext = AesConst::encrypt(&plaintext, &key, &mode);
1006        let decrypted = AesConst::decrypt(&ciphertext, &key, &mode);
1007
1008        assert_eq!(decrypted, plaintext, "CBC mode test failed");
1009    }
1010
1011    #[test]
1012    fn test_const_ctr_mode() {
1013        let plaintext = [0u8; 32];
1014        let key = [0xFF; 16];
1015        let iv = [0x00; 12];
1016        let mode = AesMode::CTR { iv };
1017
1018        let ciphertext = AesConst::encrypt(&plaintext, &key, &mode);
1019        let decrypted = AesConst::decrypt(&ciphertext, &key, &mode);
1020
1021        assert_eq!(decrypted, plaintext, "CTR mode test failed");
1022    }
1023
1024    use super::Aes;
1025
1026    #[test]
1027    fn test_nonconst_single_block_mode() {
1028        let plaintext = [0u8; 16];
1029        let key = [0xFF; 16];
1030
1031        let ciphertext = Aes::encrypt_block(&plaintext, &key);
1032        let decrypted = Aes::decrypt_block(&ciphertext, &key);
1033
1034        assert_eq!(decrypted, plaintext, "Single block mode test failed");
1035    }
1036
1037    #[test]
1038    fn test_nonconst_ecb_mode() {
1039        let plaintext = [0u8; 32];
1040        let key = [0xFF; 16];
1041        let mode = AesMode::ECB {};
1042
1043        let ciphertext = Aes::encrypt(&plaintext, &key, &mode);
1044        let decrypted = Aes::decrypt(&ciphertext, &key, &mode);
1045
1046        assert_eq!(decrypted, plaintext, "ECB mode test failed");
1047    }
1048
1049    #[test]
1050    fn test_nonconst_cbc_mode() {
1051        let plaintext = [0u8; 32];
1052        let key = [0xFF; 16];
1053        let iv = [0x00; 16];
1054        let mode = AesMode::CBC { iv };
1055
1056        let ciphertext = Aes::encrypt(&plaintext, &key, &mode);
1057        let decrypted = Aes::decrypt(&ciphertext, &key, &mode);
1058
1059        assert_eq!(decrypted, plaintext, "CBC mode test failed");
1060    }
1061
1062    #[test]
1063    fn test_nonconst_ctr_mode() {
1064        let plaintext = [0u8; 32];
1065        let key = [0xFF; 16];
1066        let iv = [0x00; 12];
1067        let mode = AesMode::CTR { iv };
1068
1069        let ciphertext = Aes::encrypt(&plaintext, &key, &mode);
1070        let decrypted = Aes::decrypt(&ciphertext, &key, &mode);
1071
1072        assert_eq!(decrypted, plaintext, "CTR mode test failed");
1073    }
1074}