use crate::bytes::vec_u64_from_bytes;
use crate::data_types::{get_size_in_bits, get_types_vector, Type, UINT64};
use crate::data_values::Value;
use crate::errors::Result;
use aes::cipher::KeyInit;
use aes::cipher::{generic_array::GenericArray, BlockEncrypt};
use aes::Aes128;
use cipher::block_padding::NoPadding;
use rand::rngs::OsRng;
use rand::RngCore;
const BLOCK_SIZE: usize = 16;
pub fn get_bytes_from_os(bytes: &mut [u8]) -> Result<()> {
OsRng
.try_fill_bytes(bytes)
.map_err(|_| runtime_error!("OS random generator failed"))?;
Ok(())
}
pub const SEED_SIZE: usize = 16;
const BUFFER_SIZE: usize = 512;
const INITIAL_BUFFER_SIZE: usize = 64;
pub struct PRNG {
aes: Aes128,
random_source: PrfSession,
}
impl PRNG {
pub fn new(seed: Option<[u8; SEED_SIZE]>) -> Result<PRNG> {
let bytes = match seed {
Some(bytes) => bytes,
None => {
let mut bytes = [0u8; SEED_SIZE];
get_bytes_from_os(&mut bytes)?;
bytes
}
};
let aes = aes::Aes128::new(GenericArray::from_slice(&bytes));
Ok(PRNG {
aes,
random_source: PrfSession::new(0, BUFFER_SIZE)?,
})
}
pub fn get_random_bytes(&mut self, n: usize) -> Result<Vec<u8>> {
self.random_source
.generate_random_bytes(&self.aes, n as u64)
}
fn get_random_key(&mut self) -> Result<[u8; SEED_SIZE]> {
let bytes = self.get_random_bytes(SEED_SIZE)?;
let mut new_seed = [0u8; SEED_SIZE];
new_seed.copy_from_slice(&bytes[0..SEED_SIZE]);
Ok(new_seed)
}
pub fn get_random_value(&mut self, t: Type) -> Result<Value> {
match t {
Type::Scalar(_) | Type::Array(_, _) => {
let bit_size = get_size_in_bits(t)?;
let byte_size = (bit_size + 7) / 8;
let bits_to_flush = 8 * byte_size - bit_size;
let mut bytes = self.get_random_bytes(byte_size as usize)?;
if !bytes.is_empty() {
*bytes.last_mut().unwrap() >>= bits_to_flush;
}
Ok(Value::from_bytes(bytes))
}
Type::Tuple(_) | Type::Vector(_, _) | Type::NamedTuple(_) => {
let ts = get_types_vector(t)?;
let mut v = vec![];
for sub_t in ts {
v.push(self.get_random_value((*sub_t).clone())?)
}
Ok(Value::from_vector(v))
}
}
}
pub fn get_random_in_range(&mut self, modulus: Option<u64>) -> Result<u64> {
if let Some(m) = modulus {
let rem = ((u64::MAX % m) + 1) % m;
let rejection_bound = u64::MAX - rem;
let mut r;
loop {
r = vec_u64_from_bytes(&self.get_random_bytes(8)?, UINT64)?[0];
if r <= rejection_bound {
break;
}
}
Ok(r % m)
} else {
Ok(vec_u64_from_bytes(&self.get_random_bytes(8)?, UINT64)?[0])
}
}
}
pub(super) struct Prf {
aes: Aes128,
}
impl Prf {
pub fn new(key: Option<[u8; SEED_SIZE]>) -> Result<Prf> {
let key_bytes = match key {
Some(bytes) => bytes,
None => {
let mut gen = PRNG::new(None)?;
gen.get_random_key()?
}
};
let aes = aes::Aes128::new(GenericArray::from_slice(&key_bytes));
Ok(Prf { aes })
}
#[cfg(test)]
fn output_bytes(&mut self, input: u64, n: u64) -> Result<Vec<u8>> {
let initial_buffer_size = usize::min(BUFFER_SIZE, n as usize);
PrfSession::new(input, initial_buffer_size)?.generate_random_bytes(&self.aes, n)
}
pub(super) fn output_value(&mut self, input: u64, t: Type) -> Result<Value> {
PrfSession::new(input, INITIAL_BUFFER_SIZE)?.recursively_generate_value(&self.aes, t)
}
pub(super) fn output_permutation(&mut self, input: u64, n: u64) -> Result<Value> {
if n > 2u64.pow(30) {
return Err(runtime_error!("n should be less than 2^30"));
}
let initial_buffer_size = usize::min(BUFFER_SIZE, n as usize);
let mut session = PrfSession::new(input, initial_buffer_size)?;
let mut a: Vec<u64> = (0..n).collect();
for i in 1..n {
let j = session.generate_u32_in_range(&self.aes, i as u32 + 1)?;
a.swap(i as usize, j as usize);
}
Value::from_flattened_array_u64(&a, UINT64)
}
}
struct PrfSession {
input: u128,
buffer: Vec<u8>,
next_byte: usize,
current_buffer_size: usize,
next_buffer_size: usize,
}
impl PrfSession {
pub fn new(input: u64, initial_buffer_size: usize) -> Result<Self> {
let initial_buffer_size = (initial_buffer_size + BLOCK_SIZE - 1) / BLOCK_SIZE * BLOCK_SIZE;
Ok(Self {
input: (input as u128) << 64,
buffer: vec![0u8; initial_buffer_size],
next_byte: initial_buffer_size,
current_buffer_size: initial_buffer_size,
next_buffer_size: initial_buffer_size,
})
}
fn generate_one_batch(&mut self, aes: &Aes128) -> Result<()> {
let mut i_bytes = vec![0u8; self.next_buffer_size];
for i in (0..i_bytes.len()).step_by(16) {
i_bytes[i..i + 16].copy_from_slice(&self.input.to_le_bytes());
self.input = self.input.wrapping_add(1);
}
let buffer_len = self.next_buffer_size;
if buffer_len != self.buffer.len() {
self.buffer.resize(buffer_len, 0);
}
aes.encrypt_padded_b2b::<NoPadding>(&i_bytes, &mut self.buffer)
.map_err(|e| runtime_error!("Encryption error: {e:?}"))?;
self.current_buffer_size = self.next_buffer_size;
if self.next_buffer_size < BUFFER_SIZE {
self.next_buffer_size = usize::min(BUFFER_SIZE, self.next_buffer_size * 2);
}
self.next_byte = 0;
Ok(())
}
fn generate_random_bytes(&mut self, aes: &Aes128, n: u64) -> Result<Vec<u8>> {
let mut bytes = vec![0u8; n as usize];
self.fill_random_bytes(aes, bytes.as_mut_slice())?;
Ok(bytes)
}
fn fill_random_bytes(&mut self, aes: &Aes128, mut buff: &mut [u8]) -> Result<()> {
while !buff.is_empty() {
let need_bytes = buff.len();
let ready_bytes = &self.buffer[self.next_byte..self.current_buffer_size];
if ready_bytes.len() >= need_bytes {
buff.clone_from_slice(&ready_bytes[..need_bytes]);
self.next_byte += need_bytes;
break;
} else {
buff[..ready_bytes.len()].clone_from_slice(ready_bytes);
buff = &mut buff[ready_bytes.len()..];
self.next_byte = 0;
self.generate_one_batch(aes)?;
}
}
Ok(())
}
fn recursively_generate_value(&mut self, aes: &Aes128, tp: Type) -> Result<Value> {
match tp {
Type::Scalar(_) | Type::Array(_, _) => {
let bit_size = get_size_in_bits(tp)?;
let byte_size = (bit_size + 7) / 8;
let bits_to_flush = 8 * byte_size - bit_size;
let mut bytes = self.generate_random_bytes(aes, byte_size)?;
if !bytes.is_empty() {
*bytes.last_mut().unwrap() >>= bits_to_flush;
}
Ok(Value::from_bytes(bytes))
}
Type::Tuple(_) | Type::Vector(_, _) | Type::NamedTuple(_) => {
let ts = get_types_vector(tp)?;
let mut v = vec![];
for sub_t in ts {
let value = self.recursively_generate_value(aes, (*sub_t).clone())?;
v.push(value);
}
Ok(Value::from_vector(v))
}
}
}
fn generate_random_number_const<const NEED_BYTES: usize>(
&mut self,
aes: &Aes128,
) -> Result<u64> {
let mut res = [0u8; 8];
let use_bytes = std::cmp::min(self.current_buffer_size - self.next_byte, NEED_BYTES);
res[..use_bytes].copy_from_slice(&self.buffer[self.next_byte..self.next_byte + use_bytes]);
if use_bytes == NEED_BYTES {
self.next_byte += use_bytes;
} else {
self.generate_one_batch(aes)?;
self.next_byte = NEED_BYTES - use_bytes;
res[use_bytes..NEED_BYTES].copy_from_slice(&self.buffer[..self.next_byte]);
}
let mask = if NEED_BYTES == 8 {
u64::MAX
} else {
(1 << (NEED_BYTES * 8)) - 1
};
Ok(u64::from_le_bytes(res) & mask)
}
fn generate_random_number(&mut self, aes: &Aes128, need_bytes: usize) -> Result<u64> {
match need_bytes {
1 => self.generate_random_number_const::<1>(aes),
2 => self.generate_random_number_const::<2>(aes),
3 => self.generate_random_number_const::<3>(aes),
4 => self.generate_random_number_const::<4>(aes),
5 => self.generate_random_number_const::<5>(aes),
6 => self.generate_random_number_const::<6>(aes),
7 => self.generate_random_number_const::<7>(aes),
8 => self.generate_random_number_const::<8>(aes),
_ => Err(runtime_error!("Unsupported need bytes")),
}
}
fn generate_u32_in_range(&mut self, aes: &Aes128, modulus: u32) -> Result<u32> {
let modulus = modulus as u64;
let need_bytes = (modulus.next_power_of_two().trailing_zeros() + 7) / 8 + 1;
let max_rand_value = (1u64 << (need_bytes as u64 * 8)) - 1;
let num_biased = (max_rand_value + 1) % modulus;
let rejection_bound = max_rand_value - num_biased;
loop {
let rand_value = self.generate_random_number(aes, need_bytes as usize)?;
if rand_value <= rejection_bound {
return Ok((rand_value % modulus) as u32);
}
}
}
}
pub fn entropy_test(counters: [u32; 256], n: u64) -> bool {
let mut entropy = 0f64;
for c in counters {
let prob_c = (c as f64) / (n as f64);
entropy -= prob_c.log2() * prob_c;
}
let precision = (1020_f64) / (n as f64);
(entropy - 8f64).abs() < precision
}
pub fn chi_statistics(counters: &[u64], expected_count_per_element: u64) -> f64 {
let mut chi_statistics = 0_f64;
for c in counters {
chi_statistics += (*c as f64 - expected_count_per_element as f64).powi(2);
}
chi_statistics / expected_count_per_element as f64
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use super::*;
use crate::data_types::{
array_type, named_tuple_type, scalar_type, tuple_type, vector_type, BIT, INT32, UINT64,
UINT8,
};
#[test]
fn test_prng_fixed_seed() {
let helper = |n: usize| -> Result<()> {
let seed = b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0A\x0B\x0C\x0D\x0E\x0F";
let mut prng1 = PRNG::new(Some(seed.clone()))?;
let mut prng2 = PRNG::new(Some(seed.clone()))?;
let rand_bytes1 = prng1.get_random_bytes(n)?;
let rand_bytes2 = prng2.get_random_bytes(n)?;
assert_eq!(rand_bytes1, rand_bytes2);
Ok(())
};
helper(1).unwrap();
helper(19).unwrap();
helper(1000).unwrap();
}
#[test]
fn test_prng_random_seed() {
let mut prng = PRNG::new(None).unwrap();
let mut counters = [0; 256];
let n = 10_000_001;
let rand_bytes = prng.get_random_bytes(n).unwrap();
for byte in rand_bytes {
counters[byte as usize] += 1;
}
assert!(entropy_test(counters, n as u64));
}
#[test]
fn test_prng_random_value() {
let mut g = PRNG::new(None).unwrap();
let mut helper = |t: Type| -> Result<()> {
let v = g.get_random_value(t.clone())?;
assert!(v.check_type(t)?);
Ok(())
};
|| -> Result<()> {
helper(scalar_type(BIT))?;
helper(scalar_type(UINT8))?;
helper(scalar_type(INT32))?;
helper(array_type(vec![2, 5], BIT))?;
helper(array_type(vec![2, 5], UINT8))?;
helper(array_type(vec![2, 5], INT32))?;
helper(tuple_type(vec![scalar_type(BIT), scalar_type(INT32)]))?;
helper(tuple_type(vec![
vector_type(3, scalar_type(BIT)),
vector_type(5, scalar_type(BIT)),
scalar_type(BIT),
scalar_type(INT32),
]))?;
helper(named_tuple_type(vec![
("field 1".to_owned(), scalar_type(BIT)),
("field 2".to_owned(), scalar_type(INT32)),
]))
}()
.unwrap()
}
#[test]
fn test_prng_random_value_flush() {
let mut g = PRNG::new(None).unwrap();
let mut helper = |t: Type, expected: u8| -> Result<()> {
let v = g.get_random_value(t.clone())?;
v.access_bytes(|bytes| {
if !bytes.is_empty() {
assert!(bytes.last() < Some(&expected));
}
Ok(())
})?;
Ok(())
};
|| -> Result<()> {
helper(array_type(vec![2, 5], BIT), 4)?;
helper(array_type(vec![3, 3], BIT), 2)?;
helper(array_type(vec![7], BIT), 128)?;
helper(scalar_type(BIT), 2)
}()
.unwrap();
}
#[test]
fn test_prng_random_u64_modulo() {
|| -> Result<()> {
let mut g = PRNG::new(None).unwrap();
let m = 100_u64;
let mut counters = vec![0; m as usize];
let expected_count_per_int = 1000;
let n = expected_count_per_int * m;
for _ in 0..n {
let r = g.get_random_in_range(Some(m))?;
counters[r as usize] += 1;
}
let chi2 = chi_statistics(&counters, expected_count_per_int);
assert!(chi2 < 180.792_f64);
Ok(())
}()
.unwrap();
}
#[test]
fn test_prf_fixed_key() {
|| -> Result<()> {
let key = b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0A\x0B\x0C\x0D\x0E\x0F";
let mut prf1 = Prf::new(Some(key.clone()))?;
let mut prf2 = Prf::new(Some(key.clone()))?;
for i in 0..100_000u64 {
assert_eq!(prf1.output_bytes(i, 1)?, prf2.output_bytes(i, 1)?);
assert_eq!(prf1.output_bytes(i, 5)?, prf2.output_bytes(i, 5)?);
}
Ok(())
}()
.unwrap();
}
#[test]
fn test_prf_random_key() {
|| -> Result<()> {
let mut prf = Prf::new(None)?;
let mut counters = [0; 256];
let n = 100_000u64;
let k = 10u64;
for i in 0..n {
let out = prf.output_bytes(i, k)?;
for byte in out {
counters[byte as usize] += 1;
}
}
assert!(entropy_test(counters, n * k as u64));
Ok(())
}()
.unwrap();
}
#[test]
fn test_prf_output_value() {
let mut g = Prf::new(None).unwrap();
let mut helper = |t: Type| -> Result<()> {
let v1 = g.output_value(15, t.clone())?;
let v2 = g.output_value(15, t.clone())?;
assert!(v1.check_type(t.clone())?);
assert!(v2.check_type(t.clone())?);
assert_eq!(v1, v2);
if let Type::Tuple(_) | Type::Vector(_, _) | Type::NamedTuple(_) = t.clone() {
let values = v1.to_vector()?;
let mut all_equal = true;
for i in 1..values.len() {
all_equal &= values[i - 1] == values[i];
}
assert!(!all_equal);
let mut numbers = vec![];
let types = get_types_vector(t)?;
for i in 0..types.len() {
let tp = (*types[i]).clone();
if !tp.is_array() {
return Ok(());
}
if tp.get_scalar_type() != UINT64 {
return Ok(());
}
let mut tmp = values[i].to_flattened_array_u64(tp)?;
numbers.append(&mut tmp)
}
let mut tmp_numbers = numbers.clone();
tmp_numbers.sort_unstable();
tmp_numbers.dedup();
assert_eq!(tmp_numbers.len(), numbers.len());
}
Ok(())
};
|| -> Result<()> {
helper(scalar_type(BIT))?;
helper(scalar_type(UINT8))?;
helper(scalar_type(INT32))?;
helper(array_type(vec![3, 4], BIT))?;
helper(array_type(vec![4, 2], UINT8))?;
helper(array_type(vec![6, 2], INT32))?;
helper(tuple_type(vec![scalar_type(BIT), scalar_type(INT32)]))?;
helper(tuple_type(vec![
vector_type(3, scalar_type(BIT)),
vector_type(5, scalar_type(BIT)),
scalar_type(BIT),
scalar_type(INT32),
]))?;
helper(tuple_type(vec![
scalar_type(INT32),
scalar_type(INT32),
scalar_type(INT32),
scalar_type(INT32),
]))?;
helper(tuple_type(vec![
array_type(vec![2, 2], INT32),
array_type(vec![2, 2], INT32),
array_type(vec![2, 2], INT32),
array_type(vec![2, 2], INT32),
]))?;
helper(tuple_type(vec![
array_type(vec![2, 1, 2], UINT64),
array_type(vec![2, 3, 2], UINT64),
array_type(vec![2, 2, 1], UINT64),
array_type(vec![3, 3, 2], UINT64),
]))?;
helper(named_tuple_type(vec![
("field 1".to_owned(), scalar_type(BIT)),
("field 2".to_owned(), scalar_type(INT32)),
]))
}()
.unwrap();
let mut helper_flush = |t: Type, expected: u8| -> Result<()> {
let v = g.output_value(181, t.clone())?;
v.access_bytes(|bytes| {
if !bytes.is_empty() {
assert!(bytes.last() < Some(&expected));
}
Ok(())
})?;
Ok(())
};
|| -> Result<()> {
helper_flush(array_type(vec![1, 5], BIT), 32)?;
helper_flush(array_type(vec![3, 3, 3], BIT), 8)?;
helper_flush(array_type(vec![2, 6], BIT), 16)?;
helper_flush(scalar_type(BIT), 2)
}()
.unwrap();
}
#[test]
fn test_generate_u32_in_range() -> Result<()> {
let mut prf = Prf::new(None)?;
let critical_value = [0f64, 23.9281, 27.6310, 30.6648, 33.3768, 35.8882];
for n in 2..6 {
let mut session = PrfSession::new(0, 1)?;
let expected_count = 1000000;
let runs = n * expected_count;
let mut stats: HashMap<u32, u64> = HashMap::new();
for _ in 0..runs {
let x = session.generate_u32_in_range(&mut prf.aes, n)?;
assert!(x < n);
*stats.entry(x).or_default() += 1;
}
let counters: Vec<u64> = stats.values().cloned().collect();
let chi2 = chi_statistics(&counters, expected_count as u64);
assert!(chi2 < critical_value[(n - 1) as usize]);
}
Ok(())
}
#[test]
fn test_prf_output_permutation() -> Result<()> {
let mut prf = Prf::new(None)?;
let mut helper = |n: u64| -> Result<()> {
let result_type = array_type(vec![n], UINT64);
let mut perm_statistics: HashMap<Vec<u64>, u64> = HashMap::new();
let expected_count_per_perm = 100;
let n_factorial: u64 = (2..=n).product();
let runs = expected_count_per_perm * n_factorial;
for input in 0..runs {
let result_value = prf.output_permutation(input, n)?;
let perm = result_value.to_flattened_array_u64(result_type.clone())?;
let mut perm_sorted = perm.clone();
perm_sorted.sort();
let range_vec: Vec<u64> = (0..n).collect();
assert_eq!(perm_sorted, range_vec);
perm_statistics
.entry(perm)
.and_modify(|counter| *counter += 1)
.or_insert(0);
}
assert_eq!(perm_statistics.len() as u64, n_factorial);
if n > 1 {
let counters: Vec<u64> = perm_statistics.values().map(|c| *c).collect();
let chi2 = chi_statistics(&counters, expected_count_per_perm);
if n == 4 {
assert!(chi2 < 70.5496_f64);
}
if n == 5 {
assert!(chi2 < 207.1986_f64);
}
}
Ok(())
};
helper(1)?;
helper(4)?;
helper(5)
}
#[test]
fn test_prf_output_permutation_correctness() -> Result<()> {
let mut prf = Prf::new(None)?;
let mut helper = |n: u64| -> Result<()> {
let result_type = array_type(vec![n], UINT64);
let result_value = prf.output_permutation(0, n)?;
let perm = result_value.to_flattened_array_u64(result_type.clone())?;
let mut perm_sorted = perm.clone();
perm_sorted.sort();
let range_vec: Vec<u64> = (0..n).collect();
assert_eq!(perm_sorted, range_vec);
Ok(())
};
helper(1)?;
helper(10)?;
helper(100)?;
helper(1000)?;
helper(10000)?;
helper(100000)?;
helper(1000000)?;
Ok(())
}
}