use std::mem::size_of;
use ark_ff::Field;
use derive_more::AsRef;
use derive_where::derive_where;
use rs_merkle::Hasher;
pub trait ReseedableRng {
type Seed;
fn reseed(&mut self, seed: Self::Seed);
fn next_bytes<T, F>(&mut self, f: F) -> T
where
F: FnOnce(&[u8]) -> T;
fn draw_alpha<F: Field>(&mut self) -> F {
loop {
if let Some(alpha) = self.next_bytes(|bytes| F::from_random_bytes(bytes)) {
return alpha;
}
}
}
fn draw_positions(&mut self, count: usize, domain_size: usize) -> Vec<usize> {
debug_assert!(
domain_size.is_power_of_two(),
"Domain size must be a power of two"
);
let mask = domain_size - 1;
let mut positions = Vec::with_capacity(count);
for _ in 0..count {
let number = self.next_bytes(|bytes| {
usize::from_le_bytes(bytes[0..size_of::<usize>()].try_into().unwrap())
});
positions.push(number & mask);
}
positions
}
}
#[derive_where(Clone, PartialEq;)]
#[derive_where(Debug; H::Hash)]
pub struct FriChallenger<H: Hasher> {
seed: H::Hash,
counter: usize,
}
impl<H: Hasher> Default for FriChallenger<H> {
fn default() -> Self {
Self {
seed: H::hash(&[]),
counter: 0,
}
}
}
impl<H: Hasher> ReseedableRng for FriChallenger<H>
where
H::Hash: AsRef<[u8]>,
{
type Seed = H::Hash;
fn reseed(&mut self, seed: Self::Seed) {
self.seed = H::concat_and_hash(&self.seed, Some(&seed));
self.counter = 0;
}
fn next_bytes<T, F>(&mut self, f: F) -> T
where
F: FnOnce(&[u8]) -> T,
{
self.counter += 1;
let hash = H::concat_and_hash(&self.seed, Some(&H::hash(&self.counter.to_le_bytes())));
f(hash.as_ref())
}
}
impl<H: Hasher> FriChallenger<H> {
pub fn reset(&mut self) {
self.seed = H::hash(&[]);
self.counter = 0;
}
}
#[derive(AsRef, Default, Clone, PartialEq, Eq, Debug)]
pub struct MemoryRng<R> {
#[as_ref]
inner: R,
last_positions: Vec<usize>,
}
impl<R> From<R> for MemoryRng<R> {
fn from(value: R) -> Self {
Self {
inner: value,
last_positions: vec![],
}
}
}
impl<R: ReseedableRng> ReseedableRng for MemoryRng<R> {
type Seed = R::Seed;
fn reseed(&mut self, seed: Self::Seed) {
self.inner.reseed(seed);
}
fn next_bytes<T, F>(&mut self, f: F) -> T
where
F: FnOnce(&[u8]) -> T,
{
self.inner.next_bytes(f)
}
fn draw_positions(&mut self, count: usize, domain_size: usize) -> Vec<usize> {
self.last_positions = self.inner.draw_positions(count, domain_size);
self.last_positions.clone()
}
}
impl<R> MemoryRng<R> {
pub fn last_positions(&self) -> &[usize] {
&self.last_positions
}
pub fn last_positions_mut(&mut self) -> &mut Vec<usize> {
&mut self.last_positions
}
pub fn into_inner(self) -> R {
self.inner
}
}
impl<'a, R> ReseedableRng for &'a mut R
where
R: ReseedableRng + ?Sized,
{
type Seed = R::Seed;
fn reseed(&mut self, seed: Self::Seed) {
(**self).reseed(seed);
}
fn next_bytes<T, F>(&mut self, f: F) -> T
where
F: FnOnce(&[u8]) -> T,
{
(**self).next_bytes(f)
}
fn draw_alpha<F: Field>(&mut self) -> F {
(**self).draw_alpha()
}
fn draw_positions(&mut self, count: usize, domain_size: usize) -> Vec<usize> {
(**self).draw_positions(count, domain_size)
}
}