use alloc::vec;
use alloc::vec::Vec;
use core::error::Error;
use core::fmt::{Display, Formatter};
use p3_field::{BasedVectorSpace, Field, PrimeField, PrimeField64};
use p3_monty_31::{MontyField31, MontyParameters};
use p3_symmetric::{CryptographicPermutation, Hash, MerkleCap};
use crate::{
CanFinalizeDigest, CanObserve, CanSample, CanSampleBits, CanSampleUniformBits, FieldChallenger,
};
#[derive(Clone, Debug)]
pub struct DuplexChallenger<F, P, const WIDTH: usize, const RATE: usize>
where
F: Clone,
P: CryptographicPermutation<[F; WIDTH]>,
{
pub sponge_state: [F; WIDTH],
pub input_buffer: Vec<F>,
pub output_buffer: Vec<F>,
pub permutation: P,
}
impl<F, P, const WIDTH: usize, const RATE: usize> DuplexChallenger<F, P, WIDTH, RATE>
where
F: Copy,
P: CryptographicPermutation<[F; WIDTH]>,
{
pub fn new(permutation: P) -> Self
where
F: Default,
{
const {
assert!(RATE > 0 && RATE < WIDTH);
}
Self {
sponge_state: [F::default(); WIDTH],
input_buffer: vec![],
output_buffer: vec![],
permutation,
}
}
pub(crate) fn duplexing(&mut self) {
assert!(self.input_buffer.len() <= RATE);
for (i, val) in self.input_buffer.drain(..).enumerate() {
self.sponge_state[i] = val;
}
self.permutation.permute_mut(&mut self.sponge_state);
self.output_buffer.clear();
self.output_buffer.extend(&self.sponge_state[..RATE]);
}
}
impl<F, P, const WIDTH: usize, const RATE: usize> DuplexChallenger<F, P, WIDTH, RATE>
where
F: Copy + Default + PrimeField,
P: CryptographicPermutation<[F; WIDTH]>,
{
pub fn absorb_rate_padded_with_tag(&mut self, values: &[F], length_tag: u8) {
const {
assert!(
RATE < WIDTH,
"RATE must be less than WIDTH for capacity length slot"
);
}
assert!(values.len() <= RATE);
self.input_buffer.clear();
self.output_buffer.clear();
for (i, &v) in values.iter().enumerate() {
self.sponge_state[i] = v;
}
self.sponge_state[values.len()..RATE].fill(F::ZERO);
self.sponge_state[RATE] += F::from_u8(length_tag);
self.permutation.permute_mut(&mut self.sponge_state);
self.output_buffer
.extend_from_slice(&self.sponge_state[..RATE]);
}
}
impl<F, P, const WIDTH: usize, const RATE: usize> FieldChallenger<F>
for DuplexChallenger<F, P, WIDTH, RATE>
where
F: PrimeField64,
P: CryptographicPermutation<[F; WIDTH]>,
{
}
impl<F, P, const WIDTH: usize, const RATE: usize> CanObserve<F>
for DuplexChallenger<F, P, WIDTH, RATE>
where
F: Copy,
P: CryptographicPermutation<[F; WIDTH]>,
{
fn observe(&mut self, value: F) {
self.output_buffer.clear();
self.input_buffer.push(value);
if self.input_buffer.len() == RATE {
self.duplexing();
}
}
}
impl<F, P, const N: usize, const WIDTH: usize, const RATE: usize> CanObserve<[F; N]>
for DuplexChallenger<F, P, WIDTH, RATE>
where
F: Copy,
P: CryptographicPermutation<[F; WIDTH]>,
{
fn observe(&mut self, values: [F; N]) {
for value in values {
self.observe(value);
}
}
}
impl<F, P, const N: usize, const WIDTH: usize, const RATE: usize> CanObserve<Hash<F, F, N>>
for DuplexChallenger<F, P, WIDTH, RATE>
where
F: Copy,
P: CryptographicPermutation<[F; WIDTH]>,
{
fn observe(&mut self, values: Hash<F, F, N>) {
for value in values {
self.observe(value);
}
}
}
impl<F, P, const N: usize, const WIDTH: usize, const RATE: usize> CanObserve<&MerkleCap<F, [F; N]>>
for DuplexChallenger<F, P, WIDTH, RATE>
where
F: Copy,
P: CryptographicPermutation<[F; WIDTH]>,
{
fn observe(&mut self, cap: &MerkleCap<F, [F; N]>) {
for digest in cap.roots() {
for value in digest {
self.observe(*value);
}
}
}
}
impl<F, P, const N: usize, const WIDTH: usize, const RATE: usize> CanObserve<MerkleCap<F, [F; N]>>
for DuplexChallenger<F, P, WIDTH, RATE>
where
F: Copy,
P: CryptographicPermutation<[F; WIDTH]>,
{
fn observe(&mut self, cap: MerkleCap<F, [F; N]>) {
self.observe(&cap);
}
}
impl<F, P, const WIDTH: usize, const RATE: usize> CanObserve<Vec<Vec<F>>>
for DuplexChallenger<F, P, WIDTH, RATE>
where
F: Copy,
P: CryptographicPermutation<[F; WIDTH]>,
{
fn observe(&mut self, valuess: Vec<Vec<F>>) {
for values in valuess {
for value in values {
self.observe(value);
}
}
}
}
impl<F, EF, P, const WIDTH: usize, const RATE: usize> CanSample<EF>
for DuplexChallenger<F, P, WIDTH, RATE>
where
F: Field,
EF: BasedVectorSpace<F>,
P: CryptographicPermutation<[F; WIDTH]>,
{
fn sample(&mut self) -> EF {
EF::from_basis_coefficients_fn(|_| {
if !self.input_buffer.is_empty() || self.output_buffer.is_empty() {
self.duplexing();
}
self.output_buffer
.pop()
.expect("Output buffer should be non-empty")
})
}
}
impl<F, P, const WIDTH: usize, const RATE: usize> CanSampleBits<usize>
for DuplexChallenger<F, P, WIDTH, RATE>
where
F: PrimeField64,
P: CryptographicPermutation<[F; WIDTH]>,
{
fn sample_bits(&mut self, bits: usize) -> usize {
assert!(bits < (usize::BITS as usize));
assert!((1 << bits) < F::ORDER_U64);
let rand_f: F = self.sample();
let rand_usize = rand_f.as_canonical_u64() as usize;
rand_usize & ((1 << bits) - 1)
}
}
pub trait UniformSamplingField {
const MAX_SINGLE_SAMPLE_BITS: usize;
const SAMPLING_BITS_M: [u64; 64];
}
impl<MP> UniformSamplingField for MontyField31<MP>
where
MP: UniformSamplingField + MontyParameters,
{
const MAX_SINGLE_SAMPLE_BITS: usize = MP::MAX_SINGLE_SAMPLE_BITS;
const SAMPLING_BITS_M: [u64; 64] = MP::SAMPLING_BITS_M;
}
pub(super) struct ResampleOnRejection;
pub(super) struct ErrorOnRejection;
#[derive(Debug)]
pub struct ResamplingError {
value: u64,
m: u64,
}
impl Display for ResamplingError {
fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
write!(
f,
"Encountered value {0}, which requires resampling for uniform bits as it not smaller than {1}. But resampling is not enabled.",
self.value, self.m
)
}
}
impl Error for ResamplingError {}
pub(super) trait BitSamplingStrategy<F, P, const W: usize, const R: usize>
where
F: PrimeField64,
P: CryptographicPermutation<[F; W]>,
{
const ERROR_ON_REJECTION: bool;
#[inline]
fn sample_value(
challenger: &mut DuplexChallenger<F, P, W, R>,
m: u64,
) -> Result<F, ResamplingError> {
let mut result: F = challenger.sample();
if Self::ERROR_ON_REJECTION {
if result.as_canonical_u64() >= m {
return Err(ResamplingError {
value: result.as_canonical_u64(),
m,
});
}
} else {
while result.as_canonical_u64() >= m {
result = challenger.sample();
}
}
Ok(result)
}
}
impl<F, P, const W: usize, const R: usize> BitSamplingStrategy<F, P, W, R> for ResampleOnRejection
where
F: PrimeField64,
P: CryptographicPermutation<[F; W]>,
{
const ERROR_ON_REJECTION: bool = false;
}
impl<F, P, const W: usize, const R: usize> BitSamplingStrategy<F, P, W, R> for ErrorOnRejection
where
F: PrimeField64,
P: CryptographicPermutation<[F; W]>,
{
const ERROR_ON_REJECTION: bool = true;
}
impl<F, P, const WIDTH: usize, const RATE: usize> DuplexChallenger<F, P, WIDTH, RATE>
where
F: UniformSamplingField + PrimeField64,
P: CryptographicPermutation<[F; WIDTH]>,
{
#[inline]
fn sample_uniform_bits_with_strategy<S>(
&mut self,
bits: usize,
) -> Result<usize, ResamplingError>
where
S: BitSamplingStrategy<F, P, WIDTH, RATE>,
{
if bits == 0 {
return Ok(0);
};
assert!(bits < usize::BITS as usize, "bit count must be valid");
assert!(
(1u64 << bits) < F::ORDER_U64,
"bit count exceeds field order"
);
let m = F::SAMPLING_BITS_M[bits];
if bits <= F::MAX_SINGLE_SAMPLE_BITS {
let rand_f = S::sample_value(self, m);
Ok(rand_f?.as_canonical_u64() as usize & ((1 << bits) - 1))
} else {
let half_bits1 = bits / 2;
let half_bits2 = bits - half_bits1;
let rand1 = S::sample_value(self, F::SAMPLING_BITS_M[half_bits1]);
let chunk1 = rand1?.as_canonical_u64() as usize & ((1 << half_bits1) - 1);
let rand2 = S::sample_value(self, F::SAMPLING_BITS_M[half_bits2]);
let chunk2 = rand2?.as_canonical_u64() as usize & ((1 << half_bits2) - 1);
Ok(chunk1 | (chunk2 << half_bits1))
}
}
}
impl<F, P, const WIDTH: usize, const RATE: usize> CanSampleUniformBits<F>
for DuplexChallenger<F, P, WIDTH, RATE>
where
F: UniformSamplingField + PrimeField64,
P: CryptographicPermutation<[F; WIDTH]>,
{
fn sample_uniform_bits<const RESAMPLE: bool>(
&mut self,
bits: usize,
) -> Result<usize, ResamplingError> {
if RESAMPLE {
self.sample_uniform_bits_with_strategy::<ResampleOnRejection>(bits)
} else {
self.sample_uniform_bits_with_strategy::<ErrorOnRejection>(bits)
}
}
}
impl<F, P, const WIDTH: usize, const RATE: usize> CanFinalizeDigest
for DuplexChallenger<F, P, WIDTH, RATE>
where
F: Copy,
P: CryptographicPermutation<[F; WIDTH]>,
{
type Digest = [F; RATE];
fn finalize(mut self) -> [F; RATE] {
self.duplexing();
self.sponge_state[..RATE].try_into().unwrap()
}
}
#[cfg(test)]
mod tests {
use core::iter;
use p3_baby_bear::BabyBear;
use p3_field::PrimeCharacteristicRing;
use p3_field::extension::BinomialExtensionField;
use p3_goldilocks::Goldilocks;
use p3_symmetric::Permutation;
use super::*;
use crate::grinding_challenger::GrindingChallenger;
const WIDTH: usize = 24;
const RATE: usize = 16;
type G = Goldilocks;
type EF2G = BinomialExtensionField<G, 2>;
type BB = BabyBear;
#[derive(Clone)]
struct TestPermutation {}
impl<F: Clone> Permutation<[F; WIDTH]> for TestPermutation {
fn permute_mut(&self, input: &mut [F; WIDTH]) {
input.reverse();
}
}
impl<F: Clone> CryptographicPermutation<[F; WIDTH]> for TestPermutation {}
#[test]
fn test_duplex_challenger() {
type Chal = DuplexChallenger<G, TestPermutation, WIDTH, RATE>;
let permutation = TestPermutation {};
let mut duplex_challenger = DuplexChallenger::new(permutation);
(0..12).for_each(|element| duplex_challenger.observe(G::from_u8(element as u8)));
let state_after_duplexing: Vec<_> = iter::repeat_n(G::ZERO, 12)
.chain((0..12).map(G::from_u8).rev())
.collect();
let expected_samples: Vec<G> = state_after_duplexing[..16].iter().copied().rev().collect();
let samples = <Chal as CanSample<G>>::sample_vec(&mut duplex_challenger, 16);
assert_eq!(samples, expected_samples);
}
#[test]
#[should_panic]
fn test_duplex_challenger_sample_bits_security() {
type GoldilocksChal = DuplexChallenger<G, TestPermutation, WIDTH, RATE>;
let permutation = TestPermutation {};
let mut duplex_challenger = GoldilocksChal::new(permutation);
for _ in 0..100 {
assert!(duplex_challenger.sample_bits(129) < 4);
}
}
#[test]
#[should_panic]
fn test_duplex_challenger_sample_bits_security_small_field() {
type BabyBearChal = DuplexChallenger<BB, TestPermutation, WIDTH, RATE>;
let permutation = TestPermutation {};
let mut duplex_challenger = BabyBearChal::new(permutation);
for _ in 0..100 {
assert!(duplex_challenger.sample_bits(40) < 1 << 31);
}
}
#[test]
#[should_panic]
fn test_duplex_challenger_grind_security() {
type GoldilocksChal = DuplexChallenger<G, TestPermutation, WIDTH, RATE>;
let permutation = TestPermutation {};
let mut duplex_challenger = GoldilocksChal::new(permutation);
let too_many_bits = usize::BITS as usize;
let witness = duplex_challenger.grind(too_many_bits);
assert!(duplex_challenger.check_witness(too_many_bits, witness));
}
#[test]
fn test_observe_single_value() {
let mut chal = DuplexChallenger::<G, TestPermutation, WIDTH, RATE>::new(TestPermutation {});
chal.observe(G::from_u8(42));
assert_eq!(chal.input_buffer, vec![G::from_u8(42)]);
assert!(chal.output_buffer.is_empty());
}
#[test]
fn test_observe_array_of_values() {
let mut chal = DuplexChallenger::<G, TestPermutation, WIDTH, RATE>::new(TestPermutation {});
chal.observe([G::from_u8(1), G::from_u8(2), G::from_u8(3)]);
assert_eq!(
chal.input_buffer,
vec![G::from_u8(1), G::from_u8(2), G::from_u8(3)]
);
assert!(chal.output_buffer.is_empty());
}
#[test]
fn test_observe_hash_array() {
let mut chal = DuplexChallenger::<G, TestPermutation, WIDTH, RATE>::new(TestPermutation {});
let hash = Hash::<G, G, 4>::from([G::from_u8(10); 4]);
chal.observe(hash);
assert_eq!(chal.input_buffer, vec![G::from_u8(10); 4]);
}
#[test]
fn test_observe_nested_vecs() {
let mut chal = DuplexChallenger::<G, TestPermutation, WIDTH, RATE>::new(TestPermutation {});
chal.observe(vec![
vec![G::from_u8(1), G::from_u8(2)],
vec![G::from_u8(3)],
]);
assert_eq!(
chal.input_buffer,
vec![G::from_u8(1), G::from_u8(2), G::from_u8(3)]
);
}
#[test]
fn test_sample_triggers_duplex() {
let mut chal = DuplexChallenger::<G, TestPermutation, WIDTH, RATE>::new(TestPermutation {});
chal.observe(G::from_u8(5));
assert!(chal.output_buffer.is_empty());
let _sample: G = chal.sample();
assert!(!chal.output_buffer.is_empty());
}
#[test]
fn test_sample_multiple_extension_field() {
use p3_field::extension::BinomialExtensionField;
type EF = BinomialExtensionField<G, 2>;
let mut chal = DuplexChallenger::<G, TestPermutation, WIDTH, RATE>::new(TestPermutation {});
chal.observe(G::from_u8(1));
chal.observe(G::from_u8(2));
let _: EF = chal.sample();
let _: EF = chal.sample();
}
#[test]
fn test_sample_bits_within_bounds() {
let mut chal = DuplexChallenger::<G, TestPermutation, WIDTH, RATE>::new(TestPermutation {});
for i in 0..RATE {
chal.observe(G::from_u8(i as u8));
}
let bits = 3;
let value = chal.sample_bits(bits);
let expected = G::ZERO.as_canonical_u64() as usize & ((1 << bits) - 1);
assert_eq!(value, expected);
}
#[test]
fn test_sample_bits_trigger_duplex_when_empty() {
let mut chal = DuplexChallenger::<G, TestPermutation, WIDTH, RATE>::new(TestPermutation {});
assert_eq!(chal.input_buffer.len(), 0);
assert_eq!(chal.output_buffer.len(), 0);
let bits = 2;
let sample = chal.sample_bits(bits);
let expected = G::ZERO.as_canonical_u64() as usize & ((1 << bits) - 1);
assert_eq!(sample, expected);
}
#[test]
fn test_output_buffer_pops_correctly() {
let mut chal = DuplexChallenger::<G, TestPermutation, WIDTH, RATE>::new(TestPermutation {});
for i in 0..RATE {
chal.observe(G::from_u8(i as u8));
}
let expected = [
G::from_u8(0),
G::from_u8(0),
G::from_u8(0),
G::from_u8(0),
G::from_u8(0),
G::from_u8(0),
G::from_u8(0),
G::from_u8(0),
G::from_u8(15),
G::from_u8(14),
G::from_u8(13),
G::from_u8(12),
G::from_u8(11),
G::from_u8(10),
G::from_u8(9),
G::from_u8(8),
]
.to_vec();
assert_eq!(chal.output_buffer, expected);
let first: G = chal.sample();
let second: G = chal.sample();
assert_eq!(first, G::from_u8(8));
assert_eq!(second, G::from_u8(9));
}
#[test]
fn test_duplexing_only_when_needed() {
let mut chal = DuplexChallenger::<G, TestPermutation, WIDTH, RATE>::new(TestPermutation {});
chal.output_buffer = vec![G::from_u8(10), G::from_u8(20)];
let sample: G = chal.sample();
assert_eq!(sample, G::from_u8(20));
assert_eq!(chal.output_buffer, vec![G::from_u8(10)]);
}
#[test]
fn test_flush_when_input_full() {
let mut chal = DuplexChallenger::<G, TestPermutation, WIDTH, RATE>::new(TestPermutation {});
for i in 0..RATE {
chal.observe(G::from_u8(i as u8));
}
let expected_output = [
G::from_u8(0),
G::from_u8(0),
G::from_u8(0),
G::from_u8(0),
G::from_u8(0),
G::from_u8(0),
G::from_u8(0),
G::from_u8(0),
G::from_u8(15),
G::from_u8(14),
G::from_u8(13),
G::from_u8(12),
G::from_u8(11),
G::from_u8(10),
G::from_u8(9),
G::from_u8(8),
]
.to_vec();
assert!(chal.input_buffer.is_empty());
assert_eq!(chal.output_buffer, expected_output);
}
#[test]
fn test_observe_base_as_algebra_element_consistency_with_direct_observe() {
let mut chal1 =
DuplexChallenger::<G, TestPermutation, WIDTH, RATE>::new(TestPermutation {});
let mut chal2 =
DuplexChallenger::<G, TestPermutation, WIDTH, RATE>::new(TestPermutation {});
let base_val = G::from_u8(99);
chal1.observe_base_as_algebra_element::<EF2G>(base_val);
let ext_val = EF2G::from(base_val);
chal2.observe_algebra_element(ext_val);
assert_eq!(chal1.input_buffer, chal2.input_buffer);
assert_eq!(chal1.output_buffer, chal2.output_buffer);
assert_eq!(chal1.sponge_state, chal2.sponge_state);
}
#[test]
fn test_observe_base_as_algebra_element_stream_consistency() {
let mut chal1 =
DuplexChallenger::<G, TestPermutation, WIDTH, RATE>::new(TestPermutation {});
let mut chal2 =
DuplexChallenger::<G, TestPermutation, WIDTH, RATE>::new(TestPermutation {});
let base_values: Vec<_> = (0u8..25).map(G::from_u8).collect();
for &val in &base_values {
chal1.observe_base_as_algebra_element::<EF2G>(val);
}
for &val in &base_values {
let ext_val = EF2G::from(val);
chal2.observe_algebra_element(ext_val);
}
assert_eq!(chal1.input_buffer, chal2.input_buffer);
assert_eq!(chal1.output_buffer, chal2.output_buffer);
assert_eq!(chal1.sponge_state, chal2.sponge_state);
let sample1: EF2G = chal1.sample_algebra_element();
let sample2: EF2G = chal2.sample_algebra_element();
assert_eq!(sample1, sample2);
assert_eq!(chal1.input_buffer, chal2.input_buffer);
assert_eq!(chal1.output_buffer, chal2.output_buffer);
assert_eq!(chal1.sponge_state, chal2.sponge_state);
}
#[test]
fn test_observe_algebra_elements_equivalence() {
let mut chal1 =
DuplexChallenger::<G, TestPermutation, WIDTH, RATE>::new(TestPermutation {});
let mut chal2 =
DuplexChallenger::<G, TestPermutation, WIDTH, RATE>::new(TestPermutation {});
let ext_values: Vec<EF2G> = (0u8..10).map(|i| EF2G::from(G::from_u8(i))).collect();
chal1.observe_algebra_slice(&ext_values);
for ext_val in &ext_values {
chal2.observe_algebra_element(*ext_val);
}
assert_eq!(chal1.input_buffer, chal2.input_buffer);
assert_eq!(chal1.output_buffer, chal2.output_buffer);
assert_eq!(chal1.sponge_state, chal2.sponge_state);
let sample1: EF2G = chal1.sample_algebra_element();
let sample2: EF2G = chal2.sample_algebra_element();
assert_eq!(sample1, sample2);
}
#[test]
fn test_observe_algebra_elements_empty_slice() {
let mut chal1 =
DuplexChallenger::<G, TestPermutation, WIDTH, RATE>::new(TestPermutation {});
let mut chal2 =
DuplexChallenger::<G, TestPermutation, WIDTH, RATE>::new(TestPermutation {});
chal1.observe(G::from_u8(42));
chal2.observe(G::from_u8(42));
let empty: Vec<EF2G> = vec![];
chal1.observe_algebra_slice(&empty);
assert_eq!(chal1.input_buffer, chal2.input_buffer);
assert_eq!(chal1.output_buffer, chal2.output_buffer);
assert_eq!(chal1.sponge_state, chal2.sponge_state);
}
#[test]
fn test_observe_algebra_elements_triggers_duplexing() {
let mut chal = DuplexChallenger::<G, TestPermutation, WIDTH, RATE>::new(TestPermutation {});
let ext_values: Vec<EF2G> = (0u8..8).map(|i| EF2G::from(G::from_u8(i))).collect();
assert!(chal.input_buffer.is_empty());
assert!(chal.output_buffer.is_empty());
chal.observe_algebra_slice(&ext_values);
assert!(chal.input_buffer.is_empty());
assert!(!chal.output_buffer.is_empty());
}
#[test]
fn test_finalize() {
let new_chal = || DuplexChallenger::<G, _, WIDTH, RATE>::new(TestPermutation {});
let mut c1 = new_chal();
let mut c2 = new_chal();
for i in 0..5u8 {
c1.observe(G::from_u8(i));
c2.observe(G::from_u8(i));
}
assert_eq!(c1.finalize(), c2.finalize());
let mut c1 = new_chal();
let mut c2 = new_chal();
for i in 0..10u8 {
c1.observe(G::from_u8(i));
c2.observe(G::from_u8(i + 1));
}
assert_ne!(c1.finalize(), c2.finalize());
}
#[test]
fn test_finalize_sample_interaction() {
let digest = |n_samples: usize| {
let mut c = DuplexChallenger::<G, _, WIDTH, RATE>::new(TestPermutation {});
for i in 0..5u8 {
c.observe(G::from_u8(i));
}
for _ in 0..n_samples {
let _: G = c.sample();
}
c.finalize()
};
assert_ne!(digest(0), digest(1));
assert_eq!(digest(1), digest(2));
assert_eq!(digest(1), digest(RATE));
assert_ne!(digest(RATE), digest(RATE + 1));
assert_eq!(digest(RATE + 1), digest(RATE + 2));
}
}