use std::ops::Bound;
use crate::generators::aes_ctr::{AesBlockCipher, AesCtrGenerator, AesIndex, TableIndex};
use crate::generators::{widening_mul, BytesPerChild, ChildrenCount, ForkError};
pub type ParallelChildrenClosure<BlockCipher> = fn(
(
usize,
(Box<BlockCipher>, TableIndex, BytesPerChild, AesIndex),
),
) -> AesCtrGenerator<BlockCipher>;
pub type ParallelChildrenIterator<BlockCipher> = rayon::iter::Map<
rayon::iter::Zip<
rayon::range::Iter<usize>,
rayon::iter::RepeatN<(Box<BlockCipher>, TableIndex, BytesPerChild, AesIndex)>,
>,
fn(
(
usize,
(Box<BlockCipher>, TableIndex, BytesPerChild, AesIndex),
),
) -> AesCtrGenerator<BlockCipher>,
>;
impl<BlockCipher: AesBlockCipher> AesCtrGenerator<BlockCipher> {
pub fn par_try_fork(
&mut self,
n_children: ChildrenCount,
n_bytes: BytesPerChild,
) -> Result<ParallelChildrenIterator<BlockCipher>, ForkError>
where
BlockCipher: Send + Sync,
{
use rayon::prelude::*;
if n_children.0 > (usize::MAX as u64) {
return Err(ForkError::ForkTooLarge);
}
let (first_index, offset, new_parent_state) = self.state.check_fork(n_children, n_bytes)?;
let output = (0..n_children.0 as usize)
.into_par_iter()
.zip(rayon::iter::repeat_n(
(self.block_cipher.clone(), first_index, n_bytes, offset),
n_children.0 as usize,
))
.map(
(|(i, (block_cipher, first_index, n_bytes, offset))| {
let child_first_index =
first_index.increased(widening_mul(n_bytes.0, i as u64));
let child_bound_index =
first_index.increased(widening_mul(n_bytes.0, (i + 1) as u64));
AesCtrGenerator::from_block_cipher(
block_cipher,
child_first_index,
Bound::Excluded(child_bound_index),
offset,
)
}) as ParallelChildrenClosure<BlockCipher>,
);
self.state = new_parent_state;
Ok(output)
}
}
#[cfg(test)]
pub mod aes_ctr_parallel_generic_tests {
use std::ops::Bound;
use super::*;
use crate::generators::aes_ctr::aes_ctr_generic_test::{
any_key, any_valid_fork, assert_generator_matches_cipher, make_ctr_pair,
};
use crate::generators::aes_ctr::index::AesIndex;
use crate::generators::aes_ctr::{AesKey, BYTES_PER_AES_CALL};
use rand::prelude::*;
use rand::rngs::ThreadRng;
use rand::thread_rng;
use rayon::prelude::*;
const REPEATS: usize = 1_000_000;
pub fn prop_fork_first_state_table_index<G: AesBlockCipher>() {
for _ in 0..REPEATS {
let (t, nc, nb, i) = any_valid_fork().next().unwrap();
let k = any_key().next().unwrap();
let offset = AesIndex(rand::random());
let original_generator = AesCtrGenerator::<G>::new(
k,
t,
Bound::Excluded(t.increased(widening_mul(nc.0, nb.0) + i)),
offset,
);
let mut forked_generator = original_generator.clone();
let first_child = forked_generator
.par_try_fork(nc, nb)
.unwrap()
.find_first(|_| true)
.unwrap();
assert_eq!(
original_generator.next_table_index(),
first_child.next_table_index(),
"key={k:?}, t={t:?}, offset={offset:?}"
);
}
}
pub fn prop_fork_last_bound_table_index<G: AesBlockCipher>() {
for _ in 0..REPEATS {
let (t, nc, nb, i) = any_valid_fork().next().unwrap();
let k = any_key().next().unwrap();
let offset = AesIndex(rand::random());
let mut parent_generator = AesCtrGenerator::<G>::new(
k,
t,
Bound::Excluded(t.increased(widening_mul(nc.0, nb.0) + i)),
offset,
);
let last_child = parent_generator
.par_try_fork(nc, nb)
.unwrap()
.find_last(|_| true)
.unwrap();
let child_last_index = last_child.get_last();
let parent_next = parent_generator.next_table_index().unwrap();
assert_eq!(
child_last_index, Some(parent_next.decremented()),
"last_child.last={child_last_index:?}, parent_generator.next={parent_next:?}\nkey={k:?}, t={t:?}, offset={offset:?}"
);
}
}
pub fn prop_fork_parent_bound_table_index<G: AesBlockCipher>() {
for _ in 0..REPEATS {
let (t, nc, nb, i) = any_valid_fork().next().unwrap();
let k = any_key().next().unwrap();
let offset = AesIndex(rand::random());
let original_generator = AesCtrGenerator::<G>::new(
k,
t,
Bound::Excluded(t.increased(widening_mul(nc.0, nb.0) + i)),
offset,
);
let mut forked_generator = original_generator.clone();
forked_generator
.par_try_fork(nc, nb)
.unwrap()
.find_last(|_| true)
.unwrap();
assert_eq!(
original_generator.get_last(),
forked_generator.get_last(),
"key={k:?}, t={t:?}, offset={offset:?}"
);
}
}
pub fn prop_fork_parent_state_table_index<G: AesBlockCipher>() {
for _ in 0..REPEATS {
let (t, nc, nb, i) = any_valid_fork().next().unwrap();
let k = any_key().next().unwrap();
let offset = AesIndex(rand::random());
let original_generator = AesCtrGenerator::<G>::new(
k,
t,
Bound::Excluded(t.increased(widening_mul(nc.0, nb.0) + i)),
offset,
);
let mut forked_generator = original_generator.clone();
forked_generator
.par_try_fork(nc, nb)
.unwrap()
.find_last(|_| true)
.unwrap();
assert_eq!(
forked_generator.next_table_index().unwrap(),
t.increased(widening_mul(nc.0, nb.0)),
"key={k:?}, t={t:?}, offset={offset:?}"
);
}
}
pub fn prop_fork<G: AesBlockCipher>() {
for _ in 0..1000 {
let (t, nc, nb, i) = any_valid_fork().next().unwrap();
let k = any_key().next().unwrap();
let offset = AesIndex(rand::random());
let bytes_to_go = nc.0 * nb.0;
let original_generator = AesCtrGenerator::<G>::new(
k,
t,
Bound::Excluded(t.increased(widening_mul(nc.0, nb.0) + i)),
offset,
);
let mut forked_generator = original_generator.clone();
let initial_output: Vec<u8> = original_generator
.take(usize::try_from(bytes_to_go).unwrap())
.collect();
let forked_output: Vec<u8> = forked_generator
.par_try_fork(nc, nb)
.unwrap()
.flat_map(|child| child.collect::<Vec<_>>())
.collect();
assert_eq!(
initial_output, forked_output,
"key={k:?}, t={t:?}, offset={offset:?}"
);
}
}
pub fn prop_fork_children_remaining_bytes<G: AesBlockCipher>() {
for _ in 0..REPEATS {
let (t, nc, nb, i) = any_valid_fork().next().unwrap();
let k = any_key().next().unwrap();
let offset = AesIndex(rand::random());
let mut generator = AesCtrGenerator::<G>::new(
k,
t,
Bound::Excluded(t.increased(widening_mul(nc.0, nb.0) + i)),
offset,
);
assert!(
generator
.par_try_fork(nc, nb)
.unwrap()
.all(|c| c.remaining_bytes().0 == nb.0 as u128),
"key={k:?}, t={t:?}, offset={offset:?}"
);
}
}
pub fn prop_fork_with_parent_continuation<G: AesBlockCipher>() {
for _ in 0..10_000 {
let (t, nc, nb, num_extra_bytes) = any_valid_fork().next().unwrap();
let k = any_key().next().unwrap();
let fork_bytes = widening_mul(nc.0, nb.0);
let total_bytes = fork_bytes.saturating_add(num_extra_bytes);
let offset = AesIndex(rand::random());
let mut gen1 =
AesCtrGenerator::<G>::new(k, t, Bound::Excluded(t.increased(total_bytes)), offset);
let mut forked_gen = gen1.clone();
let bytes_per_parts = (num_extra_bytes.min(u128::from(u16::MAX)) / 2) as usize;
for i in 0..bytes_per_parts {
let byte = forked_gen.next().unwrap();
let expected_byte = gen1.next().unwrap();
assert_eq!(
byte, expected_byte,
"pre-fork bytes are not equal (byte index {i}), key={k:?}, t={t:?}, offset={offset:?}"
);
}
let children: Vec<_> = forked_gen.par_try_fork(nc, nb).unwrap().collect();
for (child_i, child) in children.into_iter().enumerate() {
for (i, byte) in child.enumerate() {
let expected_byte = gen1.next().unwrap();
assert_eq!(
byte, expected_byte,
"invalid byte at index {i} for child {child_i}, key={k:?}, t={t:?}, offset={offset:?}"
);
}
}
for i in 0..bytes_per_parts {
let byte = forked_gen.next().unwrap();
let expected_byte = gen1.next().unwrap();
assert_eq!(byte, expected_byte, "post-fork bytes are not equal (byte index {i}), got {byte}, expected {expected_byte}, key={k:?}, t={t:?}, offset={offset:?}");
}
}
}
pub fn prop_fork_parent_remaining_bytes<G: AesBlockCipher>() {
for _ in 0..REPEATS {
let (t, nc, nb, i) = any_valid_fork().next().unwrap();
let k = any_key().next().unwrap();
let offset = AesIndex(rand::random());
let bytes_to_go = nc.0 * nb.0;
let mut generator = AesCtrGenerator::<G>::new(
k,
t,
Bound::Excluded(t.increased(widening_mul(nc.0, nb.0) + i)),
offset,
);
let before_remaining_bytes = generator.remaining_bytes();
let _ = generator.par_try_fork(nc, nb).unwrap();
let after_remaining_bytes = generator.remaining_bytes();
assert_eq!(
before_remaining_bytes.0 - after_remaining_bytes.0,
bytes_to_go as u128,
"key={k:?}, t={t:?}, offset={offset:?}"
);
}
}
pub fn test_forking_conformance_with_ctr_crate<G: AesBlockCipher>() {
let mut rng = thread_rng();
fn random_tuple_that_equals(rng: &mut ThreadRng, x: u64) -> (u64, u64) {
loop {
let a: u64 = rng.gen_range(1..=x);
if x.is_multiple_of(a) {
let b = x / a;
return (a, b);
}
}
}
for _ in 0..1_000 {
let key = AesKey(rng.gen());
let aes_idx: u128 = rng.gen_range(0..u128::MAX);
let offset: u128 = rng.gen();
for byte_idx in 0..BYTES_PER_AES_CALL {
let (mut gen, mut cipher) = make_ctr_pair::<G>(key, aes_idx, byte_idx, offset);
let bytes = gen.remaining_bytes().0.min(u128::from(u16::MAX));
let bytes_per_parts = (bytes / 3) as usize;
assert_generator_matches_cipher(&mut gen, &mut cipher, bytes_per_parts, || {
format!("invalid bytes pre-fork buffer, key={key:?}, aes_idx={aes_idx}, byte_idx={byte_idx}, offset={offset}")
});
if gen.remaining_bytes().0 == 0 {
continue;
}
let (nc, nb) = random_tuple_that_equals(&mut rng, bytes_per_parts as u64);
let children = gen
.par_try_fork(ChildrenCount(nc), BytesPerChild(nb))
.unwrap()
.collect::<Vec<_>>();
for (child_i, mut child) in children.into_iter().enumerate() {
assert_generator_matches_cipher(&mut child, &mut cipher, nb as usize, || {
format!("invalid bytes child #{child_i} buffer, key={key:?}, aes_idx={aes_idx}, byte_idx={byte_idx}, offset={offset}")
});
}
assert_generator_matches_cipher(&mut gen, &mut cipher, bytes_per_parts, || {
format!("invalid bytes post-fork buffer, key={key:?}, aes_idx={aes_idx}, byte_idx={byte_idx}, offset={offset}")
});
}
}
}
}