aes 0.9.0

Pure Rust implementation of the Advanced Encryption Standard (a.k.a. Rijndael)
Documentation
#![allow(unsafe_op_in_unsafe_fn)]

use crate::x86::arch::*;
use core::mem::{transmute, zeroed};

pub(super) type Aes128RoundKeys = [__m128i; 11];
pub(super) type Aes192RoundKeys = [__m128i; 13];
pub(super) type Aes256RoundKeys = [__m128i; 15];

pub(crate) mod aes128 {
    use super::*;

    #[target_feature(enable = "aes")]
    pub(crate) unsafe fn expand_key(key: &[u8; 16]) -> Aes128RoundKeys {
        unsafe fn expand_round<const RK: i32>(keys: &mut Aes128RoundKeys, pos: usize) {
            let mut t1 = keys[pos - 1];
            let mut t2;
            let mut t3;

            t2 = _mm_aeskeygenassist_si128(t1, RK);
            t2 = _mm_shuffle_epi32(t2, 0xff);
            t3 = _mm_slli_si128(t1, 0x4);
            t1 = _mm_xor_si128(t1, t3);
            t3 = _mm_slli_si128(t3, 0x4);
            t1 = _mm_xor_si128(t1, t3);
            t3 = _mm_slli_si128(t3, 0x4);
            t1 = _mm_xor_si128(t1, t3);
            t1 = _mm_xor_si128(t1, t2);

            keys[pos] = t1;
        }

        let mut keys: Aes128RoundKeys = zeroed();
        let k = _mm_loadu_si128(key.as_ptr().cast());
        keys[0] = k;

        let kr = &mut keys;
        expand_round::<0x01>(kr, 1);
        expand_round::<0x02>(kr, 2);
        expand_round::<0x04>(kr, 3);
        expand_round::<0x08>(kr, 4);
        expand_round::<0x10>(kr, 5);
        expand_round::<0x20>(kr, 6);
        expand_round::<0x40>(kr, 7);
        expand_round::<0x80>(kr, 8);
        expand_round::<0x1B>(kr, 9);
        expand_round::<0x36>(kr, 10);

        keys
    }
}

pub(crate) mod aes192 {
    use super::*;

    #[target_feature(enable = "aes")]
    pub(crate) unsafe fn expand_key(key: &[u8; 24]) -> Aes192RoundKeys {
        unsafe fn shuffle(a: __m128i, b: __m128i, i: usize) -> __m128i {
            let a: [u64; 2] = transmute(a);
            let b: [u64; 2] = transmute(b);
            transmute([a[i], b[0]])
        }

        #[target_feature(enable = "aes")]
        unsafe fn expand_round<const RK: i32>(
            mut t1: __m128i,
            mut t3: __m128i,
        ) -> (__m128i, __m128i) {
            let (mut t2, mut t4);

            t2 = _mm_aeskeygenassist_si128(t3, RK);
            t2 = _mm_shuffle_epi32(t2, 0x55);
            t4 = _mm_slli_si128(t1, 0x4);
            t1 = _mm_xor_si128(t1, t4);
            t4 = _mm_slli_si128(t4, 0x4);
            t1 = _mm_xor_si128(t1, t4);
            t4 = _mm_slli_si128(t4, 0x4);
            t1 = _mm_xor_si128(t1, t4);
            t1 = _mm_xor_si128(t1, t2);
            t2 = _mm_shuffle_epi32(t1, 0xff);
            t4 = _mm_slli_si128(t3, 0x4);
            t3 = _mm_xor_si128(t3, t4);
            t3 = _mm_xor_si128(t3, t2);

            (t1, t3)
        }

        let mut keys: Aes192RoundKeys = zeroed();
        // We are being extra pedantic here to remove out-of-bound access.
        // This should be optimized into movups, movsd sequence.
        let (k0, k1l) = {
            let mut t = [0u8; 32];
            t[..key.len()].copy_from_slice(key);
            (
                _mm_loadu_si128(t.as_ptr().cast()),
                _mm_loadu_si128(t.as_ptr().offset(16).cast()),
            )
        };

        keys[0] = k0;

        let (k1_2, k2r) = expand_round::<0x01>(k0, k1l);
        keys[1] = shuffle(k1l, k1_2, 0);
        keys[2] = shuffle(k1_2, k2r, 1);

        let (k3, k4l) = expand_round::<0x02>(k1_2, k2r);
        keys[3] = k3;

        let (k4_5, k5r) = expand_round::<0x04>(k3, k4l);
        let k4 = shuffle(k4l, k4_5, 0);
        let k5 = shuffle(k4_5, k5r, 1);
        keys[4] = k4;
        keys[5] = k5;

        let (k6, k7l) = expand_round::<0x08>(k4_5, k5r);
        keys[6] = k6;

        let (k7_8, k8r) = expand_round::<0x10>(k6, k7l);
        keys[7] = shuffle(k7l, k7_8, 0);
        keys[8] = shuffle(k7_8, k8r, 1);

        let (k9, k10l) = expand_round::<0x20>(k7_8, k8r);
        keys[9] = k9;

        let (k10_11, k11r) = expand_round::<0x40>(k9, k10l);
        keys[10] = shuffle(k10l, k10_11, 0);
        keys[11] = shuffle(k10_11, k11r, 1);

        let (k12, _) = expand_round::<0x80>(k10_11, k11r);
        keys[12] = k12;

        keys
    }
}

