aes 0.7.3

Pure Rust implementation of the Advanced Encryption Standard (a.k.a. Rijndael) including support for AES in counter mode (a.k.a. AES-CTR)
Documentation
//! AES in counter mode (a.k.a. AES-CTR)

// TODO(tarcieri): support generic CTR API

#![allow(clippy::unreadable_literal)]

use super::arch::*;
use core::mem;

use super::{Aes128, Aes192, Aes256};
use crate::BLOCK_SIZE;
use cipher::{
    consts::U16,
    errors::{LoopError, OverflowError},
    generic_array::GenericArray,
    BlockCipher, FromBlockCipher, SeekNum, StreamCipher, StreamCipherSeek,
};

const PAR_BLOCKS: usize = 8;
const PAR_BLOCKS_SIZE: usize = PAR_BLOCKS * BLOCK_SIZE;

#[inline(always)]
pub fn xor(buf: &mut [u8], key: &[u8]) {
    debug_assert_eq!(buf.len(), key.len());
    for (a, b) in buf.iter_mut().zip(key) {
        *a ^= *b;
    }
}

#[inline(always)]
fn xor_block8(buf: &mut [u8], ctr: [__m128i; 8]) {
    debug_assert_eq!(buf.len(), PAR_BLOCKS_SIZE);

    // Safety: `loadu` and `storeu` support unaligned access
    #[allow(clippy::cast_ptr_alignment)]
    unsafe {
        // compiler should unroll this loop
        for i in 0..8 {
            let ptr = buf.as_mut_ptr().offset(16 * i) as *mut __m128i;
            let data = _mm_loadu_si128(ptr);
            let data = _mm_xor_si128(data, ctr[i as usize]);
            _mm_storeu_si128(ptr, data);
        }
    }
}

#[inline(always)]
fn swap_bytes(v: __m128i) -> __m128i {
    unsafe {
        let mask = _mm_set_epi64x(0x08090a0b0c0d0e0f, 0x0001020304050607);
        _mm_shuffle_epi8(v, mask)
    }
}

#[inline(always)]
fn inc_be(v: __m128i) -> __m128i {
    unsafe { _mm_add_epi64(v, _mm_set_epi64x(1, 0)) }
}

#[inline(always)]
fn load(val: &GenericArray<u8, U16>) -> __m128i {
    // Safety: `loadu` supports unaligned loads
    #[allow(clippy::cast_ptr_alignment)]
    unsafe {
        _mm_loadu_si128(val.as_ptr() as *const __m128i)
    }
}

