use alloc::vec;
use alloc::vec::Vec;
use p3_symmetric::CryptographicHasher;
use crate::{CanFinalizeDigest, CanObserve, CanSample};
#[derive(Clone, Debug)]
pub struct HashChallenger<T, H, const OUT_LEN: usize>
where
T: Clone,
H: CryptographicHasher<T, [T; OUT_LEN]>,
{
input_buffer: Vec<T>,
output_buffer: Vec<T>,
hasher: H,
}
impl<T, H, const OUT_LEN: usize> HashChallenger<T, H, OUT_LEN>
where
T: Clone,
H: CryptographicHasher<T, [T; OUT_LEN]>,
{
pub const fn new(initial_state: Vec<T>, hasher: H) -> Self {
Self {
input_buffer: initial_state,
output_buffer: vec![],
hasher,
}
}
fn flush(&mut self) {
let inputs = self.input_buffer.drain(..);
let output = self.hasher.hash_iter(inputs);
self.input_buffer.extend_from_slice(&output);
self.output_buffer = output.into();
}
}
impl<T, H, const OUT_LEN: usize> CanObserve<T> for HashChallenger<T, H, OUT_LEN>
where
T: Clone,
H: CryptographicHasher<T, [T; OUT_LEN]>,
{
fn observe(&mut self, value: T) {
self.output_buffer.clear();
self.input_buffer.push(value);
}
}
impl<T, H, const N: usize, const OUT_LEN: usize> CanObserve<[T; N]>
for HashChallenger<T, H, OUT_LEN>
where
T: Clone,
H: CryptographicHasher<T, [T; OUT_LEN]>,
{
fn observe(&mut self, values: [T; N]) {
if N == 0 {
return;
}
self.output_buffer.clear();
self.input_buffer.extend(values);
}
}
impl<T, H, const OUT_LEN: usize> CanSample<T> for HashChallenger<T, H, OUT_LEN>
where
T: Clone,
H: CryptographicHasher<T, [T; OUT_LEN]>,
{
fn sample(&mut self) -> T {
if self.output_buffer.is_empty() {
self.flush();
}
self.output_buffer
.pop()
.expect("Output buffer should be non-empty")
}
}
impl<T, H, const OUT_LEN: usize> CanFinalizeDigest for HashChallenger<T, H, OUT_LEN>
where
T: Clone,
H: CryptographicHasher<T, [T; OUT_LEN]>,
{
type Digest = [T; OUT_LEN];
fn finalize(mut self) -> [T; OUT_LEN] {
self.flush();
core::array::from_fn(|i| self.output_buffer[i].clone())
}
}
#[cfg(test)]
mod tests {
use p3_field::PrimeCharacteristicRing;
use p3_goldilocks::Goldilocks;
use super::*;
const OUT_LEN: usize = 2;
type F = Goldilocks;
#[derive(Clone)]
struct TestHasher {}
impl CryptographicHasher<F, [F; OUT_LEN]> for TestHasher {
fn hash_iter<I>(&self, input: I) -> [F; OUT_LEN]
where
I: IntoIterator<Item = F>,
{
let (sum, len) = input
.into_iter()
.fold((F::ZERO, 0_usize), |(acc_sum, acc_len), f| {
(acc_sum + f, acc_len + 1)
});
[sum, F::from_usize(len)]
}
fn hash_iter_slices<'a, I>(&self, input: I) -> [F; OUT_LEN]
where
I: IntoIterator<Item = &'a [F]>,
F: 'a,
{
let (sum, len) = input
.into_iter()
.fold((F::ZERO, 0_usize), |(acc_sum, acc_len), n| {
(
acc_sum + n.iter().fold(F::ZERO, |acc, f| acc + *f),
acc_len + n.len(),
)
});
[sum, F::from_usize(len)]
}
}
#[test]
fn test_hash_challenger() {
let initial_state = (1..11_u8).map(F::from_u8).collect::<Vec<_>>();
let test_hasher = TestHasher {};
let mut hash_challenger = HashChallenger::new(initial_state.clone(), test_hasher);
assert_eq!(hash_challenger.input_buffer, initial_state);
assert_eq!(hash_challenger.output_buffer, vec![]);
hash_challenger.flush();
let expected_sum = F::from_u8(55);
let expected_len = F::from_u8(10);
assert_eq!(
hash_challenger.input_buffer,
vec![expected_sum, expected_len]
);
assert_eq!(
hash_challenger.output_buffer,
vec![expected_sum, expected_len]
);
let new_element = F::from_u8(11);
hash_challenger.observe(new_element);
assert_eq!(
hash_challenger.input_buffer,
vec![expected_sum, expected_len, new_element]
);
assert_eq!(hash_challenger.output_buffer, vec![]);
let new_expected_len = 3;
let new_expected_sum = 76;
let new_element = hash_challenger.sample();
assert_eq!(new_element, F::from_u8(new_expected_len));
assert_eq!(
hash_challenger.output_buffer,
[F::from_u8(new_expected_sum)]
);
}
#[test]
fn test_hash_challenger_flush() {
let initial_state = (1..11_u8).map(F::from_u8).collect::<Vec<_>>();
let test_hasher = TestHasher {};
let mut hash_challenger = HashChallenger::new(initial_state, test_hasher);
let first_sample = hash_challenger.sample();
let second_sample = hash_challenger.sample();
assert_eq!(first_sample, F::from_u8(10));
assert_eq!(second_sample, F::from_u8(55));
assert!(hash_challenger.output_buffer.is_empty());
}
#[test]
fn test_observe_single_value() {
let test_hasher = TestHasher {};
let mut hash_challenger = HashChallenger::new(vec![F::from_u8(123)], test_hasher);
let value = F::from_u8(42);
hash_challenger.observe(value);
assert_eq!(
hash_challenger.input_buffer,
vec![F::from_u8(123), F::from_u8(42)]
);
assert!(hash_challenger.output_buffer.is_empty());
}
#[test]
fn test_observe_array() {
let test_hasher = TestHasher {};
let mut hash_challenger = HashChallenger::new(vec![F::from_u8(123)], test_hasher);
let values = [F::from_u8(1), F::from_u8(2), F::from_u8(3)];
hash_challenger.observe(values);
assert_eq!(
hash_challenger.input_buffer,
vec![F::from_u8(123), F::from_u8(1), F::from_u8(2), F::from_u8(3)]
);
assert!(hash_challenger.output_buffer.is_empty());
}
#[test]
fn test_sample_output_buffer() {
let test_hasher = TestHasher {};
let initial_state = vec![F::from_u8(5), F::from_u8(10)];
let mut hash_challenger = HashChallenger::new(initial_state, test_hasher);
let sample = hash_challenger.sample();
assert_eq!(sample, F::from_u8(2));
assert_eq!(hash_challenger.output_buffer, vec![F::from_u8(15)]);
}
#[test]
fn test_flush_empty_buffer() {
let test_hasher = TestHasher {};
let mut hash_challenger = HashChallenger::new(vec![], test_hasher);
hash_challenger.flush();
assert_eq!(hash_challenger.input_buffer, vec![F::ZERO, F::ZERO]);
assert_eq!(hash_challenger.output_buffer, vec![F::ZERO, F::ZERO]);
}
#[test]
fn test_flush_with_data() {
let test_hasher = TestHasher {};
let initial_state = vec![F::from_u8(1), F::from_u8(2)];
let mut hash_challenger = HashChallenger::new(initial_state, test_hasher);
hash_challenger.flush();
assert_eq!(
hash_challenger.input_buffer,
vec![F::from_u8(3), F::from_u8(2)]
);
assert_eq!(
hash_challenger.output_buffer,
vec![F::from_u8(3), F::from_u8(2)]
);
}
#[test]
fn test_sample_after_observe() {
let test_hasher = TestHasher {};
let initial_state = vec![F::from_u8(1), F::from_u8(2)];
let mut hash_challenger = HashChallenger::new(initial_state, test_hasher);
hash_challenger.observe(F::from_u8(3));
assert!(hash_challenger.output_buffer.is_empty());
assert_eq!(
hash_challenger.input_buffer,
vec![F::from_u8(1), F::from_u8(2), F::from_u8(3)]
);
let sample = hash_challenger.sample();
assert_eq!(sample, F::from_u8(3));
}
#[test]
fn test_sample_with_non_empty_output_buffer() {
let test_hasher = TestHasher {};
let mut hash_challenger = HashChallenger::new(vec![], test_hasher);
hash_challenger.output_buffer = vec![F::from_u8(42), F::from_u8(24)];
let sample = hash_challenger.sample();
assert_eq!(sample, F::from_u8(24));
assert_eq!(hash_challenger.output_buffer, vec![F::from_u8(42)]);
}
#[test]
fn test_finalize() {
let new_chal = || HashChallenger::new(vec![F::from_u8(1), F::from_u8(2)], TestHasher {});
let mut h1 = new_chal();
let mut h2 = new_chal();
h1.observe(F::from_u8(42));
h2.observe(F::from_u8(42));
assert_eq!(h1.finalize(), h2.finalize());
let mut h1 = new_chal();
let mut h2 = new_chal();
h1.observe(F::from_u8(1));
h2.observe(F::from_u8(2));
assert_ne!(h1.finalize(), h2.finalize());
}
#[test]
fn test_finalize_sample_interaction() {
let digest = |n_samples: usize| {
let mut c = HashChallenger::new(vec![F::from_u8(1), F::from_u8(2)], TestHasher {});
c.observe(F::from_u8(42));
for _ in 0..n_samples {
let _: F = c.sample();
}
c.finalize()
};
assert_ne!(digest(0), digest(1));
assert_eq!(digest(1), digest(OUT_LEN));
assert_ne!(digest(OUT_LEN), digest(OUT_LEN + 1));
assert_eq!(digest(OUT_LEN + 1), digest(2 * OUT_LEN));
}
#[test]
fn test_output_buffer_cleared_on_observe() {
let test_hasher = TestHasher {};
let mut hash_challenger = HashChallenger::new(vec![], test_hasher);
hash_challenger.output_buffer.push(F::from_u8(42));
assert!(!hash_challenger.output_buffer.is_empty());
hash_challenger.observe(F::from_u8(3));
assert!(hash_challenger.output_buffer.is_empty());
}
}