use alloc::string::String;
use alloc::vec;
use alloc::vec::Vec;
use p3_field::{
BasedVectorSpace, PrimeField, PrimeField32, absorb_radix_bits, max_absorb_injective_limbs,
reduce_packed, split_pf_to_field_order_limbs, squeeze_field_order_num_limbs,
};
use p3_symmetric::{CryptographicPermutation, Hash, MerkleCap};
use crate::{
CanFinalizeDigest, CanObserve, CanSample, CanSampleBits, DuplexChallenger, FieldChallenger,
};
#[derive(Clone, Debug)]
pub struct MultiField32Challenger<F, PF, P, const WIDTH: usize, const RATE: usize>
where
F: PrimeField32,
PF: PrimeField,
P: CryptographicPermutation<[PF; WIDTH]>,
{
inner: DuplexChallenger<PF, P, WIDTH, RATE>,
f_buffer: Vec<F>,
f_squeeze_buffer: Vec<F>,
}
impl<F, PF, P, const WIDTH: usize, const RATE: usize> MultiField32Challenger<F, PF, P, WIDTH, RATE>
where
F: PrimeField32,
PF: PrimeField,
P: CryptographicPermutation<[PF; WIDTH]>,
{
#[inline]
#[must_use]
pub const fn absorb_radix_bits(&self) -> u32 {
absorb_radix_bits::<F>()
}
#[inline]
#[must_use]
pub fn absorb_num_f_elms(&self) -> usize {
max_absorb_injective_limbs::<F, PF>()
}
#[inline]
#[must_use]
pub fn squeeze_num_f_elms(&self) -> usize {
squeeze_field_order_num_limbs::<PF, F>()
}
#[inline]
#[must_use]
pub const fn pending_f_squeeze_len(&self) -> usize {
self.f_squeeze_buffer.len()
}
pub fn new(permutation: P) -> Result<Self, String> {
if F::order() >= PF::order() {
return Err(String::from("F::order() must be less than PF::order()"));
}
if RATE >= WIDTH {
return Err(String::from("RATE must be less than WIDTH"));
}
Ok(Self {
inner: DuplexChallenger::new(permutation),
f_buffer: vec![],
f_squeeze_buffer: vec![],
})
}
fn flush_f_if_non_empty(&mut self) {
if self.f_buffer.is_empty() {
return;
}
let n_in = self.f_buffer.len();
let absorb_n = self.absorb_num_f_elms();
assert!(n_in <= absorb_n * RATE);
let rb = self.absorb_radix_bits();
let packed: Vec<PF> = self
.f_buffer
.chunks(absorb_n)
.map(|chunk| reduce_packed(chunk, rb))
.collect();
self.inner.absorb_rate_padded_with_tag(&packed, n_in as u8);
self.f_buffer.clear();
self.f_squeeze_buffer.clear();
}
fn refill_f_squeeze_from_inner(&mut self) {
self.f_squeeze_buffer.clear();
let squeeze_n = self.squeeze_num_f_elms();
for &pf in &self.inner.output_buffer {
self.f_squeeze_buffer
.extend(split_pf_to_field_order_limbs::<PF, F>(pf, squeeze_n));
}
self.inner.output_buffer.clear();
}
}
impl<F, PF, P, const WIDTH: usize, const RATE: usize> FieldChallenger<F>
for MultiField32Challenger<F, PF, P, WIDTH, RATE>
where
F: PrimeField32,
PF: PrimeField,
P: CryptographicPermutation<[PF; WIDTH]>,
{
}
impl<F, PF, P, const WIDTH: usize, const RATE: usize> CanObserve<F>
for MultiField32Challenger<F, PF, P, WIDTH, RATE>
where
F: PrimeField32,
PF: PrimeField,
P: CryptographicPermutation<[PF; WIDTH]>,
{
fn observe(&mut self, value: F) {
self.inner.output_buffer.clear();
self.f_squeeze_buffer.clear();
self.f_buffer.push(value);
if self.f_buffer.len() == self.absorb_num_f_elms() * RATE {
self.flush_f_if_non_empty();
}
}
}
impl<F, PF, const N: usize, P, const WIDTH: usize, const RATE: usize> CanObserve<[F; N]>
for MultiField32Challenger<F, PF, P, WIDTH, RATE>
where
F: PrimeField32,
PF: PrimeField,
P: CryptographicPermutation<[PF; WIDTH]>,
{
fn observe(&mut self, values: [F; N]) {
for value in values {
self.observe(value);
}
}
}
impl<F, PF, const N: usize, P, const WIDTH: usize, const RATE: usize> CanObserve<Hash<F, PF, N>>
for MultiField32Challenger<F, PF, P, WIDTH, RATE>
where
F: PrimeField32,
PF: PrimeField,
P: CryptographicPermutation<[PF; WIDTH]>,
{
fn observe(&mut self, values: Hash<F, PF, N>) {
self.inner.output_buffer.clear();
self.f_squeeze_buffer.clear();
self.flush_f_if_non_empty();
let words: &[PF; N] = values.as_ref();
for chunk in words.chunks(RATE) {
self.inner
.absorb_rate_padded_with_tag(chunk, chunk.len() as u8);
self.f_squeeze_buffer.clear();
}
}
}
impl<F, PF, const N: usize, P, const WIDTH: usize, const RATE: usize>
CanObserve<&MerkleCap<F, [PF; N]>> for MultiField32Challenger<F, PF, P, WIDTH, RATE>
where
F: PrimeField32,
PF: PrimeField,
P: CryptographicPermutation<[PF; WIDTH]>,
{
fn observe(&mut self, cap: &MerkleCap<F, [PF; N]>) {
for digest in cap.roots() {
self.observe(Hash::<F, PF, N>::from(*digest));
}
}
}
impl<F, PF, const N: usize, P, const WIDTH: usize, const RATE: usize>
CanObserve<MerkleCap<F, [PF; N]>> for MultiField32Challenger<F, PF, P, WIDTH, RATE>
where
F: PrimeField32,
PF: PrimeField,
P: CryptographicPermutation<[PF; WIDTH]>,
{
fn observe(&mut self, cap: MerkleCap<F, [PF; N]>) {
self.observe(&cap);
}
}
impl<F, PF, P, const WIDTH: usize, const RATE: usize> CanObserve<Vec<Vec<F>>>
for MultiField32Challenger<F, PF, P, WIDTH, RATE>
where
F: PrimeField32,
PF: PrimeField,
P: CryptographicPermutation<[PF; WIDTH]>,
{
fn observe(&mut self, valuess: Vec<Vec<F>>) {
for values in valuess {
for value in values {
self.observe(value);
}
}
}
}
impl<F, EF, PF, P, const WIDTH: usize, const RATE: usize> CanSample<EF>
for MultiField32Challenger<F, PF, P, WIDTH, RATE>
where
F: PrimeField32,
EF: BasedVectorSpace<F>,
PF: PrimeField,
P: CryptographicPermutation<[PF; WIDTH]>,
{
fn sample(&mut self) -> EF {
EF::from_basis_coefficients_fn(|_| {
self.flush_f_if_non_empty();
if self.f_squeeze_buffer.is_empty() {
if !self.inner.input_buffer.is_empty() || self.inner.output_buffer.is_empty() {
self.inner.duplexing();
}
self.refill_f_squeeze_from_inner();
}
self.f_squeeze_buffer
.pop()
.expect("Output buffer should be non-empty")
})
}
}
impl<F, PF, P, const WIDTH: usize, const RATE: usize> CanSampleBits<usize>
for MultiField32Challenger<F, PF, P, WIDTH, RATE>
where
F: PrimeField32,
PF: PrimeField,
P: CryptographicPermutation<[PF; WIDTH]>,
{
fn sample_bits(&mut self, bits: usize) -> usize {
assert!(bits < (usize::BITS as usize));
assert!((1 << bits) < F::ORDER_U32);
let rand_f: F = self.sample();
let rand_usize = rand_f.as_canonical_u32() as usize;
rand_usize & ((1 << bits) - 1)
}
}
impl<F, PF, P, const WIDTH: usize, const RATE: usize> CanFinalizeDigest
for MultiField32Challenger<F, PF, P, WIDTH, RATE>
where
F: PrimeField32,
PF: PrimeField,
P: CryptographicPermutation<[PF; WIDTH]>,
{
type Digest = [PF; RATE];
fn finalize(mut self) -> [PF; RATE] {
let had_pending_f = !self.f_buffer.is_empty();
self.flush_f_if_non_empty();
if !had_pending_f {
self.inner.duplexing();
}
self.inner.sponge_state[..RATE].try_into().unwrap()
}
}
#[cfg(test)]
mod tests {
use p3_baby_bear::BabyBear;
use p3_field::{
Field, PrimeCharacteristicRing, PrimeField, injective_pack_bits, split_pf_to_packed_limbs,
squeeze_field_order_num_limbs,
};
use p3_goldilocks::Goldilocks;
use p3_symmetric::Permutation;
use super::*;
const WIDTH: usize = 8;
const RATE: usize = 4;
type F = BabyBear;
type PF = Goldilocks;
#[derive(Clone)]
struct TestPermutation;
impl Permutation<[PF; WIDTH]> for TestPermutation {
fn permute_mut(&self, input: &mut [PF; WIDTH]) {
for (i, val) in input.iter_mut().enumerate() {
*val = PF::from_u8((i + 1) as u8);
}
}
}
impl CryptographicPermutation<[PF; WIDTH]> for TestPermutation {}
#[derive(Clone)]
struct MixingPermutation;
impl Permutation<[PF; WIDTH]> for MixingPermutation {
fn permute_mut(&self, input: &mut [PF; WIDTH]) {
let sum: PF = input.iter().copied().sum();
for (i, val) in input.iter_mut().enumerate() {
*val = sum + PF::from_u8((i + 1) as u8);
}
}
}
impl CryptographicPermutation<[PF; WIDTH]> for MixingPermutation {}
#[test]
fn test_packing() {
let c = MultiField32Challenger::<F, PF, _, WIDTH, RATE>::new(MixingPermutation).unwrap();
assert_eq!(c.absorb_radix_bits(), 31);
assert_eq!(c.absorb_num_f_elms(), 2);
assert_eq!(c.squeeze_num_f_elms(), 1);
assert_eq!(squeeze_field_order_num_limbs::<PF, F>(), 1);
}
#[test]
fn test_output_buffer_excludes_capacity() {
let permutation = TestPermutation;
let mut challenger =
MultiField32Challenger::<F, PF, _, WIDTH, RATE>::new(permutation).unwrap();
let squeeze_n = challenger.squeeze_num_f_elms();
let _: F = challenger.sample();
let expected_output_size = RATE * squeeze_n;
assert_eq!(
challenger.pending_f_squeeze_len(),
expected_output_size - 1,
"Pending F squeeze should be RATE * squeeze_num_f_elms minus one sample"
);
assert_eq!(
challenger.inner.output_buffer.len(),
0,
"After refill, inner PF output buffer is drained like popped F outputs"
);
}
#[test]
fn test_finalize() {
let new_chal =
|| MultiField32Challenger::<F, PF, _, WIDTH, RATE>::new(MixingPermutation).unwrap();
let mut c1 = new_chal();
let mut c2 = new_chal();
for i in 0..5u8 {
c1.observe(F::from_u8(i));
c2.observe(F::from_u8(i));
}
assert_eq!(c1.finalize(), c2.finalize());
let mut c1 = new_chal();
let mut c2 = new_chal();
for i in 0..5u8 {
c1.observe(F::from_u8(i));
c2.observe(F::from_u8(i + 1));
}
assert_ne!(c1.finalize(), c2.finalize());
}
#[test]
fn test_finalize_sample_interaction() {
let batch_size = {
let c =
MultiField32Challenger::<F, PF, _, WIDTH, RATE>::new(MixingPermutation).unwrap();
c.squeeze_num_f_elms() * RATE
};
let digest = |n_samples: usize| {
let mut c =
MultiField32Challenger::<F, PF, _, WIDTH, RATE>::new(MixingPermutation).unwrap();
for i in 0..3u8 {
c.observe(F::from_u8(i));
}
for _ in 0..n_samples {
let _: F = c.sample();
}
c.finalize()
};
assert_ne!(digest(0), digest(1));
assert_eq!(digest(1), digest(2));
assert_eq!(digest(1), digest(batch_size));
assert_ne!(digest(batch_size), digest(batch_size + 1));
assert_eq!(digest(batch_size + 1), digest(batch_size + 2));
}
#[test]
fn test_partial_absorb_length_distinct_from_padded_equivalent() {
let ne = {
let c =
MultiField32Challenger::<F, PF, _, WIDTH, RATE>::new(MixingPermutation).unwrap();
c.absorb_num_f_elms()
};
assert_eq!(ne, 2);
let mut a =
MultiField32Challenger::<F, PF, _, WIDTH, RATE>::new(MixingPermutation).unwrap();
a.observe(F::ONE);
let mut b =
MultiField32Challenger::<F, PF, _, WIDTH, RATE>::new(MixingPermutation).unwrap();
b.observe(F::ONE);
for _ in 1..ne {
b.observe(F::ZERO);
}
assert_ne!(a.finalize(), b.finalize());
}
#[test]
fn test_absorb_no_radix_overflow_collision() {
let mut a =
MultiField32Challenger::<F, PF, _, WIDTH, RATE>::new(MixingPermutation).unwrap();
a.observe(F::from_u32(1 << 30));
a.observe(F::ZERO);
let mut b =
MultiField32Challenger::<F, PF, _, WIDTH, RATE>::new(MixingPermutation).unwrap();
b.observe(F::ZERO);
b.observe(F::ONE);
assert_ne!(a.finalize(), b.finalize());
}
#[test]
fn test_duplexing_respects_rate() {
let permutation = TestPermutation;
let mut challenger =
MultiField32Challenger::<F, PF, _, WIDTH, RATE>::new(permutation).unwrap();
let absorb_n = challenger.absorb_num_f_elms();
for i in 0..(absorb_n * RATE) {
challenger.observe(F::from_u8(i as u8));
}
assert_eq!(
challenger.inner.output_buffer.len(),
RATE,
"After a full F batch flush, inner holds one rate row of PF elements"
);
assert_eq!(
challenger.pending_f_squeeze_len(),
0,
"F limbs are produced on sample() via split_pf_to_packed_limbs, not on observe"
);
}
#[test]
fn test_squeeze_covers_full_f_range() {
use p3_field::split_pf_to_field_order_limbs;
let pack_bits = injective_pack_bits::<F>();
let threshold = 1u32 << pack_bits;
let v_raw = F::ORDER_U32 as u64 + threshold as u64 + 1;
let pf_val = PF::from_u64(v_raw);
let limbs = split_pf_to_field_order_limbs::<PF, F>(pf_val, 1);
assert_eq!(limbs[0].as_canonical_u32(), threshold + 1);
assert!(
limbs[0].as_canonical_u32() > threshold,
"c0 must exceed the old base-2^30 ceiling"
);
}
#[test]
fn test_observe_hash_native_pf_high_bits_distinct() {
use num_bigint::BigUint;
use p3_bn254::Bn254;
use p3_field::split_pf_to_packed_limbs;
use p3_symmetric::Hash;
type PF254 = Bn254;
#[derive(Clone)]
struct Bn254MixingPermutation;
impl Permutation<[PF254; WIDTH]> for Bn254MixingPermutation {
fn permute_mut(&self, input: &mut [PF254; WIDTH]) {
let sum: PF254 = input.iter().copied().sum();
for (i, val) in input.iter_mut().enumerate() {
*val = sum + PF254::from_u8((i + 1) as u8);
}
}
}
impl CryptographicPermutation<[PF254; WIDTH]> for Bn254MixingPermutation {}
let pack_bits = injective_pack_bits::<F>();
let observe_n = PF254::bits().div_ceil(pack_bits as usize);
let a = PF254::from_biguint(BigUint::from(1u32)).unwrap();
let b = PF254::from_biguint(BigUint::from(1u32) + (BigUint::from(1u32) << 200)).unwrap();
assert_ne!(a, b);
let digest = |h: PF254| {
let mut c =
MultiField32Challenger::<F, PF254, _, WIDTH, RATE>::new(Bn254MixingPermutation)
.unwrap();
c.observe(Hash::<F, PF254, 1>::from([h]));
c.finalize()
};
assert_ne!(digest(a), digest(b));
let limbs_a = split_pf_to_packed_limbs::<PF254, F>(a, observe_n, pack_bits);
let limbs_b = split_pf_to_packed_limbs::<PF254, F>(b, observe_n, pack_bits);
assert_ne!(limbs_a, limbs_b);
let d_a = a.as_canonical_biguint().to_u64_digits();
let d_b = b.as_canonical_biguint().to_u64_digits();
let take3 = |d: &[u64]| {
let mut v = [0u64; 3];
for (i, x) in d.iter().take(3).enumerate() {
v[i] = *x;
}
v
};
assert_eq!(take3(&d_a), take3(&d_b));
}
#[test]
fn test_observe_hash_native_vs_expanded_f_not_equal() {
use p3_symmetric::Hash;
let g = PF::from_u64(123456789);
let mut native =
MultiField32Challenger::<F, PF, _, WIDTH, RATE>::new(MixingPermutation).unwrap();
native.observe(Hash::<F, PF, 1>::from([g]));
let mut via_f =
MultiField32Challenger::<F, PF, _, WIDTH, RATE>::new(MixingPermutation).unwrap();
let pb = injective_pack_bits::<F>();
let n = PF::bits().div_ceil(pb as usize);
for f in split_pf_to_packed_limbs::<PF, F>(g, n, pb) {
via_f.observe(f);
}
assert_ne!(native.finalize(), via_f.finalize());
}
#[test]
fn test_inner_sponge_matches_manual_absorb_chain() {
let mut m =
MultiField32Challenger::<F, PF, _, WIDTH, RATE>::new(MixingPermutation).unwrap();
for i in 0..8u8 {
m.observe(F::from_u8(i));
}
let d_m = m.inner.sponge_state;
let mut inner = DuplexChallenger::<PF, _, WIDTH, RATE>::new(MixingPermutation);
let packed: Vec<PF> = (0..8)
.step_by(2)
.map(|j| {
reduce_packed::<F, PF>(
&[F::from_u8(j), F::from_u8(j + 1)],
absorb_radix_bits::<F>(),
)
})
.collect();
inner.absorb_rate_padded_with_tag(&packed, 8);
assert_eq!(d_m, inner.sponge_state);
}
}