use crate::constraints::AbsorbGadget;
use crate::constraints::CryptographicSpongeVar;
use crate::poseidon::{PoseidonParameters, PoseidonSponge, PoseidonSpongeMode};
use ark_ff::{FpParameters, PrimeField};
use ark_r1cs_std::fields::fp::FpVar;
use ark_r1cs_std::prelude::*;
use ark_relations::r1cs::{ConstraintSystemRef, SynthesisError};
use ark_std::vec;
use ark_std::vec::Vec;
#[derive(Clone)]
pub struct PoseidonSpongeVar<F: PrimeField> {
pub cs: ConstraintSystemRef<F>,
pub full_rounds: u32,
pub partial_rounds: u32,
pub alpha: u64,
pub ark: Vec<Vec<F>>,
pub mds: Vec<Vec<F>>,
pub rate: usize,
pub capacity: usize,
pub state: Vec<FpVar<F>>,
mode: PoseidonSpongeMode,
}
impl<F: PrimeField> PoseidonSpongeVar<F> {
#[tracing::instrument(target = "r1cs", skip(self))]
fn apply_s_box(
&self,
state: &mut [FpVar<F>],
is_full_round: bool,
) -> Result<(), SynthesisError> {
if is_full_round {
for state_item in state.iter_mut() {
*state_item = state_item.pow_by_constant(&[self.alpha])?;
}
}
else {
state[state.len() - 1] = state[state.len() - 1].pow_by_constant(&[self.alpha])?;
}
Ok(())
}
#[tracing::instrument(target = "r1cs", skip(self))]
fn apply_ark(&self, state: &mut [FpVar<F>], round_number: usize) -> Result<(), SynthesisError> {
for (i, state_elem) in state.iter_mut().enumerate() {
*state_elem += self.ark[round_number][i];
}
Ok(())
}
#[tracing::instrument(target = "r1cs", skip(self))]
fn apply_mds(&self, state: &mut [FpVar<F>]) -> Result<(), SynthesisError> {
let mut new_state = Vec::new();
let zero = FpVar::<F>::zero();
for i in 0..state.len() {
let mut cur = zero.clone();
for (j, state_elem) in state.iter().enumerate() {
let term = state_elem * self.mds[i][j];
cur += &term;
}
new_state.push(cur);
}
state.clone_from_slice(&new_state[..state.len()]);
Ok(())
}
#[tracing::instrument(target = "r1cs", skip(self))]
fn permute(&mut self) -> Result<(), SynthesisError> {
let full_rounds_over_2 = self.full_rounds / 2;
let mut state = self.state.clone();
for i in 0..full_rounds_over_2 {
self.apply_ark(&mut state, i as usize)?;
self.apply_s_box(&mut state, true)?;
self.apply_mds(&mut state)?;
}
for i in full_rounds_over_2..(full_rounds_over_2 + self.partial_rounds) {
self.apply_ark(&mut state, i as usize)?;
self.apply_s_box(&mut state, false)?;
self.apply_mds(&mut state)?;
}
for i in
(full_rounds_over_2 + self.partial_rounds)..(self.partial_rounds + self.full_rounds)
{
self.apply_ark(&mut state, i as usize)?;
self.apply_s_box(&mut state, true)?;
self.apply_mds(&mut state)?;
}
self.state = state;
Ok(())
}
#[tracing::instrument(target = "r1cs", skip(self))]
fn absorb_internal(
&mut self,
mut rate_start_index: usize,
elements: &[FpVar<F>],
) -> Result<(), SynthesisError> {
let mut remaining_elements = elements;
loop {
if rate_start_index + remaining_elements.len() <= self.rate {
for (i, element) in remaining_elements.iter().enumerate() {
self.state[i + rate_start_index] += element;
}
self.mode = PoseidonSpongeMode::Absorbing {
next_absorb_index: rate_start_index + remaining_elements.len(),
};
return Ok(());
}
let num_elements_absorbed = self.rate - rate_start_index;
for (i, element) in remaining_elements
.iter()
.enumerate()
.take(num_elements_absorbed)
{
self.state[i + rate_start_index] += element;
}
self.permute()?;
remaining_elements = &remaining_elements[num_elements_absorbed..];
rate_start_index = 0;
}
}
#[tracing::instrument(target = "r1cs", skip(self))]
fn squeeze_internal(
&mut self,
mut rate_start_index: usize,
output: &mut [FpVar<F>],
) -> Result<(), SynthesisError> {
let mut remaining_output = output;
loop {
if rate_start_index + remaining_output.len() <= self.rate {
remaining_output.clone_from_slice(
&self.state[rate_start_index..(remaining_output.len() + rate_start_index)],
);
self.mode = PoseidonSpongeMode::Squeezing {
next_squeeze_index: rate_start_index + remaining_output.len(),
};
return Ok(());
}
let num_elements_squeezed = self.rate - rate_start_index;
remaining_output[..num_elements_squeezed].clone_from_slice(
&self.state[rate_start_index..(num_elements_squeezed + rate_start_index)],
);
if remaining_output.len() != self.rate {
self.permute()?;
}
remaining_output = &mut remaining_output[num_elements_squeezed..];
rate_start_index = 0;
}
}
}
impl<F: PrimeField> CryptographicSpongeVar<F, PoseidonSponge<F>> for PoseidonSpongeVar<F> {
type Parameters = PoseidonParameters<F>;
#[tracing::instrument(target = "r1cs", skip(cs))]
fn new(cs: ConstraintSystemRef<F>, params: &PoseidonParameters<F>) -> Self {
let full_rounds = params.full_rounds;
let partial_rounds = params.partial_rounds;
let alpha = params.alpha;
let mds = params.mds.to_vec();
let ark = params.ark.to_vec();
let rate = 2;
let capacity = 1;
let zero = FpVar::<F>::zero();
let state = vec![zero; rate + capacity];
let mode = PoseidonSpongeMode::Absorbing {
next_absorb_index: 0,
};
Self {
cs,
full_rounds,
partial_rounds,
alpha,
ark,
mds,
state,
rate,
capacity,
mode,
}
}
#[tracing::instrument(target = "r1cs", skip(self))]
fn cs(&self) -> ConstraintSystemRef<F> {
self.cs.clone()
}
#[tracing::instrument(target = "r1cs", skip(self, input))]
fn absorb(&mut self, input: &impl AbsorbGadget<F>) -> Result<(), SynthesisError> {
let input = input.to_sponge_field_elements()?;
if input.is_empty() {
return Ok(());
}
match self.mode {
PoseidonSpongeMode::Absorbing { next_absorb_index } => {
let mut absorb_index = next_absorb_index;
if absorb_index == self.rate {
self.permute()?;
absorb_index = 0;
}
self.absorb_internal(absorb_index, input.as_slice())?;
}
PoseidonSpongeMode::Squeezing {
next_squeeze_index: _,
} => {
self.permute()?;
self.absorb_internal(0, input.as_slice())?;
}
};
Ok(())
}
#[tracing::instrument(target = "r1cs", skip(self))]
fn squeeze_bytes(&mut self, num_bytes: usize) -> Result<Vec<UInt8<F>>, SynthesisError> {
let usable_bytes = (F::Params::CAPACITY / 8) as usize;
let num_elements = (num_bytes + usable_bytes - 1) / usable_bytes;
let src_elements = self.squeeze_field_elements(num_elements)?;
let mut bytes: Vec<UInt8<F>> = Vec::with_capacity(usable_bytes * num_elements);
for elem in &src_elements {
bytes.extend_from_slice(&elem.to_bytes()?[..usable_bytes]);
}
bytes.truncate(num_bytes);
Ok(bytes)
}
#[tracing::instrument(target = "r1cs", skip(self))]
fn squeeze_bits(&mut self, num_bits: usize) -> Result<Vec<Boolean<F>>, SynthesisError> {
let usable_bits = F::Params::CAPACITY as usize;
let num_elements = (num_bits + usable_bits - 1) / usable_bits;
let src_elements = self.squeeze_field_elements(num_elements)?;
let mut bits: Vec<Boolean<F>> = Vec::with_capacity(usable_bits * num_elements);
for elem in &src_elements {
bits.extend_from_slice(&elem.to_bits_le()?[..usable_bits]);
}
bits.truncate(num_bits);
Ok(bits)
}
#[tracing::instrument(target = "r1cs", skip(self))]
fn squeeze_field_elements(
&mut self,
num_elements: usize,
) -> Result<Vec<FpVar<F>>, SynthesisError> {
let zero = FpVar::zero();
let mut squeezed_elems = vec![zero; num_elements];
match self.mode {
PoseidonSpongeMode::Absorbing {
next_absorb_index: _,
} => {
self.permute()?;
self.squeeze_internal(0, &mut squeezed_elems)?;
}
PoseidonSpongeMode::Squeezing { next_squeeze_index } => {
let mut squeeze_index = next_squeeze_index;
if squeeze_index == self.rate {
self.permute()?;
squeeze_index = 0;
}
self.squeeze_internal(squeeze_index, &mut squeezed_elems)?;
}
};
Ok(squeezed_elems)
}
}
#[cfg(test)]
mod tests {
use crate::constraints::CryptographicSpongeVar;
use crate::poseidon::constraints::PoseidonSpongeVar;
use crate::poseidon::tests::poseidon_parameters_for_test;
use crate::poseidon::PoseidonSponge;
use crate::{CryptographicSponge, FieldBasedCryptographicSponge};
use ark_ff::UniformRand;
use ark_r1cs_std::fields::fp::FpVar;
use ark_r1cs_std::prelude::*;
use ark_relations::r1cs::ConstraintSystem;
use ark_relations::*;
use ark_std::test_rng;
use ark_test_curves::bls12_381::Fr;
#[test]
fn absorb_test() {
let mut rng = test_rng();
let cs = ConstraintSystem::new_ref();
let absorb1: Vec<_> = (0..256).map(|_| Fr::rand(&mut rng)).collect();
let absorb1_var: Vec<_> = absorb1
.iter()
.map(|v| FpVar::new_input(ns!(cs, "absorb1"), || Ok(*v)).unwrap())
.collect();
let absorb2: Vec<_> = (0..8).map(|i| vec![i, i + 1, i + 2]).collect();
let absorb2_var: Vec<_> = absorb2
.iter()
.map(|v| UInt8::new_input_vec(ns!(cs, "absorb2"), v).unwrap())
.collect();
let sponge_params = poseidon_parameters_for_test();
let mut native_sponge = PoseidonSponge::<Fr>::new(&sponge_params);
let mut constraint_sponge = PoseidonSpongeVar::<Fr>::new(cs.clone(), &sponge_params);
native_sponge.absorb(&absorb1);
constraint_sponge.absorb(&absorb1_var).unwrap();
let squeeze1 = native_sponge.squeeze_native_field_elements(1);
let squeeze2 = constraint_sponge.squeeze_field_elements(1).unwrap();
assert_eq!(squeeze2.value().unwrap(), squeeze1);
assert!(cs.is_satisfied().unwrap());
native_sponge.absorb(&absorb2);
constraint_sponge.absorb(&absorb2_var).unwrap();
let squeeze1 = native_sponge.squeeze_native_field_elements(1);
let squeeze2 = constraint_sponge.squeeze_field_elements(1).unwrap();
assert_eq!(squeeze2.value().unwrap(), squeeze1);
assert!(cs.is_satisfied().unwrap());
}
}