use crate::poseidon::{
State,
helpers::{AlgebraicSponge, DuplexSpongeMode},
};
use snarkvm_console_types::{Field, prelude::*};
use snarkvm_fields::PoseidonParameters;
use smallvec::SmallVec;
use std::{ops::DerefMut, sync::Arc};
#[derive(Clone, Debug)]
pub struct PoseidonSponge<E: Environment, const RATE: usize, const CAPACITY: usize> {
parameters: Arc<PoseidonParameters<E::Field, RATE, CAPACITY>>,
state: State<E, RATE, CAPACITY>,
pub(in crate::poseidon) mode: DuplexSpongeMode,
}
impl<E: Environment, const RATE: usize, const CAPACITY: usize> AlgebraicSponge<E, RATE, CAPACITY>
for PoseidonSponge<E, RATE, CAPACITY>
{
type Parameters = Arc<PoseidonParameters<E::Field, RATE, CAPACITY>>;
fn new(parameters: &Self::Parameters) -> Self {
Self {
parameters: parameters.clone(),
state: State::default(),
mode: DuplexSpongeMode::Absorbing { next_absorb_index: 0 },
}
}
fn absorb(&mut self, input: &[Field<E>]) {
if !input.is_empty() {
match self.mode {
DuplexSpongeMode::Absorbing { mut next_absorb_index } => {
if next_absorb_index == RATE {
self.permute();
next_absorb_index = 0;
}
self.absorb_internal(next_absorb_index, input);
}
DuplexSpongeMode::Squeezing { next_squeeze_index: _ } => {
self.permute();
self.absorb_internal(0, input);
}
}
}
}
fn squeeze(&mut self, num_elements: u16) -> SmallVec<[Field<E>; 10]> {
if num_elements == 0 {
return SmallVec::new();
}
let mut output = if num_elements <= 10 {
smallvec::smallvec_inline![Field::<E>::zero(); 10]
} else {
smallvec::smallvec![Field::<E>::zero(); num_elements as usize]
};
match self.mode {
DuplexSpongeMode::Absorbing { next_absorb_index: _ } => {
self.permute();
self.squeeze_internal(0, &mut output[..num_elements as usize]);
}
DuplexSpongeMode::Squeezing { mut next_squeeze_index } => {
if next_squeeze_index == RATE {
self.permute();
next_squeeze_index = 0;
}
self.squeeze_internal(next_squeeze_index, &mut output[..num_elements as usize]);
}
}
output.truncate(num_elements as usize);
output
}
}
impl<E: Environment, const RATE: usize, const CAPACITY: usize> PoseidonSponge<E, RATE, CAPACITY> {
#[inline]
fn apply_ark(&mut self, round_number: usize) {
for (state_elem, ark_elem) in self.state.iter_mut().zip(&self.parameters.ark[round_number]) {
*state_elem += Field::<E>::new(*ark_elem);
}
}
#[inline]
fn apply_s_box(&mut self, is_full_round: bool) {
if is_full_round {
for elem in self.state.iter_mut() {
let e = elem.deref_mut();
*e = e.pow([self.parameters.alpha]);
}
}
else {
let e = self.state[0].deref_mut();
*e = e.pow([self.parameters.alpha]);
}
}
#[inline]
fn apply_mds(&mut self) {
let mut new_state = State::default();
let curr_state: Vec<<E as Environment>::Field> = self.state.iter().map(|e| *e.deref()).collect::<Vec<_>>();
new_state.iter_mut().zip(&self.parameters.mds).for_each(|(new_elem, mds_row)| {
*new_elem = Field::new(E::Field::sum_of_products(curr_state.as_slice(), mds_row));
});
self.state = new_state;
}
#[inline]
fn permute(&mut self) {
let partial_rounds = self.parameters.partial_rounds;
let full_rounds = self.parameters.full_rounds;
let full_rounds_over_2 = full_rounds / 2;
let partial_round_range = full_rounds_over_2..(full_rounds_over_2 + partial_rounds);
for i in 0..(partial_rounds + full_rounds) {
let is_full_round = !partial_round_range.contains(&i);
self.apply_ark(i);
self.apply_s_box(is_full_round);
self.apply_mds();
}
}
#[inline]
fn absorb_internal(&mut self, mut rate_start: usize, input: &[Field<E>]) {
if !input.is_empty() {
let first_chunk_size = std::cmp::min(RATE - rate_start, input.len());
let num_elements_remaining = input.len() - first_chunk_size;
let (first_chunk, rest_chunk) = input.split_at(first_chunk_size);
let rest_chunks = rest_chunk.chunks(RATE);
let total_num_chunks = 1 + (num_elements_remaining / RATE) +
usize::from((num_elements_remaining % RATE) != 0);
for (i, chunk) in std::iter::once(first_chunk).chain(rest_chunks).enumerate() {
for (element, state_elem) in chunk.iter().zip(&mut self.state.rate_state_mut()[rate_start..]) {
*state_elem += element;
}
if i == total_num_chunks - 1 {
self.mode = DuplexSpongeMode::Absorbing { next_absorb_index: rate_start + chunk.len() };
return;
} else {
self.permute();
}
rate_start = 0;
}
}
}
#[inline]
fn squeeze_internal(&mut self, mut rate_start: usize, output: &mut [Field<E>]) {
let output_size = output.len();
if output_size != 0 {
let first_chunk_size = std::cmp::min(RATE - rate_start, output.len());
let num_output_remaining = output.len() - first_chunk_size;
let (first_chunk, rest_chunk) = output.split_at_mut(first_chunk_size);
assert_eq!(rest_chunk.len(), num_output_remaining);
let rest_chunks = rest_chunk.chunks_mut(RATE);
let total_num_chunks = 1 + (num_output_remaining / RATE) +
usize::from((num_output_remaining % RATE) != 0);
for (i, chunk) in std::iter::once(first_chunk).chain(rest_chunks).enumerate() {
let range = rate_start..(rate_start + chunk.len());
debug_assert_eq!(
chunk.len(),
self.state.rate_state(range.clone()).len(),
"Failed to squeeze {output_size} at rate {RATE} & rate_start {rate_start}"
);
chunk.copy_from_slice(self.state.rate_state(range));
if i == total_num_chunks - 1 {
self.mode = DuplexSpongeMode::Squeezing { next_squeeze_index: (rate_start + chunk.len()) };
return;
} else {
self.permute();
}
rate_start = 0;
}
}
}
}