use std::ops::Bound;
use crate::generators::aes_ctr::block_cipher::{AesBlockCipher, AesKey};
use crate::generators::aes_ctr::index::TableIndex;
use crate::generators::aes_ctr::states::{BufferPointer, ShiftAction, State};
use crate::generators::aes_ctr::{AesIndex, BYTES_PER_BATCH};
use crate::generators::{widening_mul, ByteCount, BytesPerChild, ChildrenCount, ForkError};
pub type ChildrenClosure<BlockCipher> = fn(
(u64, (Box<BlockCipher>, TableIndex, BytesPerChild, AesIndex)),
) -> AesCtrGenerator<BlockCipher>;
pub type ChildrenIterator<BlockCipher> = std::iter::Map<
std::iter::Zip<
std::ops::Range<u64>,
std::iter::Repeat<(Box<BlockCipher>, TableIndex, BytesPerChild, AesIndex)>,
>,
ChildrenClosure<BlockCipher>,
>;
#[derive(Clone)]
pub struct AesCtrGenerator<BlockCipher: AesBlockCipher> {
pub(crate) block_cipher: Box<BlockCipher>,
pub(crate) buffer: [u8; BYTES_PER_BATCH],
pub(crate) state: State,
}
#[allow(unused)] impl<BlockCipher: AesBlockCipher> AesCtrGenerator<BlockCipher> {
pub(crate) fn new(
key: AesKey,
start_index: TableIndex,
end: Bound<TableIndex>,
offset: AesIndex,
) -> AesCtrGenerator<BlockCipher> {
AesCtrGenerator::from_block_cipher(
Box::new(BlockCipher::new(key)),
start_index,
end,
offset,
)
}
pub(crate) fn from_block_cipher(
block_cipher: Box<BlockCipher>,
start_index: TableIndex,
end: Bound<TableIndex>,
offset: AesIndex,
) -> AesCtrGenerator<BlockCipher> {
let buffer = [0u8; BYTES_PER_BATCH];
let state = State::new(start_index, end, offset);
AesCtrGenerator {
block_cipher,
state,
buffer,
}
}
pub(crate) fn from_params(params: impl Into<super::AesCtrParams>) -> Self {
use crate::seeders::SeedKind;
let params = params.into();
let (key, offset) = match ¶ms.seed {
SeedKind::Ctr(s) => (AesKey(u128::from_le(s.0)), AesIndex(0)),
SeedKind::Xof(xof) => super::xof_init(xof.clone()),
};
Self::new(key, params.first_index, Bound::Unbounded, offset)
}
pub(crate) fn next_table_index(&self) -> Option<TableIndex> {
self.state.next_table_index()
}
#[cfg(test)]
pub(crate) fn get_last(&self) -> Option<TableIndex> {
self.state.last()
}
pub(crate) fn remaining_bytes(&self) -> ByteCount {
self.state.remaining_bytes()
}
pub(crate) fn generate_next(&mut self) -> u8 {
self.next()
.expect("Tried to generate a byte after the bound.")
}
pub(crate) fn try_fork(
&mut self,
n_children: ChildrenCount,
n_bytes: BytesPerChild,
) -> Result<ChildrenIterator<BlockCipher>, ForkError> {
let (first_index, offset, new_parent_state) = self.state.check_fork(n_children, n_bytes)?;
let output = (0..n_children.0)
.zip(std::iter::repeat((
self.block_cipher.clone(),
first_index,
n_bytes,
offset,
)))
.map(
(|(i, (block_cipher, first_index, n_bytes, offset))| {
let child_first_index = first_index.increased(widening_mul(n_bytes.0, i));
let child_bound_index = first_index.increased(widening_mul(n_bytes.0, (i + 1)));
AesCtrGenerator::from_block_cipher(
block_cipher,
child_first_index,
Bound::Excluded(child_bound_index),
offset,
)
}) as ChildrenClosure<BlockCipher>,
);
self.state = new_parent_state;
Ok(output)
}
}
impl<BlockCipher: AesBlockCipher> Iterator for AesCtrGenerator<BlockCipher> {
type Item = u8;
fn next(&mut self) -> Option<Self::Item> {
match self.state.next() {
ShiftAction::NoOutput => None,
ShiftAction::OutputByte(ptr) => Some(self.buffer[ptr.0]),
ShiftAction::RefreshBatchAndOutputByte(aes_index, BufferPointer(ptr)) => {
let aes_inputs =
core::array::from_fn(|i| aes_index.0.wrapping_add(i as u128).to_le());
self.buffer = self.block_cipher.generate_batch(aes_inputs);
Some(self.buffer[ptr])
}
}
}
}
#[cfg(test)]
#[allow(unused)] pub mod aes_ctr_generic_test {
use std::ops::Bound;
use super::*;
use crate::generators::aes_ctr::index::{AesIndex, ByteIndex};
use crate::generators::aes_ctr::BYTES_PER_AES_CALL;
use crate::seeders::{Seed, SeedKind};
use aes::cipher::{KeyIvInit, StreamCipher, StreamCipherSeek};
use aes::Aes128;
use ctr::Ctr128LE;
use rand::rngs::ThreadRng;
use rand::{thread_rng, Rng};
const REPEATS: usize = 1_000_000;
pub fn any_table_index() -> impl Iterator<Item = TableIndex> {
std::iter::repeat_with(|| {
TableIndex::new(
AesIndex(thread_rng().gen()),
ByteIndex(thread_rng().gen::<usize>() % BYTES_PER_AES_CALL),
)
})
}
pub fn any_u128() -> impl Iterator<Item = u128> {
std::iter::repeat_with(|| thread_rng().gen())
}
pub fn any_children_count() -> impl Iterator<Item = ChildrenCount> {
std::iter::repeat_with(|| ChildrenCount(thread_rng().gen::<u64>() % 2048 + 1))
}
pub fn any_bytes_per_child() -> impl Iterator<Item = BytesPerChild> {
std::iter::repeat_with(|| BytesPerChild(thread_rng().gen::<u64>() % 2048 + 1))
}
pub fn any_key() -> impl Iterator<Item = AesKey> {
std::iter::repeat_with(|| AesKey(thread_rng().gen()))
}
pub fn any_valid_fork() -> impl Iterator<Item = (TableIndex, ChildrenCount, BytesPerChild, u128)>
{
any_table_index()
.zip(any_children_count())
.zip(any_bytes_per_child())
.zip(any_u128())
.map(|(((t, nc), nb), i)| (t, nc, nb, i))
.filter(|(t, nc, nb, i)| {
TableIndex::distance(&TableIndex::LAST, t).unwrap().0 > widening_mul(nc.0, nb.0) + i
})
}
pub fn prop_fork_first_state_table_index<G: AesBlockCipher>() {
for _ in 0..REPEATS {
let (t, nc, nb, i) = any_valid_fork().next().unwrap();
let k = any_key().next().unwrap();
let offset = AesIndex(rand::random());
let original_generator = AesCtrGenerator::<G>::new(
k,
t,
Bound::Excluded(t.increased(widening_mul(nc.0, nb.0) + i)),
offset,
);
let mut forked_generator = original_generator.clone();
let first_child = forked_generator.try_fork(nc, nb).unwrap().next().unwrap();
assert_eq!(
original_generator.next_table_index(),
first_child.next_table_index(),
"key={k:?}, t={t:?}, offset={offset:?}"
);
}
}
pub fn prop_fork_last_bound_table_index<G: AesBlockCipher>() {
for _ in 0..REPEATS {
let (t, nc, nb, i) = any_valid_fork().next().unwrap();
let k = any_key().next().unwrap();
let offset = AesIndex(rand::random());
let mut parent_generator = AesCtrGenerator::<G>::new(
k,
t,
Bound::Excluded(t.increased(widening_mul(nc.0, nb.0) + i)),
offset,
);
let last_child = parent_generator.try_fork(nc, nb).unwrap().last().unwrap();
let child_last_index = last_child.get_last();
let parent_next = parent_generator.next_table_index().unwrap();
assert_eq!(
child_last_index, Some(parent_next.decremented()),
"last_child.last={child_last_index:?}, parent_generator.next={parent_next:?}\nkey={k:?}, t={t:?}, offset={offset:?}"
);
}
}
pub fn prop_fork_parent_bound_table_index<G: AesBlockCipher>() {
for _ in 0..REPEATS {
let (t, nc, nb, i) = any_valid_fork().next().unwrap();
let k = any_key().next().unwrap();
let offset = AesIndex(rand::random());
let original_generator = AesCtrGenerator::<G>::new(
k,
t,
Bound::Excluded(t.increased(widening_mul(nc.0, nb.0) + i)),
offset,
);
let mut forked_generator = original_generator.clone();
forked_generator.try_fork(nc, nb).unwrap().last().unwrap();
assert_eq!(
original_generator.get_last(),
forked_generator.get_last(),
"key={k:?}, t={t:?}, offset={offset:?}"
);
}
}
pub fn prop_fork_parent_state_table_index<G: AesBlockCipher>() {
for _ in 0..REPEATS {
let (t, nc, nb, i) = any_valid_fork().next().unwrap();
let k = any_key().next().unwrap();
let offset = AesIndex(rand::random());
let original_generator = AesCtrGenerator::<G>::new(
k,
t,
Bound::Excluded(t.increased(widening_mul(nc.0, nb.0) + i)),
offset,
);
let mut forked_generator = original_generator.clone();
forked_generator.try_fork(nc, nb).unwrap().last().unwrap();
assert_eq!(
forked_generator.next_table_index().unwrap(),
t.increased(widening_mul(nc.0, nb.0)),
"key={k:?}, t={t:?}, offset={offset:?}"
);
}
}
pub fn prop_fork<G: AesBlockCipher>() {
for _ in 0..1000 {
let (t, nc, nb, i) = any_valid_fork().next().unwrap();
let offset = AesIndex(rand::random());
let k = any_key().next().unwrap();
let bytes_to_go = nc.0 * nb.0;
let original_generator = AesCtrGenerator::<G>::new(
k,
t,
Bound::Excluded(t.increased(widening_mul(nc.0, nb.0) + i)),
offset,
);
let mut forked_generator = original_generator.clone();
let initial_output: Vec<u8> = original_generator
.take(usize::try_from(bytes_to_go).unwrap())
.collect();
let forked_output: Vec<u8> = forked_generator
.try_fork(nc, nb)
.unwrap()
.flat_map(|child| child.collect::<Vec<_>>())
.collect();
assert_eq!(
initial_output, forked_output,
"key={k:?}, t={t:?}, offset={offset:?}"
);
}
}
pub fn prop_fork_children_remaining_bytes<G: AesBlockCipher>() {
for _ in 0..REPEATS {
let (t, nc, nb, i) = any_valid_fork().next().unwrap();
let k = any_key().next().unwrap();
let offset = AesIndex(rand::random());
let mut generator = AesCtrGenerator::<G>::new(
k,
t,
Bound::Excluded(t.increased(widening_mul(nc.0, nb.0) + i)),
offset,
);
assert!(
generator
.try_fork(nc, nb)
.unwrap()
.all(|c| c.remaining_bytes().0 == nb.0 as u128),
"key={k:?}, t={t:?}, offset={offset:?}"
);
}
}
pub fn prop_fork_parent_remaining_bytes<G: AesBlockCipher>() {
for _ in 0..REPEATS {
let (t, nc, nb, i) = any_valid_fork().next().unwrap();
let k = any_key().next().unwrap();
let offset = AesIndex(rand::random());
let bytes_to_go = nc.0 * nb.0;
let mut generator = AesCtrGenerator::<G>::new(
k,
t,
Bound::Excluded(t.increased(widening_mul(nc.0, nb.0) + i)),
offset,
);
let before_remaining_bytes = generator.remaining_bytes();
let _ = generator.try_fork(nc, nb).unwrap();
let after_remaining_bytes = generator.remaining_bytes();
assert_eq!(
before_remaining_bytes.0 - after_remaining_bytes.0,
bytes_to_go as u128,
"key={k:?}, t={t:?}, offset={offset:?}"
);
}
}
pub(crate) fn assert_generator_matches_cipher(
gen: &mut impl Iterator<Item = u8>,
reference_cipher: &mut impl StreamCipher,
num_bytes: usize,
msg_fn: impl Fn() -> String,
) {
let mut reference_buffer = [0u8; 1024];
let mut buffer = [0u8; 1024];
let n = num_bytes.div_ceil(buffer.len());
for i in 0..n {
let valid = if i != n - 1 {
buffer.len()
} else {
num_bytes % buffer.len()
};
for o in buffer[..valid].iter_mut() {
*o = gen.next().expect("Unexpected end of generator");
}
reference_buffer.fill(0);
reference_cipher.apply_keystream(&mut reference_buffer[..valid]);
assert_eq!(
&buffer[..valid],
&reference_buffer[..valid],
"{} #{i}",
msg_fn()
);
}
}
pub(crate) fn make_ctr_pair<G: AesBlockCipher>(
key: AesKey,
aes_idx: u128,
byte_idx: usize,
offset: u128,
) -> (AesCtrGenerator<G>, Ctr128LE<Aes128>) {
let key_bytes = key.0.to_ne_bytes();
let counter = aes_idx.wrapping_add(offset);
let nonce_bytes = counter.to_le_bytes();
let mut cipher = Ctr128LE::<Aes128>::new_from_slices(&key_bytes, &nonce_bytes).unwrap();
let mut start = TableIndex::new(AesIndex(aes_idx), ByteIndex(byte_idx));
cipher.seek(byte_idx);
let gen = AesCtrGenerator::<G>::new(key, start, Bound::Unbounded, AesIndex(offset));
(gen, cipher)
}
pub fn test_conformance_with_ctr_crate<G: AesBlockCipher>() {
let mut rng = thread_rng();
for _ in 0..1_000 {
let key = AesKey(rng.gen());
let aes_idx: u128 = rng.gen();
let offset: u128 = rng.gen();
for byte_idx in 0..BYTES_PER_AES_CALL {
let (mut gen, mut cipher) = make_ctr_pair::<G>(key, aes_idx, byte_idx, offset);
let remaining_bytes = gen.remaining_bytes().0.min(u128::from(u16::MAX)) as usize;
assert_generator_matches_cipher(&mut gen, &mut cipher, remaining_bytes, || {
format!("invalid bytes buffer, key={key:?}, aes_idx={aes_idx}, byte_idx={byte_idx}, offset={offset}")
});
}
}
}
pub fn test_forking_conformance_with_ctr_crate<G: AesBlockCipher>() {
let mut rng = thread_rng();
fn random_tuple_less_or_equal_than(rng: &mut ThreadRng, x: u64) -> (u64, u64) {
let mut mul = 0;
let mut a = 0;
let mut b = 0;
for i in 0..30 {
let a2: u64 = rng.gen_range(1..=x);
let b2 = x / a2;
let m2 = a2 * b2;
if m2 > mul {
mul = m2;
a = a2;
b = b2;
}
}
(a, b)
}
for _ in 0..1_000 {
let key = AesKey(rng.gen());
let aes_idx: u128 = rng.gen_range(0..u128::MAX);
let offset: u128 = rng.gen();
for byte_idx in 0..BYTES_PER_AES_CALL {
let (mut gen, mut cipher) = make_ctr_pair::<G>(key, aes_idx, byte_idx, offset);
let bytes = gen.remaining_bytes().0.min(u128::from(u16::MAX));
let bytes_per_parts = (bytes / 3) as usize;
if bytes_per_parts == 0 {
continue;
}
assert_generator_matches_cipher(&mut gen, &mut cipher, bytes_per_parts, || {
format!("invalid bytes pre-fork buffer, key={key:?}, aes_idx={aes_idx}, byte_idx={byte_idx}, offset={offset}")
});
let (nc, nb) = random_tuple_less_or_equal_than(&mut rng, bytes_per_parts as u64);
for (child_i, mut child) in gen
.try_fork(ChildrenCount(nc), BytesPerChild(nb))
.unwrap()
.enumerate()
{
assert_generator_matches_cipher(&mut child, &mut cipher, nb as usize, || {
format!("invalid bytes child #{child_i} buffer, key={key:?}, aes_idx={aes_idx}, byte_idx={byte_idx}, offset={offset}")
});
}
assert_generator_matches_cipher(&mut gen, &mut cipher, bytes_per_parts, || {
format!("invalid bytes post-fork buffer, key={key:?}, aes_idx={aes_idx}, byte_idx={byte_idx}, offset={offset}")
});
}
}
}
pub fn prop_fork_with_parent_continuation<G: AesBlockCipher>() {
for _ in 0..10_000 {
let (t, nc, nb, num_extra_bytes) = any_valid_fork().next().unwrap();
let k = any_key().next().unwrap();
let fork_bytes = widening_mul(nc.0, nb.0);
let total_bytes = fork_bytes.saturating_add(num_extra_bytes);
let offset = AesIndex(rand::random());
let mut gen1 =
AesCtrGenerator::<G>::new(k, t, Bound::Excluded(t.increased(total_bytes)), offset);
let mut forked_gen = gen1.clone();
let bytes_per_parts = (num_extra_bytes.min(u128::from(u16::MAX)) / 2) as usize;
for i in 0..bytes_per_parts {
let byte = forked_gen.next().unwrap();
let expected_byte = gen1.next().unwrap();
assert_eq!(
byte, expected_byte,
"pre-fork bytes are not equal (byte index {i}), key={k:?}, t={t:?}, offset={offset:?} nc={nc:?}, nb={nb:?}, num_extra_bytes={num_extra_bytes}, bytes_per_parts={bytes_per_parts}"
);
}
for (child_i, mut child) in forked_gen.try_fork(nc, nb).unwrap().enumerate() {
let mut child_byte_count: u64 = 0;
for byte in child.by_ref() {
let i = child_byte_count;
child_byte_count += 1;
let expected_byte = gen1.next().unwrap();
assert_eq!(
byte, expected_byte,
"invalid byte at index {i} for child {child_i}, key={k:?}, t={t:?}, offset={offset:?}offset={offset:?} nc={nc:?}, nb={nb:?}, num_extra_bytes={num_extra_bytes}, bytes_per_parts={bytes_per_parts}"
);
}
assert!(
child.next().is_none(),
"child {child_i} iterator not exhausted after yielding all bytes, key={k:?}, t={t:?}, offset={offset:?}"
);
assert_eq!(
child_byte_count, nb.0,
"child {child_i} produced {child_byte_count} bytes, expected {}, key={k:?}, t={t:?}, offset={offset:?}",
nb.0
);
}
for i in 0..bytes_per_parts {
let byte = forked_gen.next().unwrap();
let expected_byte = gen1.next().unwrap();
assert_eq!(byte, expected_byte, "post-fork bytes are not equal (byte index {i}), got {byte}, expected {expected_byte}, key={k:?}, t={t:?}, offset={offset:?}, offset={offset:?} nc={nc:?}, nb={nb:?}, num_extra_bytes={num_extra_bytes}, bytes_per_parts={bytes_per_parts}");
}
}
}
pub fn prop_different_offset_means_different_output<G: AesBlockCipher>() {
for _ in 0..10_000 {
let (t, nc, nb, num_extra_bytes) = any_valid_fork().next().unwrap();
let k = any_key().next().unwrap();
let fork_bytes = widening_mul(nc.0, nb.0);
let total_bytes = fork_bytes.saturating_add(num_extra_bytes);
let offset1 = AesIndex(rand::random());
let mut gen1 =
AesCtrGenerator::<G>::new(k, t, Bound::Excluded(t.increased(total_bytes)), offset1);
let offset2 = loop {
let offset2 = AesIndex(rand::random());
if offset1 != offset2 {
break offset2;
}
};
let mut gen2 =
AesCtrGenerator::<G>::new(k, t, Bound::Excluded(t.increased(total_bytes)), offset2);
let bytes = gen1.remaining_bytes().0.min(u128::from(u16::MAX));
let bytes_per_parts = (bytes / 2) as usize;
let mut slice1 = [0u8; 1024];
let mut slice2 = [0u8; 1024];
let n = bytes_per_parts.div_ceil(slice1.len());
let rest = bytes_per_parts % slice1.len();
for i in 0..n {
if i == n - 1 && rest < 16 {
continue;
}
slice1.fill(0);
slice2.fill(0);
for (o, b) in slice1.iter_mut().zip(gen1.by_ref()) {
*o = b;
}
for (o, b) in slice2.iter_mut().zip(gen2.by_ref()) {
*o = b;
}
assert_ne!(
slice1, slice2,
"pre-fork bytes slices are equal but they should not (slice index {i}), key={k:?}, t={t:?}, offset1={offset1:?}, offset2={offset2:?}"
);
}
for (mut child_1, mut child_2) in gen1
.try_fork(nc, nb)
.unwrap()
.zip(gen2.try_fork(nc, nb).unwrap())
{
let n = nb.0.div_ceil(slice1.len() as u64);
let rest = nb.0 % (slice1.len() as u64);
for i in 0..nb.0.div_ceil(slice1.len() as u64) {
if i == n - 1 && rest < 16 {
continue;
}
slice1.fill(0);
slice2.fill(0);
for (o, b) in slice1.iter_mut().zip(child_1.by_ref()) {
*o = b;
}
for (o, b) in slice2.iter_mut().zip(child_2.by_ref()) {
*o = b;
}
assert_ne!(
slice1, slice2,
"child bytes slices are equal but they should not (slice index {i}), key={k:?}, t={t:?}, offset1={offset1:?}, offset2={offset2:?}"
);
}
}
let n = bytes_per_parts.div_ceil(slice1.len());
let rest = bytes_per_parts % slice1.len();
for i in 0..n {
if i == n - 1 && rest < 16 {
continue;
}
slice1.fill(0);
slice2.fill(0);
for (o, b) in slice1.iter_mut().zip(gen1.by_ref()) {
*o = b;
}
for (o, b) in slice2.iter_mut().zip(gen2.by_ref()) {
*o = b;
}
assert_ne!(
slice1, slice2,
"post-fork bytes slices are equal but they should not (slice index {i}), key={k:?}, t={t:?}, offset1={offset1:?}, offset2={offset2:?}"
);
}
}
}
}