#[cfg(not(feature = "std"))]
use alloc::{vec, vec::Vec};
use core::marker::PhantomData;
use plonky2_field::extension::Extendable;
pub use qp_plonky2_core::Challenger;
use crate::hash::hash_types::{HashOutTarget, MerkleCapTarget, RichField};
use crate::hash::hashing::PlonkyPermutation;
use crate::iop::ext_target::ExtensionTarget;
use crate::iop::target::Target;
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::config::AlgebraicHasher;
#[derive(Debug)]
pub struct RecursiveChallenger<F: RichField + Extendable<D>, H: AlgebraicHasher<F>, const D: usize>
{
sponge_state: H::AlgebraicPermutation,
input_buffer: Vec<Target>,
output_buffer: Vec<Target>,
__: PhantomData<(F, H)>,
}
impl<F: RichField + Extendable<D>, H: AlgebraicHasher<F>, const D: usize>
RecursiveChallenger<F, H, D>
{
pub fn new(builder: &mut CircuitBuilder<F, D>) -> Self {
let zero = builder.zero();
Self {
sponge_state: H::AlgebraicPermutation::new(core::iter::repeat(zero)),
input_buffer: Vec::new(),
output_buffer: Vec::new(),
__: PhantomData,
}
}
pub fn from_state(sponge_state: H::AlgebraicPermutation) -> Self {
Self {
sponge_state,
input_buffer: vec![],
output_buffer: vec![],
__: PhantomData,
}
}
pub fn observe_element(&mut self, target: Target) {
self.output_buffer.clear();
self.input_buffer.push(target);
}
pub fn observe_elements(&mut self, targets: &[Target]) {
for &target in targets {
self.observe_element(target);
}
}
pub fn observe_hash(&mut self, hash: &HashOutTarget) {
self.observe_elements(&hash.elements)
}
pub fn observe_cap(&mut self, cap: &MerkleCapTarget) {
for hash in &cap.0 {
self.observe_hash(hash)
}
}
pub fn observe_extension_element(&mut self, element: ExtensionTarget<D>) {
self.observe_elements(&element.0);
}
pub fn observe_extension_elements(&mut self, elements: &[ExtensionTarget<D>]) {
for &element in elements {
self.observe_extension_element(element);
}
}
pub fn get_challenge(&mut self, builder: &mut CircuitBuilder<F, D>) -> Target {
self.absorb_buffered_inputs(builder);
if self.output_buffer.is_empty() {
self.sponge_state = builder.permute::<H>(self.sponge_state);
self.output_buffer = self.sponge_state.squeeze().to_vec();
}
self.output_buffer
.pop()
.expect("Output buffer should be non-empty")
}
pub fn get_n_challenges(
&mut self,
builder: &mut CircuitBuilder<F, D>,
n: usize,
) -> Vec<Target> {
(0..n).map(|_| self.get_challenge(builder)).collect()
}
pub fn get_hash(&mut self, builder: &mut CircuitBuilder<F, D>) -> HashOutTarget {
HashOutTarget {
elements: [
self.get_challenge(builder),
self.get_challenge(builder),
self.get_challenge(builder),
self.get_challenge(builder),
],
}
}
pub fn get_extension_challenge(
&mut self,
builder: &mut CircuitBuilder<F, D>,
) -> ExtensionTarget<D> {
self.get_n_challenges(builder, D).try_into().unwrap()
}
pub fn get_n_extension_challenges(
&mut self,
builder: &mut CircuitBuilder<F, D>,
n: usize,
) -> Vec<ExtensionTarget<D>> {
(0..n)
.map(|_| self.get_extension_challenge(builder))
.collect()
}
fn absorb_buffered_inputs(&mut self, builder: &mut CircuitBuilder<F, D>) {
if self.input_buffer.is_empty() {
return;
}
for input_chunk in self.input_buffer.chunks(H::AlgebraicPermutation::RATE) {
self.sponge_state.set_from_slice(input_chunk, 0);
self.sponge_state = builder.permute::<H>(self.sponge_state);
}
self.output_buffer = self.sponge_state.squeeze().to_vec();
self.input_buffer.clear();
}
pub fn compact(&mut self, builder: &mut CircuitBuilder<F, D>) -> H::AlgebraicPermutation {
self.absorb_buffered_inputs(builder);
self.output_buffer.clear();
self.sponge_state
}
}
#[cfg(test)]
#[cfg(feature = "rand")]
mod tests {
#[cfg(not(feature = "std"))]
use alloc::vec::Vec;
use plonky2_field::types::Sample;
use crate::iop::challenger::{Challenger, RecursiveChallenger};
use crate::iop::generator::generate_partial_witness;
use crate::iop::target::Target;
use crate::iop::witness::{PartialWitness, Witness};
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::circuit_data::CircuitConfig;
use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig};
#[test]
fn no_duplicate_challenges() {
const D: usize = 2;
type C = PoseidonGoldilocksConfig;
type F = <C as GenericConfig<D>>::F;
let mut challenger = Challenger::<F, <C as GenericConfig<D>>::InnerHasher>::new();
let mut challenges = Vec::new();
for i in 1..10 {
challenges.extend(challenger.get_n_challenges(i));
challenger.observe_element(F::rand());
}
let dedup_challenges = {
let mut dedup = challenges.clone();
dedup.dedup();
dedup
};
assert_eq!(dedup_challenges, challenges);
}
#[test]
fn test_consistency() {
const D: usize = 2;
type C = PoseidonGoldilocksConfig;
type F = <C as GenericConfig<D>>::F;
let num_inputs_per_round = [2, 5, 3];
let num_outputs_per_round = [1, 2, 4];
let inputs_per_round: Vec<Vec<F>> = num_inputs_per_round
.iter()
.map(|&n| F::rand_vec(n))
.collect();
let mut challenger = Challenger::<F, <C as GenericConfig<D>>::InnerHasher>::new();
let mut outputs_per_round: Vec<Vec<F>> = Vec::new();
for (r, inputs) in inputs_per_round.iter().enumerate() {
challenger.observe_elements(inputs);
outputs_per_round.push(challenger.get_n_challenges(num_outputs_per_round[r]));
}
let config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::<F, D>::new(config);
let mut recursive_challenger =
RecursiveChallenger::<F, <C as GenericConfig<D>>::InnerHasher, D>::new(&mut builder);
let mut recursive_outputs_per_round: Vec<Vec<Target>> = Vec::new();
for (r, inputs) in inputs_per_round.iter().enumerate() {
recursive_challenger.observe_elements(&builder.constants(inputs));
recursive_outputs_per_round.push(
recursive_challenger.get_n_challenges(&mut builder, num_outputs_per_round[r]),
);
}
let circuit = builder.build::<C>();
let inputs = PartialWitness::new();
let witness =
generate_partial_witness(inputs, &circuit.prover_only, &circuit.common).unwrap();
let recursive_output_values_per_round: Vec<Vec<F>> = recursive_outputs_per_round
.iter()
.map(|outputs| witness.get_targets(outputs))
.collect();
assert_eq!(outputs_per_round, recursive_output_values_per_round);
}
}