arcium-primitives 0.4.5

Arcium primitives
Documentation
use aes::Aes128Enc;
use rand::RngCore;

use crate::ciphers::{block::U8x16, BlockCipher};

pub struct Aes128Prng<'a> {
    enc: &'a Aes128Enc,
    iv: u128,
}

impl Aes128Prng<'_> {
    #[inline]
    fn fill_block(&mut self, dest: &mut [u8]) {
        self.iv += 1;
        let buff = self.enc.encrypt(&self.iv.into());
        dest.copy_from_slice(&AsRef::<[u8]>::as_ref(&buff)[..dest.len()]);
    }
}

impl<'a> Aes128Prng<'a> {
    pub fn new(enc: &'a Aes128Enc, iv: U8x16) -> Self {
        Self { enc, iv: iv.into() }
    }
}

impl RngCore for Aes128Prng<'_> {
    fn fill_bytes(&mut self, dest: &mut [u8]) {
        for s in dest.chunks_mut(16) {
            self.fill_block(s);
        }
    }

    // TODO: do not discard remaining state bytes in functions next_u32 and next_u64
    fn next_u32(&mut self) -> u32 {
        let mut bytes = [0u8; 4];
        self.fill_bytes(&mut bytes);
        u32::from_le_bytes(bytes)
    }

    fn next_u64(&mut self) -> u64 {
        let mut bytes = [0u8; 8];
        self.fill_bytes(&mut bytes);
        u64::from_le_bytes(bytes)
    }

    fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand::Error> {
        self.fill_bytes(dest);
        Ok(())
    }
}

#[cfg(test)]
mod tests {
    use aes::{cipher::KeyInit, Aes128Enc};
    use ff::Field;
    use rand::RngCore;

    use super::Aes128Prng;
    use crate::algebra::field::binary::Gf2_128;

    #[test]
    fn test_ctr_mode_no_repeated_blocks() {
        let key = Aes128Enc::new_from_slice(&[0xAB; 16]).unwrap();
        let mut prng = Aes128Prng::new(&key, Gf2_128::from_limbs([0x42, 0x42]));

        // Generate 256 blocks of output (4KB total), which is a reasonable amount to test for
        // collisions
        let mut blocks = Vec::new();
        for _ in 0..256 {
            let mut block = [0u8; 16];
            prng.fill_bytes(&mut block);
            blocks.push(block);
        }

        // No two blocks should be equal (distinct counter inputs => distinct AES outputs)
        for i in 0..blocks.len() {
            for j in (i + 1)..blocks.len() {
                assert_ne!(blocks[i], blocks[j], "blocks {i} and {j} collide");
            }
        }
    }

    #[test]
    fn test_deterministic_output() {
        let key = Aes128Enc::new_from_slice(&[0x01; 16]).unwrap();
        let iv = Gf2_128::ZERO;

        let mut prng1 = Aes128Prng::new(&key, iv);
        let mut prng2 = Aes128Prng::new(&key, iv);

        let mut out1 = [0u8; 128];
        let mut out2 = [0u8; 128];
        prng1.fill_bytes(&mut out1);
        prng2.fill_bytes(&mut out2);

        assert_eq!(out1, out2);
    }

    #[test]
    fn test_different_ivs_produce_different_output() {
        let key = Aes128Enc::new_from_slice(&[0x01; 16]).unwrap();

        let mut prng1 = Aes128Prng::new(&key, Gf2_128::ZERO);
        let mut prng2 = Aes128Prng::new(&key, Gf2_128::from_limbs([1, 0]));

        let mut out1 = [0u8; 64];
        let mut out2 = [0u8; 64];
        prng1.fill_bytes(&mut out1);
        prng2.fill_bytes(&mut out2);

        assert_ne!(out1, out2);
    }

    #[test]
    fn test_partial_block_fill() {
        let key = Aes128Enc::new_from_slice(&[0xCC; 16]).unwrap();
        let mut prng = Aes128Prng::new(&key, Gf2_128::ZERO);

        // next_u32 and next_u64 should produce non-zero values
        let v32 = prng.next_u32();
        let v64 = prng.next_u64();
        // Extremely unlikely to be both zero with a non-degenerate key
        assert!(v32 != 0 || v64 != 0);
    }
}