algos/cs/security/
aes.rs

1//! DISCLAIMER: This library is a toy example of the AES (Rijndael) block cipher in pure Rust.
2//! It is *EXCLUSIVELY* for demonstration and educational purposes. Absolutely DO NOT use it
3//! for real cryptographic or security-sensitive operations. It is not audited, not vetted,
4//! and very likely insecure in practice. If you need AES or any cryptographic operations in
5//! production, please use a vetted, well-reviewed cryptography library (e.g. RustCrypto).
6
7use core::convert::TryInto;
8
9/// AES block size in bytes (128 bits).
10pub const AES_BLOCK_SIZE: usize = 16;
11
12/// Represents key sizes for AES: 128, 192, or 256 bits.
13#[derive(Debug, Clone, Copy)]
14pub enum AesKeySize {
15    Bits128,
16    Bits192,
17    Bits256,
18}
19
20/// An AES key schedule object, storing the round keys after expansion.
21pub struct AesKey {
22    pub round_keys: Vec<[u8; AES_BLOCK_SIZE]>,
23    pub nr: usize, // number of rounds
24}
25
26impl AesKey {
27    /// Create a new AES key from a given key material (raw bytes) and key size variant.
28    ///
29    /// # Panics
30    /// Panics if the key material length doesn't match the indicated `AesKeySize`.
31    ///
32    /// DO NOT USE THIS FOR REAL SECURITY.
33    pub fn new(key_data: &[u8], key_size: AesKeySize) -> Self {
34        let (key_len, nr, nk) = match key_size {
35            AesKeySize::Bits128 => (16, 10, 4),
36            AesKeySize::Bits192 => (24, 12, 6),
37            AesKeySize::Bits256 => (32, 14, 8),
38        };
39        assert_eq!(
40            key_data.len(),
41            key_len,
42            "Key length mismatch for AES key size"
43        );
44
45        let expanded_len = AES_BLOCK_SIZE * (nr + 1);
46        let mut round_keys = vec![0u8; expanded_len];
47        // copy initial key
48        round_keys[..key_len].copy_from_slice(key_data);
49
50        key_expansion(&mut round_keys, nk, nr);
51
52        // Convert round_keys to round-key blocks
53        let mut round_blocks = Vec::with_capacity(nr + 1);
54        for i in 0..(nr + 1) {
55            let offset = i * AES_BLOCK_SIZE;
56            let block: [u8; AES_BLOCK_SIZE] = round_keys[offset..offset + AES_BLOCK_SIZE]
57                .try_into()
58                .unwrap();
59            round_blocks.push(block);
60        }
61
62        Self {
63            round_keys: round_blocks,
64            nr,
65        }
66    }
67}
68
69/// Encrypt a single 128-bit block `plaintext` in place using the provided AES key schedule.
70/// *This is toy code. DO NOT use in production.*
71///
72/// # Panics
73/// Panics if `plaintext.len() != 16`.
74pub fn aes_encrypt_block(plaintext: &mut [u8; AES_BLOCK_SIZE], key: &AesKey) {
75    add_round_key(plaintext, &key.round_keys[0]);
76
77    for round in 1..key.nr {
78        sub_bytes(plaintext);
79        shift_rows(plaintext);
80        mix_columns(plaintext);
81        add_round_key(plaintext, &key.round_keys[round]);
82    }
83
84    // final round
85    sub_bytes(plaintext);
86    shift_rows(plaintext);
87    add_round_key(plaintext, &key.round_keys[key.nr]);
88}
89
90/// Decrypt a single 128-bit block `ciphertext` in place using the provided AES key schedule.
91/// *This is toy code. DO NOT use in production.*
92///
93/// # Panics
94/// Panics if `ciphertext.len() != 16`.
95pub fn aes_decrypt_block(ciphertext: &mut [u8; AES_BLOCK_SIZE], key: &AesKey) {
96    add_round_key(ciphertext, &key.round_keys[key.nr]);
97    inv_shift_rows(ciphertext);
98    inv_sub_bytes(ciphertext);
99
100    for round in (1..key.nr).rev() {
101        add_round_key(ciphertext, &key.round_keys[round]);
102        inv_mix_columns(ciphertext);
103        inv_shift_rows(ciphertext);
104        inv_sub_bytes(ciphertext);
105    }
106
107    // final
108    add_round_key(ciphertext, &key.round_keys[0]);
109}
110
111// ---------------- Internal Implementation Details (toy) ---------------- //
112// The following code includes S-Boxes, inverse S-Boxes, Rcon constants,
113// and standard AES transformations. DO NOT rely on for real usage.
114
115/// S-Box for AES subBytes
116static SBOX: [u8; 256] = [
117    0x63, 0x7c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5, 0x30, 0x01, 0x67, 0x2b, 0xfe, 0xd7, 0xab, 0x76,
118    0xca, 0x82, 0xc9, 0x7d, 0xfa, 0x59, 0x47, 0xf0, 0xad, 0xd4, 0xa2, 0xaf, 0x9c, 0xa4, 0x72, 0xc0,
119    0xb7, 0xfd, 0x93, 0x26, 0x36, 0x3f, 0xf7, 0xcc, 0x34, 0xa5, 0xe5, 0xf1, 0x71, 0xd8, 0x31, 0x15,
120    0x04, 0xc7, 0x23, 0xc3, 0x18, 0x96, 0x05, 0x9a, 0x07, 0x12, 0x80, 0xe2, 0xeb, 0x27, 0xb2, 0x75,
121    0x09, 0x83, 0x2c, 0x1a, 0x1b, 0x6e, 0x5a, 0xa0, 0x52, 0x3b, 0xd6, 0xb3, 0x29, 0xe3, 0x2f, 0x84,
122    0x53, 0xd1, 0x00, 0xed, 0x20, 0xfc, 0xb1, 0x5b, 0x6a, 0xcb, 0xbe, 0x39, 0x4a, 0x4c, 0x58, 0xcf,
123    0xd0, 0xef, 0xaa, 0xfb, 0x43, 0x4d, 0x33, 0x85, 0x45, 0xf9, 0x02, 0x7f, 0x50, 0x3c, 0x9f, 0xa8,
124    0x51, 0xa3, 0x40, 0x8f, 0x92, 0x9d, 0x38, 0xf5, 0xbc, 0xb6, 0xda, 0x21, 0x10, 0xff, 0xf3, 0xd2,
125    0xcd, 0x0c, 0x13, 0xec, 0x5f, 0x97, 0x44, 0x17, 0xc4, 0xa7, 0x7e, 0x3d, 0x64, 0x5d, 0x19, 0x73,
126    0x60, 0x81, 0x4f, 0xdc, 0x22, 0x2a, 0x90, 0x88, 0x46, 0xee, 0xb8, 0x14, 0xde, 0x5e, 0x0b, 0xdb,
127    0xe0, 0x32, 0x3a, 0x0a, 0x49, 0x06, 0x24, 0x5c, 0xc2, 0xd3, 0xac, 0x62, 0x91, 0x95, 0xe4, 0x79,
128    0xe7, 0xc8, 0x37, 0x6d, 0x8d, 0xd5, 0x4e, 0xa9, 0x6c, 0x56, 0xf4, 0xea, 0x65, 0x7a, 0xae, 0x08,
129    0xba, 0x78, 0x25, 0x2e, 0x1c, 0xa6, 0xb4, 0xc6, 0xe8, 0xdd, 0x74, 0x1f, 0x4b, 0xbd, 0x8b, 0x8a,
130    0x70, 0x3e, 0xb5, 0x66, 0x48, 0x03, 0xf6, 0x0e, 0x61, 0x35, 0x57, 0xb9, 0x86, 0xc1, 0x1d, 0x9e,
131    0xe1, 0xf8, 0x98, 0x11, 0x69, 0xd9, 0x8e, 0x94, 0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf,
132    0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, 0x41, 0x99, 0x2d, 0x0f, 0xb0, 0x54, 0xbb, 0x16,
133];
134
135/// Inverse S-Box for AES invSubBytes
136static INV_SBOX: [u8; 256] = [
137    0x52, 0x09, 0x6a, 0xd5, 0x30, 0x36, 0xa5, 0x38, 0xbf, 0x40, 0xa3, 0x9e, 0x81, 0xf3, 0xd7, 0xfb,
138    0x7c, 0xe3, 0x39, 0x82, 0x9b, 0x2f, 0xff, 0x87, 0x34, 0x8e, 0x43, 0x44, 0xc4, 0xde, 0xe9, 0xcb,
139    0x54, 0x7b, 0x94, 0x32, 0xa6, 0xc2, 0x23, 0x3d, 0xee, 0x4c, 0x95, 0x0b, 0x42, 0xfa, 0xc3, 0x4e,
140    0x08, 0x2e, 0xa1, 0x66, 0x28, 0xd9, 0x24, 0xb2, 0x76, 0x5b, 0xa2, 0x49, 0x6d, 0x8b, 0xd1, 0x25,
141    0x72, 0xf8, 0xf6, 0x64, 0x86, 0x68, 0x98, 0x16, 0xd4, 0xa4, 0x5c, 0xcc, 0x5d, 0x65, 0xb6, 0x92,
142    0x6c, 0x70, 0x48, 0x50, 0xfd, 0xed, 0xb9, 0xda, 0x5e, 0x15, 0x46, 0x57, 0xa7, 0x8d, 0x9d, 0x84,
143    0x90, 0xd8, 0xab, 0x00, 0x8c, 0xbc, 0xd3, 0x0a, 0xf7, 0xe4, 0x58, 0x05, 0xb8, 0xb3, 0x45, 0x06,
144    0xd0, 0x2c, 0x1e, 0x8f, 0xca, 0x3f, 0x0f, 0x02, 0xc1, 0xaf, 0xbd, 0x03, 0x01, 0x13, 0x8a, 0x6b,
145    0x3a, 0x91, 0x11, 0x41, 0x4f, 0x67, 0xdc, 0xea, 0x97, 0xf2, 0xcf, 0xce, 0xf0, 0xb4, 0xe6, 0x73,
146    0x96, 0xac, 0x74, 0x22, 0xe7, 0xad, 0x35, 0x85, 0xe2, 0xf9, 0x37, 0xe8, 0x1c, 0x75, 0xdf, 0x6e,
147    0x47, 0xf1, 0x1a, 0x71, 0x1d, 0x29, 0xc5, 0x89, 0x6f, 0xb7, 0x62, 0x0e, 0xaa, 0x18, 0xbe, 0x1b,
148    0xfc, 0x56, 0x3e, 0x4b, 0xc6, 0xd2, 0x79, 0x20, 0x9a, 0xdb, 0xc0, 0xfe, 0x78, 0xcd, 0x5a, 0xf4,
149    0x1f, 0xdd, 0xa8, 0x33, 0x88, 0x07, 0xc7, 0x31, 0xb1, 0x12, 0x10, 0x59, 0x27, 0x80, 0xec, 0x5f,
150    0x60, 0x51, 0x7f, 0xa9, 0x19, 0xb5, 0x4a, 0x0d, 0x2d, 0xe5, 0x7a, 0x9f, 0x93, 0xc9, 0x9c, 0xef,
151    0xa0, 0xe0, 0x3b, 0x4d, 0xae, 0x2a, 0xf5, 0xb0, 0xc8, 0xeb, 0xbb, 0x3c, 0x83, 0x53, 0x99, 0x61,
152    0x17, 0x2b, 0x04, 0x7e, 0xba, 0x77, 0xd6, 0x26, 0xe1, 0x69, 0x14, 0x63, 0x55, 0x21, 0x0c, 0x7d,
153];
154
155/// Round constant for key expansion
156static RCON: [u8; 255] = {
157    let mut rcon = [0u8; 255];
158    rcon[0] = 0x00;
159    rcon[1] = 0x01;
160    rcon[2] = 0x02;
161    rcon[3] = 0x04;
162    rcon[4] = 0x08;
163    rcon[5] = 0x10;
164    rcon[6] = 0x20;
165    rcon[7] = 0x40;
166    rcon[8] = 0x80;
167    rcon[9] = 0x1B;
168    rcon[10] = 0x36;
169    // Only first 11 values are needed for AES-128
170    rcon
171};
172
173fn sub_bytes(state: &mut [u8; AES_BLOCK_SIZE]) {
174    for b in state.iter_mut() {
175        *b = SBOX[*b as usize];
176    }
177}
178
179fn inv_sub_bytes(state: &mut [u8; AES_BLOCK_SIZE]) {
180    for b in state.iter_mut() {
181        *b = INV_SBOX[*b as usize];
182    }
183}
184
185fn shift_rows(state: &mut [u8; AES_BLOCK_SIZE]) {
186    // row 1 shift by 1
187    let row1 = [state[1], state[5], state[9], state[13]];
188    state[1] = row1[1];
189    state[5] = row1[2];
190    state[9] = row1[3];
191    state[13] = row1[0];
192
193    // row 2 shift by 2
194    let row2 = [state[2], state[6], state[10], state[14]];
195    state[2] = row2[2];
196    state[6] = row2[3];
197    state[10] = row2[0];
198    state[14] = row2[1];
199
200    // row 3 shift by 3
201    let row3 = [state[3], state[7], state[11], state[15]];
202    state[3] = row3[3];
203    state[7] = row3[0];
204    state[11] = row3[1];
205    state[15] = row3[2];
206}
207
208fn inv_shift_rows(state: &mut [u8; AES_BLOCK_SIZE]) {
209    // row 1 shift right by 1
210    let row1 = [state[1], state[5], state[9], state[13]];
211    state[1] = row1[3];
212    state[5] = row1[0];
213    state[9] = row1[1];
214    state[13] = row1[2];
215
216    // row 2 shift right by 2
217    let row2 = [state[2], state[6], state[10], state[14]];
218    state[2] = row2[2];
219    state[6] = row2[3];
220    state[10] = row2[0];
221    state[14] = row2[1];
222
223    // row 3 shift right by 3
224    let row3 = [state[3], state[7], state[11], state[15]];
225    state[3] = row3[1];
226    state[7] = row3[2];
227    state[11] = row3[3];
228    state[15] = row3[0];
229}
230
231fn xtime(x: u8) -> u8 {
232    if (x & 0x80) != 0 {
233        (x << 1) ^ 0x1B
234    } else {
235        x << 1
236    }
237}
238
239fn mix_columns(state: &mut [u8; AES_BLOCK_SIZE]) {
240    for col in 0..4 {
241        let base = col * 4;
242        let t = state[base] ^ state[base + 1] ^ state[base + 2] ^ state[base + 3];
243        let temp0 = state[base];
244        let temp1 = state[base + 1];
245        let temp2 = state[base + 2];
246        let temp3 = state[base + 3];
247
248        state[base] ^= t ^ xtime(temp0 ^ temp1);
249        state[base + 1] ^= t ^ xtime(temp1 ^ temp2);
250        state[base + 2] ^= t ^ xtime(temp2 ^ temp3);
251        state[base + 3] ^= t ^ xtime(temp3 ^ temp0);
252    }
253}
254
255fn inv_mix_columns(state: &mut [u8; AES_BLOCK_SIZE]) {
256    // The standard approach is to multiply the state columns by the inverse of the MDS matrix
257    // We'll do it in the typical inline approach:
258    for col in 0..4 {
259        let base = col * 4;
260        let a0 = state[base];
261        let a1 = state[base + 1];
262        let a2 = state[base + 2];
263        let a3 = state[base + 3];
264
265        state[base] = mul(a0, 0x0e) ^ mul(a1, 0x0b) ^ mul(a2, 0x0d) ^ mul(a3, 0x09);
266        state[base + 1] = mul(a0, 0x09) ^ mul(a1, 0x0e) ^ mul(a2, 0x0b) ^ mul(a3, 0x0d);
267        state[base + 2] = mul(a0, 0x0d) ^ mul(a1, 0x09) ^ mul(a2, 0x0e) ^ mul(a3, 0x0b);
268        state[base + 3] = mul(a0, 0x0b) ^ mul(a1, 0x0d) ^ mul(a2, 0x09) ^ mul(a3, 0x0e);
269    }
270}
271
272fn mul(x: u8, y: u8) -> u8 {
273    // Galois Field (2^8) multiplication
274    let mut r = 0;
275    let mut a = x;
276    let mut b = y;
277    for _ in 0..8 {
278        if (b & 1) == 1 {
279            r ^= a;
280        }
281        let hi_bit_set = (a & 0x80) != 0;
282        a <<= 1;
283        if hi_bit_set {
284            a ^= 0x1b;
285        }
286        b >>= 1;
287    }
288    r
289}
290
291fn add_round_key(state: &mut [u8; AES_BLOCK_SIZE], round_key: &[u8; AES_BLOCK_SIZE]) {
292    for (s, k) in state.iter_mut().zip(round_key) {
293        *s ^= *k;
294    }
295}
296
297// Key Expansion routines
298fn key_expansion(expanded: &mut [u8], nk: usize, nr: usize) {
299    let total_words = (nr + 1) * 4; // number of 32-bit words
300    let mut i = nk;
301    while i < total_words {
302        let mut temp = [
303            expanded[(i - 1) * 4],
304            expanded[(i - 1) * 4 + 1],
305            expanded[(i - 1) * 4 + 2],
306            expanded[(i - 1) * 4 + 3],
307        ];
308
309        if i % nk == 0 {
310            // rotate
311            temp = [temp[1], temp[2], temp[3], temp[0]];
312            // sub
313            for t in temp.iter_mut() {
314                *t = SBOX[*t as usize];
315            }
316            // rcon
317            temp[0] ^= RCON[i / nk];
318        } else if nk > 6 && i % nk == 4 {
319            for t in temp.iter_mut() {
320                *t = SBOX[*t as usize];
321            }
322        }
323
324        let wprev = (i - nk) * 4;
325        for (j, tj) in temp.iter().enumerate() {
326            expanded[i * 4 + j] = expanded[wprev + j] ^ tj;
327        }
328        i += 1;
329    }
330}
331
332#[cfg(test)]
333mod tests {
334    use super::*;
335
336    // Example test using NIST known test vectors for AES-128 single block
337    // "Fips-197" example: key=2b7e151628aed2a6abf7158809cf4f3c, plaintext=6bc1bee22e409f96e93d7e117393172a
338    // ciphertext=3ad77bb40d7a3660a89ecaf32466ef97
339    #[test]
340    fn test_aes128_encrypt_block() {
341        let key_data = hex_to_bytes("2b7e151628aed2a6abf7158809cf4f3c");
342        let mut block = hex_to_array("6bc1bee22e409f96e93d7e117393172a");
343
344        let aes_key = AesKey::new(&key_data, AesKeySize::Bits128);
345
346        aes_encrypt_block(&mut block, &aes_key);
347
348        let expected = hex_to_array("3ad77bb40d7a3660a89ecaf32466ef97");
349        assert_eq!(block, expected);
350    }
351
352    #[test]
353    fn test_aes128_decrypt_block() {
354        let key_data = hex_to_bytes("2b7e151628aed2a6abf7158809cf4f3c");
355        let mut block = hex_to_array("3ad77bb40d7a3660a89ecaf32466ef97");
356
357        let aes_key = AesKey::new(&key_data, AesKeySize::Bits128);
358
359        aes_decrypt_block(&mut block, &aes_key);
360
361        let expected = hex_to_array("6bc1bee22e409f96e93d7e117393172a");
362        assert_eq!(block, expected);
363    }
364
365    // Helpers
366    fn hex_to_bytes(s: &str) -> Vec<u8> {
367        (0..s.len())
368            .step_by(2)
369            .map(|i| u8::from_str_radix(&s[i..i + 2], 16).unwrap())
370            .collect()
371    }
372
373    fn hex_to_array(s: &str) -> [u8; 16] {
374        let bytes = hex_to_bytes(s);
375        bytes.try_into().unwrap()
376    }
377}