use crate::{
frontend::{
gadgets::poseidon::{
Arity, Elt, IOPattern, PoseidonConstants, Simplex, Sponge, SpongeAPI, SpongeCircuit,
SpongeOp, SpongeTrait, Strength,
},
num::AllocatedNum,
AllocatedBit, Boolean, ConstraintSystem, SynthesisError,
},
traits::{ROCircuitTrait, ROMode, ROTrait},
};
use ff::{PrimeField, PrimeFieldBits};
use generic_array::typenum::{U24, U5};
use serde::{Deserialize, Serialize};
#[derive(Clone, PartialEq, Serialize, Deserialize)]
pub struct PoseidonConstantsCircuit<Scalar: PrimeField> {
wide: PoseidonConstants<Scalar, U24>,
narrow: PoseidonConstants<Scalar, U5>,
}
impl<Scalar: PrimeField> Default for PoseidonConstantsCircuit<Scalar> {
fn default() -> Self {
Self {
wide: Sponge::<Scalar, U24>::api_constants(Strength::Standard),
narrow: Sponge::<Scalar, U5>::api_constants(Strength::Standard),
}
}
}
#[derive(Serialize, Deserialize)]
pub struct PoseidonRO<Base: PrimeField> {
state: Vec<Base>,
constants: PoseidonConstantsCircuit<Base>,
mode: ROMode,
}
fn poseidon_squeeze_native<Base: PrimeField, A: Arity<Base>>(
constants: &PoseidonConstants<Base, A>,
state: &[Base],
) -> Base {
let mut sponge = Sponge::new_with_constants(constants, Simplex);
let acc = &mut ();
let parameter = IOPattern(vec![
SpongeOp::Absorb(state.len() as u32),
SpongeOp::Squeeze(1u32),
]);
sponge.start(parameter, None, acc);
SpongeAPI::absorb(&mut sponge, state.len() as u32, state, acc);
let hash = SpongeAPI::squeeze(&mut sponge, 1, acc);
sponge.finish(acc).unwrap();
hash[0]
}
impl<Base> ROTrait<Base> for PoseidonRO<Base>
where
Base: PrimeField + PrimeFieldBits + Serialize + for<'de> Deserialize<'de>,
{
type CircuitRO = PoseidonROCircuit<Base>;
type Constants = PoseidonConstantsCircuit<Base>;
fn new(constants: PoseidonConstantsCircuit<Base>) -> Self {
Self {
state: Vec::new(),
constants,
mode: ROMode::Wide,
}
}
fn new_with_mode(constants: PoseidonConstantsCircuit<Base>, mode: ROMode) -> Self {
Self {
state: Vec::new(),
constants,
mode,
}
}
fn absorb(&mut self, e: Base) {
self.state.push(e);
}
fn squeeze(&mut self, num_bits: usize, start_with_one: bool) -> Base {
let hash = match self.mode {
ROMode::Wide => poseidon_squeeze_native(&self.constants.wide, &self.state),
ROMode::Narrow => poseidon_squeeze_native(&self.constants.narrow, &self.state),
};
self.state = vec![hash];
let bits = hash.to_le_bits();
let mut res = Base::ZERO;
let mut coeff = Base::ONE;
for bit in bits[..num_bits].into_iter() {
if *bit {
res += coeff;
}
coeff += coeff;
}
if start_with_one {
let msb_coeff = coeff * Base::from(2u64).invert().unwrap();
if !bits[num_bits - 1] {
res += msb_coeff;
}
}
res
}
}
#[derive(Serialize, Deserialize)]
pub struct PoseidonROCircuit<Scalar: PrimeField> {
state: Vec<AllocatedNum<Scalar>>,
constants: PoseidonConstantsCircuit<Scalar>,
mode: ROMode,
compact: bool,
}
macro_rules! poseidon_squeeze_circuit {
($constants:expr, $state:expr, $compact:expr, $ns:expr) => {{
let parameter = IOPattern(vec![
SpongeOp::Absorb($state.len() as u32),
SpongeOp::Squeeze(1u32),
]);
let hash = {
let mut sponge = SpongeCircuit::new_with_constants($constants, Simplex);
sponge.set_compact($compact);
sponge.start(parameter, None, $ns);
SpongeAPI::absorb(
&mut sponge,
$state.len() as u32,
&$state
.iter()
.map(|e| Elt::Allocated(e.clone()))
.collect::<Vec<Elt<_>>>(),
$ns,
);
let output = SpongeAPI::squeeze(&mut sponge, 1, $ns);
sponge.finish($ns).unwrap();
output
};
Elt::ensure_allocated(&hash[0], &mut $ns.namespace(|| "ensure allocated"))
}};
}
impl<Scalar> ROCircuitTrait<Scalar> for PoseidonROCircuit<Scalar>
where
Scalar: PrimeField + PrimeFieldBits + Serialize + for<'de> Deserialize<'de>,
{
type NativeRO = PoseidonRO<Scalar>;
type Constants = PoseidonConstantsCircuit<Scalar>;
fn new(constants: PoseidonConstantsCircuit<Scalar>) -> Self {
Self {
state: Vec::new(),
constants,
mode: ROMode::Wide,
compact: false,
}
}
fn new_with_mode(constants: PoseidonConstantsCircuit<Scalar>, mode: ROMode) -> Self {
Self {
state: Vec::new(),
constants,
mode,
compact: false,
}
}
fn absorb(&mut self, e: &AllocatedNum<Scalar>) {
self.state.push(e.clone());
}
fn squeeze<CS: ConstraintSystem<Scalar>>(
&mut self,
mut cs: CS,
num_bits: usize,
start_with_one: bool,
) -> Result<Vec<AllocatedBit>, SynthesisError> {
let mut ns = cs.namespace(|| "ns");
let hash = match self.mode {
ROMode::Wide => {
poseidon_squeeze_circuit!(&self.constants.wide, &self.state, self.compact, &mut ns)?
}
ROMode::Narrow => {
poseidon_squeeze_circuit!(&self.constants.narrow, &self.state, self.compact, &mut ns)?
}
};
self.state = vec![hash.clone()];
let mut bits: Vec<AllocatedBit> = hash
.to_bits_le_strict(ns.namespace(|| "poseidon hash to boolean"))?
.iter()
.map(|boolean| match boolean {
Boolean::Is(ref x) => x.clone(),
_ => panic!("Wrong type of input. We should have never reached there"),
})
.collect::<Vec<AllocatedBit>>()[..num_bits]
.to_vec();
if start_with_one {
let msb_idx = num_bits - 1;
bits[msb_idx] = AllocatedBit::alloc(ns.namespace(|| "set msb to 1"), Some(true))?;
ns.enforce(
|| "check bits[msb] = 1",
|lc| lc + bits[msb_idx].get_variable(),
|lc| lc + CS::one(),
|lc| lc + CS::one(),
);
}
Ok(bits)
}
fn squeeze_scalar<CS: ConstraintSystem<Scalar>>(
&mut self,
mut cs: CS,
) -> Result<AllocatedNum<Scalar>, SynthesisError> {
let mut ns = cs.namespace(|| "ns");
let hash = match self.mode {
ROMode::Wide => {
poseidon_squeeze_circuit!(&self.constants.wide, &self.state, self.compact, &mut ns)?
}
ROMode::Narrow => {
poseidon_squeeze_circuit!(&self.constants.narrow, &self.state, self.compact, &mut ns)?
}
};
self.state = vec![hash.clone()];
Ok(hash)
}
fn set_compact(&mut self, compact: bool) {
self.compact = compact;
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
constants::NUM_CHALLENGE_BITS,
frontend::solver::SatisfyingAssignment,
gadgets::utils::le_bits_to_num,
provider::{
Bn256EngineKZG, GrumpkinEngine, PallasEngine, Secp256k1Engine, Secq256k1Engine, VestaEngine,
},
traits::Engine,
};
use ff::Field;
use rand::rngs::OsRng;
fn test_poseidon_ro_with<E: Engine>() {
let mut csprng: OsRng = OsRng;
let constants = PoseidonConstantsCircuit::<E::Scalar>::default();
let num_absorbs = 32;
let mut ro: PoseidonRO<E::Scalar> = PoseidonRO::new(constants.clone());
let mut ro_gadget: PoseidonROCircuit<E::Scalar> = PoseidonROCircuit::new(constants);
let mut cs = SatisfyingAssignment::<E>::new();
for i in 0..num_absorbs {
let num = E::Scalar::random(&mut csprng);
ro.absorb(num);
let num_gadget = AllocatedNum::alloc_infallible(cs.namespace(|| format!("data {i}")), || num);
num_gadget
.inputize(&mut cs.namespace(|| format!("input {i}")))
.unwrap();
ro_gadget.absorb(&num_gadget);
}
let num = ro.squeeze(NUM_CHALLENGE_BITS, false);
let num2_bits = ro_gadget
.squeeze(&mut cs, NUM_CHALLENGE_BITS, false)
.unwrap();
let num2 = le_bits_to_num(&mut cs, &num2_bits).unwrap();
assert_eq!(num, num2.get_value().unwrap());
}
#[test]
fn test_poseidon_ro() {
test_poseidon_ro_with::<PallasEngine>();
test_poseidon_ro_with::<VestaEngine>();
test_poseidon_ro_with::<Bn256EngineKZG>();
test_poseidon_ro_with::<GrumpkinEngine>();
test_poseidon_ro_with::<Secp256k1Engine>();
test_poseidon_ro_with::<Secq256k1Engine>();
}
fn test_poseidon_ro_narrow_with<E: Engine>() {
let mut csprng: OsRng = OsRng;
let constants = PoseidonConstantsCircuit::<E::Scalar>::default();
let num_absorbs = 4;
let mut ro: PoseidonRO<E::Scalar> =
PoseidonRO::new_with_mode(constants.clone(), ROMode::Narrow);
let mut ro_gadget: PoseidonROCircuit<E::Scalar> =
PoseidonROCircuit::new_with_mode(constants, ROMode::Narrow);
let mut cs = SatisfyingAssignment::<E>::new();
for i in 0..num_absorbs {
let num = E::Scalar::random(&mut csprng);
ro.absorb(num);
let num_gadget = AllocatedNum::alloc_infallible(cs.namespace(|| format!("data {i}")), || num);
num_gadget
.inputize(&mut cs.namespace(|| format!("input {i}")))
.unwrap();
ro_gadget.absorb(&num_gadget);
}
let num = ro.squeeze(NUM_CHALLENGE_BITS, false);
let num2_bits = ro_gadget
.squeeze(&mut cs, NUM_CHALLENGE_BITS, false)
.unwrap();
let num2 = le_bits_to_num(&mut cs, &num2_bits).unwrap();
assert_eq!(num, num2.get_value().unwrap());
}
#[test]
fn test_poseidon_ro_narrow() {
test_poseidon_ro_narrow_with::<PallasEngine>();
test_poseidon_ro_narrow_with::<VestaEngine>();
test_poseidon_ro_narrow_with::<Bn256EngineKZG>();
test_poseidon_ro_narrow_with::<GrumpkinEngine>();
}
}