1use crate::bytes::vec_u64_from_bytes;
2use crate::data_types::{get_size_in_bits, get_types_vector, Type, UINT64};
3use crate::data_values::Value;
4use crate::errors::Result;
5
6use aes::cipher::KeyInit;
7use aes::cipher::{generic_array::GenericArray, BlockEncrypt};
8use aes::Aes128;
9use cipher::block_padding::NoPadding;
10use rand::rngs::OsRng;
11use rand::RngCore;
12
13const BLOCK_SIZE: usize = 16;
14
15pub fn get_bytes_from_os(bytes: &mut [u8]) -> Result<()> {
20 OsRng
21 .try_fill_bytes(bytes)
22 .map_err(|_| runtime_error!("OS random generator failed"))?;
23 Ok(())
24}
25
26pub const SEED_SIZE: usize = 16;
28
29const BUFFER_SIZE: usize = 512;
31
32const INITIAL_BUFFER_SIZE: usize = 64;
34
35pub struct PRNG {
42 aes: Aes128,
43 random_source: PrfSession,
44}
45impl PRNG {
48 pub fn new(seed: Option<[u8; SEED_SIZE]>) -> Result<PRNG> {
49 let bytes = match seed {
50 Some(bytes) => bytes,
51 None => {
52 let mut bytes = [0u8; SEED_SIZE];
53 get_bytes_from_os(&mut bytes)?;
54 bytes
55 }
56 };
57 let aes = aes::Aes128::new(GenericArray::from_slice(&bytes));
58 Ok(PRNG {
59 aes,
60 random_source: PrfSession::new(0, BUFFER_SIZE)?,
61 })
62 }
63
64 pub fn get_random_bytes(&mut self, n: usize) -> Result<Vec<u8>> {
65 self.random_source
66 .generate_random_bytes(&self.aes, n as u64)
67 }
68
69 fn get_random_key(&mut self) -> Result<[u8; SEED_SIZE]> {
70 let bytes = self.get_random_bytes(SEED_SIZE)?;
71 let mut new_seed = [0u8; SEED_SIZE];
72 new_seed.copy_from_slice(&bytes[0..SEED_SIZE]);
73 Ok(new_seed)
74 }
75
76 pub fn get_random_value(&mut self, t: Type) -> Result<Value> {
77 match t {
78 Type::Scalar(_) | Type::Array(_, _) => {
79 let bit_size = get_size_in_bits(t)?;
80 let byte_size = (bit_size + 7) / 8;
81 let bits_to_flush = 8 * byte_size - bit_size;
83 let mut bytes = self.get_random_bytes(byte_size as usize)?;
84 if !bytes.is_empty() {
86 *bytes.last_mut().unwrap() >>= bits_to_flush;
87 }
88 Ok(Value::from_bytes(bytes))
89 }
90 Type::Tuple(_) | Type::Vector(_, _) | Type::NamedTuple(_) => {
91 let ts = get_types_vector(t)?;
92 let mut v = vec![];
93 for sub_t in ts {
94 v.push(self.get_random_value((*sub_t).clone())?)
95 }
96 Ok(Value::from_vector(v))
97 }
98 }
99 }
100
101 pub fn get_random_in_range(&mut self, modulus: Option<u64>) -> Result<u64> {
109 if let Some(m) = modulus {
110 let rem = ((u64::MAX % m) + 1) % m;
111 let rejection_bound = u64::MAX - rem;
112 let mut r;
113 loop {
114 r = vec_u64_from_bytes(&self.get_random_bytes(8)?, UINT64)?[0];
115 if r <= rejection_bound {
116 break;
117 }
118 }
119 Ok(r % m)
120 } else {
121 Ok(vec_u64_from_bytes(&self.get_random_bytes(8)?, UINT64)?[0])
122 }
123 }
124}
125
126pub(super) struct Prf {
134 aes: Aes128,
135}
136
137impl Prf {
138 pub fn new(key: Option<[u8; SEED_SIZE]>) -> Result<Prf> {
139 let key_bytes = match key {
140 Some(bytes) => bytes,
141 None => {
142 let mut gen = PRNG::new(None)?;
143 gen.get_random_key()?
144 }
145 };
146 let aes = aes::Aes128::new(GenericArray::from_slice(&key_bytes));
147 Ok(Prf { aes })
148 }
149
150 #[cfg(test)]
151 fn output_bytes(&mut self, input: u64, n: u64) -> Result<Vec<u8>> {
152 let initial_buffer_size = usize::min(BUFFER_SIZE, n as usize);
153 PrfSession::new(input, initial_buffer_size)?.generate_random_bytes(&self.aes, n)
154 }
155
156 pub(super) fn output_value(&mut self, input: u64, t: Type) -> Result<Value> {
157 PrfSession::new(input, INITIAL_BUFFER_SIZE)?.recursively_generate_value(&self.aes, t)
158 }
159
160 pub(super) fn output_permutation(&mut self, input: u64, n: u64) -> Result<Value> {
161 if n > 2u64.pow(30) {
162 return Err(runtime_error!("n should be less than 2^30"));
163 }
164 let initial_buffer_size = usize::min(BUFFER_SIZE, n as usize);
167 let mut session = PrfSession::new(input, initial_buffer_size)?;
168 let mut a: Vec<u64> = (0..n).collect();
169 for i in 1..n {
170 let j = session.generate_u32_in_range(&self.aes, i as u32 + 1)?;
171 a.swap(i as usize, j as usize);
172 }
173 Value::from_flattened_array_u64(&a, UINT64)
174 }
175}
176
177struct PrfSession {
181 input: u128,
182 buffer: Vec<u8>,
183 next_byte: usize,
184 current_buffer_size: usize,
186 next_buffer_size: usize,
187}
188
189impl PrfSession {
190 pub fn new(input: u64, initial_buffer_size: usize) -> Result<Self> {
191 let initial_buffer_size = (initial_buffer_size + BLOCK_SIZE - 1) / BLOCK_SIZE * BLOCK_SIZE;
193 Ok(Self {
194 input: (input as u128) << 64,
195 buffer: vec![0u8; initial_buffer_size],
198 next_byte: initial_buffer_size,
199 current_buffer_size: initial_buffer_size,
200 next_buffer_size: initial_buffer_size,
201 })
202 }
203
204 fn generate_one_batch(&mut self, aes: &Aes128) -> Result<()> {
205 let mut i_bytes = vec![0u8; self.next_buffer_size];
206 for i in (0..i_bytes.len()).step_by(16) {
207 i_bytes[i..i + 16].copy_from_slice(&self.input.to_le_bytes());
208 self.input = self.input.wrapping_add(1);
209 }
210 let buffer_len = self.next_buffer_size;
211 if buffer_len != self.buffer.len() {
212 self.buffer.resize(buffer_len, 0);
213 }
214 aes.encrypt_padded_b2b::<NoPadding>(&i_bytes, &mut self.buffer)
215 .map_err(|e| runtime_error!("Encryption error: {e:?}"))?;
216 self.current_buffer_size = self.next_buffer_size;
217 if self.next_buffer_size < BUFFER_SIZE {
218 self.next_buffer_size = usize::min(BUFFER_SIZE, self.next_buffer_size * 2);
219 }
220 self.next_byte = 0;
221 Ok(())
222 }
223
224 fn generate_random_bytes(&mut self, aes: &Aes128, n: u64) -> Result<Vec<u8>> {
225 let mut bytes = vec![0u8; n as usize];
226 self.fill_random_bytes(aes, bytes.as_mut_slice())?;
227 Ok(bytes)
228 }
229
230 fn fill_random_bytes(&mut self, aes: &Aes128, mut buff: &mut [u8]) -> Result<()> {
231 while !buff.is_empty() {
232 let need_bytes = buff.len();
233 let ready_bytes = &self.buffer[self.next_byte..self.current_buffer_size];
234 if ready_bytes.len() >= need_bytes {
235 buff.clone_from_slice(&ready_bytes[..need_bytes]);
236 self.next_byte += need_bytes;
237 break;
238 } else {
239 buff[..ready_bytes.len()].clone_from_slice(ready_bytes);
240 buff = &mut buff[ready_bytes.len()..];
241 self.next_byte = 0;
242 self.generate_one_batch(aes)?;
243 }
244 }
245 Ok(())
246 }
247
248 fn recursively_generate_value(&mut self, aes: &Aes128, tp: Type) -> Result<Value> {
249 match tp {
250 Type::Scalar(_) | Type::Array(_, _) => {
251 let bit_size = get_size_in_bits(tp)?;
252 let byte_size = (bit_size + 7) / 8;
253 let bits_to_flush = 8 * byte_size - bit_size;
255 let mut bytes = self.generate_random_bytes(aes, byte_size)?;
256 if !bytes.is_empty() {
258 *bytes.last_mut().unwrap() >>= bits_to_flush;
259 }
260 Ok(Value::from_bytes(bytes))
261 }
262 Type::Tuple(_) | Type::Vector(_, _) | Type::NamedTuple(_) => {
263 let ts = get_types_vector(tp)?;
264 let mut v = vec![];
265 for sub_t in ts {
266 let value = self.recursively_generate_value(aes, (*sub_t).clone())?;
267 v.push(value);
268 }
269 Ok(Value::from_vector(v))
270 }
271 }
272 }
273
274 fn generate_random_number_const<const NEED_BYTES: usize>(
276 &mut self,
277 aes: &Aes128,
278 ) -> Result<u64> {
279 let mut res = [0u8; 8];
280 let use_bytes = std::cmp::min(self.current_buffer_size - self.next_byte, NEED_BYTES);
281 res[..use_bytes].copy_from_slice(&self.buffer[self.next_byte..self.next_byte + use_bytes]);
282 if use_bytes == NEED_BYTES {
283 self.next_byte += use_bytes;
284 } else {
285 self.generate_one_batch(aes)?;
286 self.next_byte = NEED_BYTES - use_bytes;
287 res[use_bytes..NEED_BYTES].copy_from_slice(&self.buffer[..self.next_byte]);
288 }
289 let mask = if NEED_BYTES == 8 {
290 u64::MAX
291 } else {
292 (1 << (NEED_BYTES * 8)) - 1
293 };
294 Ok(u64::from_le_bytes(res) & mask)
295 }
296
297 fn generate_random_number(&mut self, aes: &Aes128, need_bytes: usize) -> Result<u64> {
299 match need_bytes {
300 1 => self.generate_random_number_const::<1>(aes),
301 2 => self.generate_random_number_const::<2>(aes),
302 3 => self.generate_random_number_const::<3>(aes),
303 4 => self.generate_random_number_const::<4>(aes),
304 5 => self.generate_random_number_const::<5>(aes),
305 6 => self.generate_random_number_const::<6>(aes),
306 7 => self.generate_random_number_const::<7>(aes),
307 8 => self.generate_random_number_const::<8>(aes),
308 _ => Err(runtime_error!("Unsupported need bytes")),
309 }
310 }
311
312 fn generate_u32_in_range(&mut self, aes: &Aes128, modulus: u32) -> Result<u32> {
313 let modulus = modulus as u64;
314 let need_bytes = (modulus.next_power_of_two().trailing_zeros() + 7) / 8 + 1;
316 let max_rand_value = (1u64 << (need_bytes as u64 * 8)) - 1;
319 let num_biased = (max_rand_value + 1) % modulus;
320 let rejection_bound = max_rand_value - num_biased;
321 loop {
322 let rand_value = self.generate_random_number(aes, need_bytes as usize)?;
323 if rand_value <= rejection_bound {
324 return Ok((rand_value % modulus) as u32);
325 }
326 }
327 }
328}
329
330pub fn entropy_test(counters: [u32; 256], n: u64) -> bool {
342 let mut entropy = 0f64;
343 for c in counters {
344 let prob_c = (c as f64) / (n as f64);
345 entropy -= prob_c.log2() * prob_c;
346 }
347 let precision = (1020_f64) / (n as f64);
348 (entropy - 8f64).abs() < precision
349}
350
351pub fn chi_statistics(counters: &[u64], expected_count_per_element: u64) -> f64 {
353 let mut chi_statistics = 0_f64;
354 for c in counters {
355 chi_statistics += (*c as f64 - expected_count_per_element as f64).powi(2);
356 }
357 chi_statistics / expected_count_per_element as f64
358}
359
360#[cfg(test)]
361mod tests {
362 use std::collections::HashMap;
363
364 use super::*;
365 use crate::data_types::{
366 array_type, named_tuple_type, scalar_type, tuple_type, vector_type, BIT, INT32, UINT64,
367 UINT8,
368 };
369
370 #[test]
371
372 fn test_prng_fixed_seed() {
373 let helper = |n: usize| -> Result<()> {
374 let seed = b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0A\x0B\x0C\x0D\x0E\x0F";
375 let mut prng1 = PRNG::new(Some(seed.clone()))?;
376 let mut prng2 = PRNG::new(Some(seed.clone()))?;
377 let rand_bytes1 = prng1.get_random_bytes(n)?;
378 let rand_bytes2 = prng2.get_random_bytes(n)?;
379 assert_eq!(rand_bytes1, rand_bytes2);
380 Ok(())
381 };
382 helper(1).unwrap();
383 helper(19).unwrap();
384 helper(1000).unwrap();
385 }
386
387 #[test]
388 fn test_prng_random_seed() {
389 let mut prng = PRNG::new(None).unwrap();
390 let mut counters = [0; 256];
391 let n = 10_000_001;
392 let rand_bytes = prng.get_random_bytes(n).unwrap();
393 for byte in rand_bytes {
394 counters[byte as usize] += 1;
395 }
396
397 assert!(entropy_test(counters, n as u64));
398 }
399
400 #[test]
401 fn test_prng_random_value() {
402 let mut g = PRNG::new(None).unwrap();
403 let mut helper = |t: Type| -> Result<()> {
404 let v = g.get_random_value(t.clone())?;
405 assert!(v.check_type(t)?);
406 Ok(())
407 };
408 || -> Result<()> {
409 helper(scalar_type(BIT))?;
410 helper(scalar_type(UINT8))?;
411 helper(scalar_type(INT32))?;
412 helper(array_type(vec![2, 5], BIT))?;
413 helper(array_type(vec![2, 5], UINT8))?;
414 helper(array_type(vec![2, 5], INT32))?;
415 helper(tuple_type(vec![scalar_type(BIT), scalar_type(INT32)]))?;
416 helper(tuple_type(vec![
417 vector_type(3, scalar_type(BIT)),
418 vector_type(5, scalar_type(BIT)),
419 scalar_type(BIT),
420 scalar_type(INT32),
421 ]))?;
422 helper(named_tuple_type(vec![
423 ("field 1".to_owned(), scalar_type(BIT)),
424 ("field 2".to_owned(), scalar_type(INT32)),
425 ]))
426 }()
427 .unwrap()
428 }
429 #[test]
430 fn test_prng_random_value_flush() {
431 let mut g = PRNG::new(None).unwrap();
432 let mut helper = |t: Type, expected: u8| -> Result<()> {
433 let v = g.get_random_value(t.clone())?;
434 v.access_bytes(|bytes| {
435 if !bytes.is_empty() {
436 assert!(bytes.last() < Some(&expected));
437 }
438 Ok(())
439 })?;
440 Ok(())
441 };
442 || -> Result<()> {
443 helper(array_type(vec![2, 5], BIT), 4)?;
444 helper(array_type(vec![3, 3], BIT), 2)?;
445 helper(array_type(vec![7], BIT), 128)?;
446 helper(scalar_type(BIT), 2)
447 }()
448 .unwrap();
449 }
450 #[test]
451 fn test_prng_random_u64_modulo() {
452 || -> Result<()> {
453 let mut g = PRNG::new(None).unwrap();
454
455 let m = 100_u64;
456 let mut counters = vec![0; m as usize];
457 let expected_count_per_int = 1000;
458 let n = expected_count_per_int * m;
459 for _ in 0..n {
460 let r = g.get_random_in_range(Some(m))?;
461 counters[r as usize] += 1;
462 }
463
464 let chi2 = chi_statistics(&counters, expected_count_per_int);
468 assert!(chi2 < 180.792_f64);
469
470 Ok(())
471 }()
472 .unwrap();
473 }
474
475 #[test]
476 fn test_prf_fixed_key() {
477 || -> Result<()> {
478 let key = b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0A\x0B\x0C\x0D\x0E\x0F";
479 let mut prf1 = Prf::new(Some(key.clone()))?;
480 let mut prf2 = Prf::new(Some(key.clone()))?;
481 for i in 0..100_000u64 {
482 assert_eq!(prf1.output_bytes(i, 1)?, prf2.output_bytes(i, 1)?);
483 assert_eq!(prf1.output_bytes(i, 5)?, prf2.output_bytes(i, 5)?);
484 }
485 Ok(())
486 }()
487 .unwrap();
488 }
489
490 #[test]
491 fn test_prf_random_key() {
492 || -> Result<()> {
494 let mut prf = Prf::new(None)?;
495 let mut counters = [0; 256];
496 let n = 100_000u64;
497 let k = 10u64;
498 for i in 0..n {
499 let out = prf.output_bytes(i, k)?;
500 for byte in out {
501 counters[byte as usize] += 1;
502 }
503 }
504 assert!(entropy_test(counters, n * k as u64));
505 Ok(())
506 }()
507 .unwrap();
508 }
509 #[test]
510 fn test_prf_output_value() {
511 let mut g = Prf::new(None).unwrap();
512 let mut helper = |t: Type| -> Result<()> {
513 let v1 = g.output_value(15, t.clone())?;
514 let v2 = g.output_value(15, t.clone())?;
515 assert!(v1.check_type(t.clone())?);
516 assert!(v2.check_type(t.clone())?);
517 assert_eq!(v1, v2);
518 if let Type::Tuple(_) | Type::Vector(_, _) | Type::NamedTuple(_) = t.clone() {
519 let values = v1.to_vector()?;
520 let mut all_equal = true;
522 for i in 1..values.len() {
523 all_equal &= values[i - 1] == values[i];
524 }
525 assert!(!all_equal);
526
527 let mut numbers = vec![];
530 let types = get_types_vector(t)?;
531 for i in 0..types.len() {
532 let tp = (*types[i]).clone();
533 if !tp.is_array() {
534 return Ok(());
535 }
536 if tp.get_scalar_type() != UINT64 {
537 return Ok(());
538 }
539 let mut tmp = values[i].to_flattened_array_u64(tp)?;
540 numbers.append(&mut tmp)
541 }
542 let mut tmp_numbers = numbers.clone();
543 tmp_numbers.sort_unstable();
544 tmp_numbers.dedup();
545 assert_eq!(tmp_numbers.len(), numbers.len());
546 }
547 Ok(())
548 };
549 || -> Result<()> {
550 helper(scalar_type(BIT))?;
551 helper(scalar_type(UINT8))?;
552 helper(scalar_type(INT32))?;
553 helper(array_type(vec![3, 4], BIT))?;
554 helper(array_type(vec![4, 2], UINT8))?;
555 helper(array_type(vec![6, 2], INT32))?;
556 helper(tuple_type(vec![scalar_type(BIT), scalar_type(INT32)]))?;
557 helper(tuple_type(vec![
558 vector_type(3, scalar_type(BIT)),
559 vector_type(5, scalar_type(BIT)),
560 scalar_type(BIT),
561 scalar_type(INT32),
562 ]))?;
563 helper(tuple_type(vec![
564 scalar_type(INT32),
565 scalar_type(INT32),
566 scalar_type(INT32),
567 scalar_type(INT32),
568 ]))?;
569 helper(tuple_type(vec![
570 array_type(vec![2, 2], INT32),
571 array_type(vec![2, 2], INT32),
572 array_type(vec![2, 2], INT32),
573 array_type(vec![2, 2], INT32),
574 ]))?;
575 helper(tuple_type(vec![
576 array_type(vec![2, 1, 2], UINT64),
577 array_type(vec![2, 3, 2], UINT64),
578 array_type(vec![2, 2, 1], UINT64),
579 array_type(vec![3, 3, 2], UINT64),
580 ]))?;
581 helper(named_tuple_type(vec![
582 ("field 1".to_owned(), scalar_type(BIT)),
583 ("field 2".to_owned(), scalar_type(INT32)),
584 ]))
585 }()
586 .unwrap();
587
588 let mut helper_flush = |t: Type, expected: u8| -> Result<()> {
589 let v = g.output_value(181, t.clone())?;
590 v.access_bytes(|bytes| {
591 if !bytes.is_empty() {
592 assert!(bytes.last() < Some(&expected));
593 }
594 Ok(())
595 })?;
596 Ok(())
597 };
598 || -> Result<()> {
599 helper_flush(array_type(vec![1, 5], BIT), 32)?;
600 helper_flush(array_type(vec![3, 3, 3], BIT), 8)?;
601 helper_flush(array_type(vec![2, 6], BIT), 16)?;
602 helper_flush(scalar_type(BIT), 2)
603 }()
604 .unwrap();
605 }
606
607 #[test]
608 fn test_generate_u32_in_range() -> Result<()> {
609 let mut prf = Prf::new(None)?;
610 let critical_value = [0f64, 23.9281, 27.6310, 30.6648, 33.3768, 35.8882];
613 for n in 2..6 {
614 let mut session = PrfSession::new(0, 1)?;
615 let expected_count = 1000000;
616 let runs = n * expected_count;
617 let mut stats: HashMap<u32, u64> = HashMap::new();
618 for _ in 0..runs {
619 let x = session.generate_u32_in_range(&mut prf.aes, n)?;
620 assert!(x < n);
621 *stats.entry(x).or_default() += 1;
622 }
623 let counters: Vec<u64> = stats.values().cloned().collect();
624 let chi2 = chi_statistics(&counters, expected_count as u64);
625 assert!(chi2 < critical_value[(n - 1) as usize]);
626 }
627 Ok(())
628 }
629
630 #[test]
631 fn test_prf_output_permutation() -> Result<()> {
632 let mut prf = Prf::new(None)?;
633 let mut helper = |n: u64| -> Result<()> {
634 let result_type = array_type(vec![n], UINT64);
635 let mut perm_statistics: HashMap<Vec<u64>, u64> = HashMap::new();
636 let expected_count_per_perm = 100;
637 let n_factorial: u64 = (2..=n).product();
638 let runs = expected_count_per_perm * n_factorial;
639 for input in 0..runs {
640 let result_value = prf.output_permutation(input, n)?;
641 let perm = result_value.to_flattened_array_u64(result_type.clone())?;
642
643 let mut perm_sorted = perm.clone();
644 perm_sorted.sort();
645 let range_vec: Vec<u64> = (0..n).collect();
646 assert_eq!(perm_sorted, range_vec);
647
648 perm_statistics
649 .entry(perm)
650 .and_modify(|counter| *counter += 1)
651 .or_insert(0);
652 }
653
654 assert_eq!(perm_statistics.len() as u64, n_factorial);
656
657 if n > 1 {
660 let counters: Vec<u64> = perm_statistics.values().map(|c| *c).collect();
661 let chi2 = chi_statistics(&counters, expected_count_per_perm);
662 if n == 4 {
664 assert!(chi2 < 70.5496_f64);
665 }
666 if n == 5 {
667 assert!(chi2 < 207.1986_f64);
668 }
669 }
670 Ok(())
671 };
672 helper(1)?;
673 helper(4)?;
674 helper(5)
675 }
676
677 #[test]
678 fn test_prf_output_permutation_correctness() -> Result<()> {
679 let mut prf = Prf::new(None)?;
680 let mut helper = |n: u64| -> Result<()> {
681 let result_type = array_type(vec![n], UINT64);
682 let result_value = prf.output_permutation(0, n)?;
683 let perm = result_value.to_flattened_array_u64(result_type.clone())?;
684
685 let mut perm_sorted = perm.clone();
686 perm_sorted.sort();
687 let range_vec: Vec<u64> = (0..n).collect();
688 assert_eq!(perm_sorted, range_vec);
689 Ok(())
690 };
691 helper(1)?;
692 helper(10)?;
693 helper(100)?;
694 helper(1000)?;
695 helper(10000)?;
696 helper(100000)?;
697 helper(1000000)?;
698 Ok(())
699 }
700}