macro_rules! impl_ctr {
    ($name:ident, $cipher:ty, $doc:expr) => {
        #[doc=$doc]
        #[derive(Clone)]
        #[cfg_attr(docsrs, doc(cfg(feature = "ctr")))]
        pub struct $name {
            nonce: __m128i,
            ctr: __m128i,
            cipher: $cipher,
            block: [u8; BLOCK_SIZE],
            pos: u8,
        }

        impl $name {
            #[inline(always)]
            fn gen_block(&mut self) {
                let block = self.cipher.encrypt(swap_bytes(self.ctr));
                self.block = unsafe { mem::transmute(block) }
            }

            #[inline(always)]
            fn next_block(&mut self) -> __m128i {
                let block = swap_bytes(self.ctr);
                self.ctr = inc_be(self.ctr);
                self.cipher.encrypt(block)
            }

            #[inline(always)]
            fn next_block8(&mut self) -> [__m128i; 8] {
                let mut ctr = self.ctr;
                let mut block8: [__m128i; 8] = unsafe { mem::zeroed() };
                for i in 0..8 {
                    block8[i] = swap_bytes(ctr);
                    ctr = inc_be(ctr);
                }
                self.ctr = ctr;

                self.cipher.encrypt8(block8)
            }

            #[inline(always)]
            fn get_u64_ctr(&self) -> u64 {
                let (ctr, nonce) = unsafe {
                    (
                        mem::transmute::<__m128i, [u64; 2]>(self.ctr)[1],
                        mem::transmute::<__m128i, [u64; 2]>(self.nonce)[1],
                    )
                };
                ctr.wrapping_sub(nonce)
            }

            /// Check if provided data will not overflow counter
            #[inline(always)]
            fn check_data_len(&self, data: &[u8]) -> Result<(), LoopError> {
                let bs = BLOCK_SIZE;
                let leftover_bytes = bs - self.pos as usize;
                if data.len() < leftover_bytes {
                    return Ok(());
                }
                let blocks = 1 + (data.len() - leftover_bytes) / bs;
                self.get_u64_ctr()
                    .checked_add(blocks as u64)
                    .ok_or(LoopError)
                    .map(|_| ())
            }
        }

        impl FromBlockCipher for $name {
            type BlockCipher = $cipher;
            type NonceSize = <$cipher as BlockCipher>::BlockSize;

            fn from_block_cipher(
                cipher: $cipher,
                nonce: &GenericArray<u8, Self::NonceSize>,
            ) -> Self {
                let nonce = swap_bytes(load(nonce));
                Self {
                    nonce,
                    ctr: nonce,
                    cipher,
                    block: [0u8; BLOCK_SIZE],
                    pos: 0,
                }
            }
        }

        impl StreamCipher for $name {
            #[inline]
            fn try_apply_keystream(&mut self, mut data: &mut [u8]) -> Result<(), LoopError> {
                self.check_data_len(data)?;
                let bs = BLOCK_SIZE;
                let pos = self.pos as usize;
                debug_assert!(bs > pos);

                if pos != 0 {
                    if data.len() < bs - pos {
                        let n = pos + data.len();
                        xor(data, &self.block[pos..n]);
                        self.pos = n as u8;
                        return Ok(());
                    } else {
                        let (l, r) = data.split_at_mut(bs - pos);
                        data = r;
                        xor(l, &self.block[pos..]);
                        self.ctr = inc_be(self.ctr);
                    }
                }

                let mut chunks = data.chunks_exact_mut(PAR_BLOCKS_SIZE);
                for chunk in &mut chunks {
                    xor_block8(chunk, self.next_block8());
                }
                data = chunks.into_remainder();

                let mut chunks = data.chunks_exact_mut(bs);
                for chunk in &mut chunks {
                    let block = self.next_block();

                    unsafe {
                        let t = _mm_loadu_si128(chunk.as_ptr() as *const __m128i);
                        let res = _mm_xor_si128(block, t);
                        _mm_storeu_si128(chunk.as_mut_ptr() as *mut __m128i, res);
                    }
                }

                let rem = chunks.into_remainder();
                self.pos = rem.len() as u8;
                if !rem.is_empty() {
                    self.gen_block();
                    for (a, b) in rem.iter_mut().zip(&self.block) {
                        *a ^= *b;
                    }
                }

                Ok(())
            }
        }

        impl StreamCipherSeek for $name {
            fn try_current_pos<T: SeekNum>(&self) -> Result<T, OverflowError> {
                T::from_block_byte(self.get_u64_ctr(), self.pos, BLOCK_SIZE as u8)
            }

            fn try_seek<T: SeekNum>(&mut self, pos: T) -> Result<(), LoopError> {
                let res: (u64, u8) = pos.to_block_byte(BLOCK_SIZE as u8)?;
                self.ctr = unsafe { _mm_add_epi64(self.nonce, _mm_set_epi64x(res.0 as i64, 0)) };
                self.pos = res.1;
                if self.pos != 0 {
                    self.gen_block()
                }
                Ok(())
            }
        }

        opaque_debug::implement!($name);
    };
}

impl_ctr!(Aes128Ctr, Aes128, "AES-128 in CTR mode");
impl_ctr!(Aes192Ctr, Aes192, "AES-192 in CTR mode");
impl_ctr!(Aes256Ctr, Aes256, "AES-256 in CTR mode");