pub(crate) mod aes256 {
    use super::*;

    #[target_feature(enable = "aes")]
    pub(crate) unsafe fn expand_key(key: &[u8; 32]) -> Aes256RoundKeys {
        unsafe fn expand_round<const RK: i32>(keys: &mut Aes256RoundKeys, pos: usize) {
            let mut t1 = keys[pos - 2];
            let mut t2;
            let mut t3 = keys[pos - 1];
            let mut t4;

            t2 = _mm_aeskeygenassist_si128(t3, RK);
            t2 = _mm_shuffle_epi32(t2, 0xff);
            t4 = _mm_slli_si128(t1, 0x4);
            t1 = _mm_xor_si128(t1, t4);
            t4 = _mm_slli_si128(t4, 0x4);
            t1 = _mm_xor_si128(t1, t4);
            t4 = _mm_slli_si128(t4, 0x4);
            t1 = _mm_xor_si128(t1, t4);
            t1 = _mm_xor_si128(t1, t2);

            keys[pos] = t1;

            t4 = _mm_aeskeygenassist_si128(t1, 0x00);
            t2 = _mm_shuffle_epi32(t4, 0xaa);
            t4 = _mm_slli_si128(t3, 0x4);
            t3 = _mm_xor_si128(t3, t4);
            t4 = _mm_slli_si128(t4, 0x4);
            t3 = _mm_xor_si128(t3, t4);
            t4 = _mm_slli_si128(t4, 0x4);
            t3 = _mm_xor_si128(t3, t4);
            t3 = _mm_xor_si128(t3, t2);

            keys[pos + 1] = t3;
        }

        unsafe fn expand_round_last<const RK: i32>(keys: &mut Aes256RoundKeys, pos: usize) {
            let mut t1 = keys[pos - 2];
            let mut t2;
            let t3 = keys[pos - 1];
            let mut t4;

            t2 = _mm_aeskeygenassist_si128(t3, RK);
            t2 = _mm_shuffle_epi32(t2, 0xff);
            t4 = _mm_slli_si128(t1, 0x4);
            t1 = _mm_xor_si128(t1, t4);
            t4 = _mm_slli_si128(t4, 0x4);
            t1 = _mm_xor_si128(t1, t4);
            t4 = _mm_slli_si128(t4, 0x4);
            t1 = _mm_xor_si128(t1, t4);
            t1 = _mm_xor_si128(t1, t2);

            keys[pos] = t1;
        }

        let mut keys: Aes256RoundKeys = zeroed();

        let kp = key.as_ptr().cast::<__m128i>();
        keys[0] = _mm_loadu_si128(kp);
        keys[1] = _mm_loadu_si128(kp.add(1));

        let k = &mut keys;
        expand_round::<0x01>(k, 2);
        expand_round::<0x02>(k, 4);
        expand_round::<0x04>(k, 6);
        expand_round::<0x08>(k, 8);
        expand_round::<0x10>(k, 10);
        expand_round::<0x20>(k, 12);
        expand_round_last::<0x40>(k, 14);

        keys
    }
}

#[target_feature(enable = "aes")]
pub(crate) unsafe fn inv_keys<const N: usize>(keys: &[__m128i; N]) -> [__m128i; N] {
    let mut inv_keys: [__m128i; N] = zeroed();
    inv_keys[0] = keys[N - 1];
    for i in 1..N - 1 {
        inv_keys[i] = _mm_aesimc_si128(keys[N - 1 - i]);
    }
    inv_keys[N - 1] = keys[0];
    inv_keys
}