use crate::sponge::constraints::AbsorbGadget;
use crate::sponge::constraints::{CryptographicSpongeVar, SpongeWithGadget};
use crate::sponge::rescue::{RescueConfig, RescueSponge};
use crate::sponge::DuplexSpongeMode;
use ark_ff::PrimeField;
use ark_r1cs_std::fields::fp::FpVar;
use ark_r1cs_std::prelude::*;
use ark_relations::gr1cs::{ConstraintSystemRef, SynthesisError};
#[cfg(not(feature = "std"))]
use ark_std::vec::Vec;
pub const RESCUE_PREDICATE: &str = "Deg5-Mul";
#[derive(Clone)]
pub struct RescueSpongeVar<F: PrimeField> {
pub cs: ConstraintSystemRef<F>,
pub parameters: RescueConfig<F>,
pub state: Vec<FpVar<F>>,
pub mode: DuplexSpongeMode,
}
impl<F: PrimeField> SpongeWithGadget<F> for RescueSponge<F> {
type Var = RescueSpongeVar<F>;
}
impl<F: PrimeField> RescueSpongeVar<F> {
#[tracing::instrument(target = "gr1cs", skip(self))]
fn apply_s_box(
&self,
state: &mut [FpVar<F>],
alpha: u64,
is_forward_pass: bool,
) -> Result<(), SynthesisError> {
if alpha == 5 && self.cs.has_predicate(RESCUE_PREDICATE) {
use ark_relations::lc;
let cs = state
.iter()
.fold(ConstraintSystemRef::None, |cs, item| cs.or(item.cs()));
if is_forward_pass {
for state_item in state {
if let FpVar::Var(ref fp) = state_item {
let new_state_item = FpVar::new_witness(cs.clone(), || {
state_item.value().map(|e| e.pow([self.parameters.alpha]))
})?;
let FpVar::Var(ref new_fp) = new_state_item else {
return Err(SynthesisError::AssignmentMissing);
};
cs.enforce_constraint_arity_2(
RESCUE_PREDICATE,
|| lc![fp.variable],
|| lc![new_fp.variable],
)?;
*state_item = new_state_item;
} else {
*state_item = state_item.pow_by_constant([self.parameters.alpha])?;
}
}
} else {
let alpha_inv = self.parameters.alpha_inv.to_u64_digits();
for state_item in state {
if let FpVar::Var(ref fp) = state_item {
let new_state_item = FpVar::new_witness(cs.clone(), || {
state_item.value().map(|e| e.pow(&alpha_inv))
})?;
let FpVar::Var(ref new_fp) = new_state_item else {
return Err(SynthesisError::AssignmentMissing);
};
cs.enforce_constraint_arity_2(
RESCUE_PREDICATE,
|| lc![new_fp.variable],
|| lc![fp.variable],
)?;
*state_item = new_state_item;
} else {
*state_item = state_item.pow_by_constant(&alpha_inv)?;
}
}
}
} else if is_forward_pass {
for state_item in state.iter_mut() {
*state_item = state_item.pow_by_constant([self.parameters.alpha])?;
}
} else {
for state_item in state.iter_mut() {
let output = FpVar::new_witness(self.cs(), || {
state_item
.value()
.map(|e| e.pow(self.parameters.alpha_inv.to_u64_digits()))
})?;
let expected_input = output.pow_by_constant([alpha])?;
expected_input.enforce_equal(state_item)?;
*state_item = output;
}
}
Ok(())
}
#[tracing::instrument(target = "gr1cs", skip(self))]
fn apply_ark(&self, state: &mut [FpVar<F>], round_key: &Vec<F>) -> Result<(), SynthesisError> {
for (i, state_elem) in state.iter_mut().enumerate() {
*state_elem += round_key[i];
}
Ok(())
}
#[tracing::instrument(target = "gr1cs", skip(self))]
fn apply_mds(&self, state: &mut [FpVar<F>]) -> Result<(), SynthesisError> {
let mut new_state = Vec::new();
let zero = FpVar::<F>::zero();
for i in 0..state.len() {
let mut cur = zero.clone();
for (j, state_elem) in state.iter().enumerate() {
let term = state_elem * self.parameters.mds[i][j];
cur += &term;
}
new_state.push(cur);
}
state.clone_from_slice(&new_state[..state.len()]);
Ok(())
}
#[tracing::instrument(target = "gr1cs", skip(self))]
fn permute(&mut self) -> Result<(), SynthesisError> {
let mut state = self.state.clone();
self.apply_ark(&mut state, &self.parameters.arc[0])?;
for (round, round_key) in self.parameters.arc[1..].iter().enumerate() {
if (round % 2) == 0 {
self.apply_s_box(&mut state, self.parameters.alpha, false)?;
} else {
self.apply_s_box(&mut state, self.parameters.alpha, true)?;
}
self.apply_mds(&mut state)?;
self.apply_ark(&mut state, round_key)?;
}
self.state = state;
Ok(())
}
#[tracing::instrument(target = "gr1cs", skip(self))]
fn absorb_internal(
&mut self,
mut rate_start_index: usize,
elements: &[FpVar<F>],
) -> Result<(), SynthesisError> {
let mut remaining_elements = elements;
loop {
if rate_start_index + remaining_elements.len() <= self.parameters.rate {
for (i, element) in remaining_elements.iter().enumerate() {
self.state[self.parameters.capacity + i + rate_start_index] += element;
}
self.mode = DuplexSpongeMode::Absorbing {
next_absorb_index: rate_start_index + remaining_elements.len(),
};
return Ok(());
}
let num_elements_absorbed = self.parameters.rate - rate_start_index;
for (i, element) in remaining_elements
.iter()
.enumerate()
.take(num_elements_absorbed)
{
self.state[self.parameters.capacity + i + rate_start_index] += element;
}
self.permute()?;
remaining_elements = &remaining_elements[num_elements_absorbed..];
rate_start_index = 0;
}
}
#[tracing::instrument(target = "gr1cs", skip(self))]
fn squeeze_internal(
&mut self,
mut rate_start_index: usize,
output: &mut [FpVar<F>],
) -> Result<(), SynthesisError> {
let mut remaining_output = output;
loop {
if rate_start_index + remaining_output.len() <= self.parameters.rate {
remaining_output.clone_from_slice(
&self.state[self.parameters.capacity + rate_start_index
..(self.parameters.capacity + remaining_output.len() + rate_start_index)],
);
self.mode = DuplexSpongeMode::Squeezing {
next_squeeze_index: rate_start_index + remaining_output.len(),
};
return Ok(());
}
let num_elements_squeezed = self.parameters.rate - rate_start_index;
remaining_output[..num_elements_squeezed].clone_from_slice(
&self.state[self.parameters.capacity + rate_start_index
..(self.parameters.capacity + num_elements_squeezed + rate_start_index)],
);
if remaining_output.len() != self.parameters.rate {
self.permute()?;
}
remaining_output = &mut remaining_output[num_elements_squeezed..];
rate_start_index = 0;
}
}
}
impl<F: PrimeField> CryptographicSpongeVar<F, RescueSponge<F>> for RescueSpongeVar<F> {
type Parameters = RescueConfig<F>;
fn new(cs: ConstraintSystemRef<F>, parameters: &RescueConfig<F>) -> Self {
let zero = FpVar::<F>::zero();
let state = vec![zero; parameters.rate + parameters.capacity];
let mode = DuplexSpongeMode::Absorbing {
next_absorb_index: 0,
};
Self {
cs,
parameters: parameters.clone(),
state,
mode,
}
}
fn cs(&self) -> ConstraintSystemRef<F> {
self.cs.clone()
}
fn absorb(&mut self, input: &impl AbsorbGadget<F>) -> Result<(), SynthesisError> {
let input = input.to_sponge_field_elements()?;
if input.is_empty() {
return Ok(());
}
match self.mode {
DuplexSpongeMode::Absorbing { next_absorb_index } => {
let mut absorb_index = next_absorb_index;
if absorb_index == self.parameters.rate {
self.permute()?;
absorb_index = 0;
}
self.absorb_internal(absorb_index, input.as_slice())?;
}
DuplexSpongeMode::Squeezing {
next_squeeze_index: _,
} => {
self.permute()?;
self.absorb_internal(0, input.as_slice())?;
}
};
Ok(())
}
#[tracing::instrument(target = "gr1cs", skip(self))]
fn squeeze_bytes(&mut self, num_bytes: usize) -> Result<Vec<UInt8<F>>, SynthesisError> {
let usable_bytes = ((F::MODULUS_BIT_SIZE - 1) / 8) as usize;
let num_elements = (num_bytes + usable_bytes - 1) / usable_bytes;
let src_elements = self.squeeze_field_elements(num_elements)?;
let mut bytes: Vec<UInt8<F>> = Vec::with_capacity(usable_bytes * num_elements);
for elem in &src_elements {
bytes.extend_from_slice(&elem.to_bytes_le()?[..usable_bytes]);
}
bytes.truncate(num_bytes);
Ok(bytes)
}
#[tracing::instrument(target = "gr1cs", skip(self))]
fn squeeze_bits(&mut self, num_bits: usize) -> Result<Vec<Boolean<F>>, SynthesisError> {
let usable_bits = (F::MODULUS_BIT_SIZE - 1) as usize;
let num_elements = (num_bits + usable_bits - 1) / usable_bits;
let src_elements = self.squeeze_field_elements(num_elements)?;
let mut bits: Vec<Boolean<F>> = Vec::with_capacity(usable_bits * num_elements);
for elem in &src_elements {
bits.extend_from_slice(&elem.to_bits_le()?[..usable_bits]);
}
bits.truncate(num_bits);
Ok(bits)
}
#[tracing::instrument(target = "gr1cs", skip(self))]
fn squeeze_field_elements(
&mut self,
num_elements: usize,
) -> Result<Vec<FpVar<F>>, SynthesisError> {
let zero = FpVar::zero();
let mut squeezed_elems = vec![zero; num_elements];
match self.mode {
DuplexSpongeMode::Absorbing {
next_absorb_index: _,
} => {
self.permute()?;
self.squeeze_internal(0, &mut squeezed_elems)?;
}
DuplexSpongeMode::Squeezing { next_squeeze_index } => {
let mut squeeze_index = next_squeeze_index;
if squeeze_index == self.parameters.rate {
self.permute()?;
squeeze_index = 0;
}
self.squeeze_internal(squeeze_index, &mut squeezed_elems)?;
}
};
Ok(squeezed_elems)
}
}