use std::io::{self, Read};
use group::GroupEncoding;
use midnight_proofs::transcript::{Hashable, Sampleable, TranscriptHash};
#[cfg(feature = "dev-curves")]
use {ff::PrimeField, midnight_curves::bn256};
use super::{
constants::{PoseidonField, NB_FULL_ROUNDS, NB_PARTIAL_ROUNDS, RATE, WIDTH},
round_skips::{PreComputedRoundCPU, PreComputedRoundCircuit},
PoseidonChip, NB_SKIPS_CIRCUIT,
};
use crate::{
field::foreign::params::MultiEmulationParams as MEP,
instructions::SpongeCPU,
types::{AssignedForeignPoint, Instantiable},
CircuitField,
};
pub(crate) const NB_SKIPS_CPU: usize = 2;
#[derive(Clone, Debug)]
pub struct PoseidonState<F: PoseidonField> {
pre_computed: PreComputedRoundCPU<F>,
register: [F; WIDTH],
queue: Vec<F>,
squeeze_position: usize,
input_len: Option<usize>,
}
fn linear_layer<F: PoseidonField>(state: &mut [F], constants: &mut [F]) {
#[allow(clippy::needless_range_loop)]
for i in 0..WIDTH {
for j in 0..WIDTH {
constants[i] += F::MDS[i][j] * state[j];
}
}
state.copy_from_slice(constants);
}
pub(crate) fn full_round_cpu<F: PoseidonField>(round_index: usize, state: &mut [F]) {
state.iter_mut().for_each(|x| *x = x.square().square() * *x);
let mut new_state = if round_index == NB_FULL_ROUNDS + NB_PARTIAL_ROUNDS - 1 {
[F::ZERO; WIDTH]
} else {
F::ROUND_CONSTANTS[round_index + 1]
};
linear_layer(state, &mut new_state);
}
fn partial_round_cpu<F: PoseidonField>(
pre_computed: &PreComputedRoundCPU<F>,
round_batch_index: usize,
state: &mut [F], ) {
pre_computed
.partial_round_id
.eval::<NB_SKIPS_CPU>(&pre_computed.round_constants[round_batch_index], state);
}
pub(crate) fn partial_round_cpu_for_circuits<F: PoseidonField>(
pre_computed: &PreComputedRoundCircuit<F>,
round_batch_index: usize,
state: &mut [F], ) -> [F; NB_SKIPS_CIRCUIT] {
pre_computed
.partial_round_id
.eval::<NB_SKIPS_CIRCUIT>(&pre_computed.round_constants[round_batch_index], state)
}
fn partial_round_cpu_raw<F: PoseidonField>(round: usize, state: &mut [F]) {
state[WIDTH - 1] *= state[WIDTH - 1].square().square();
let mut new_state = F::ROUND_CONSTANTS[round + 1];
linear_layer(state, &mut new_state)
}
pub fn permutation_cpu<F: PoseidonField>(pre_computed: &PreComputedRoundCPU<F>, state: &mut [F]) {
let nb_skips = pre_computed.partial_round_id.nb_skips;
let nb_main_partial_rounds = NB_PARTIAL_ROUNDS / (1 + nb_skips);
let remainder_partial_rounds = NB_PARTIAL_ROUNDS % (1 + nb_skips);
for (x, k0) in state.iter_mut().zip(F::ROUND_CONSTANTS[0]) {
*x += k0;
}
(0..NB_FULL_ROUNDS / 2).for_each(|round_index| full_round_cpu(round_index, state));
(0..nb_main_partial_rounds)
.for_each(|round_batch_index| partial_round_cpu(pre_computed, round_batch_index, state));
(NB_FULL_ROUNDS / 2 + NB_PARTIAL_ROUNDS - remainder_partial_rounds..)
.take(remainder_partial_rounds)
.for_each(|round_index| partial_round_cpu_raw(round_index, state));
(NB_FULL_ROUNDS / 2 + NB_PARTIAL_ROUNDS..)
.take(NB_FULL_ROUNDS / 2)
.for_each(|round_index| {
full_round_cpu(round_index, state);
})
}
impl<F: PoseidonField> SpongeCPU<F, F> for PoseidonChip<F> {
type StateCPU = PoseidonState<F>;
fn init(input_len: Option<usize>) -> Self::StateCPU {
let mut register = [F::ZERO; WIDTH];
register[RATE] = F::from_u128(input_len.map(|l| l as u128).unwrap_or(1 << 64));
let pre_computed = PreComputedRoundCPU::init();
PoseidonState {
pre_computed,
register,
queue: Vec::new(),
squeeze_position: 0,
input_len,
}
}
fn absorb(state: &mut Self::StateCPU, inputs: &[F]) {
state.queue.extend(inputs);
state.squeeze_position = 0;
}
fn squeeze(state: &mut Self::StateCPU) -> F {
if state.squeeze_position > 0 {
if state.input_len.is_some() {
panic!("Attempting to squeeze multiple times a fixed-size Poseidon sponge (CPU).")
};
debug_assert!(state.queue.is_empty());
let output = state.register[state.squeeze_position % RATE];
state.squeeze_position = (state.squeeze_position + 1) % RATE;
return output;
}
match state.input_len {
None => {
let padding = F::from(state.queue.len() as u64);
state.queue.push(padding);
}
Some(len) => {
if state.queue.len() != len {
panic!("Inconsistent lengths in fixed-size Poseidon sponge (CPU). Expected: {}, found: {}.", len, state.queue.len())
};
}
}
for chunk in state.queue.chunks(RATE) {
for (entry, value) in state.register.iter_mut().zip(chunk.iter()) {
*entry += value;
}
permutation_cpu(&state.pre_computed, &mut state.register);
}
state.queue = Vec::new();
state.squeeze_position = 1 % RATE;
state.register[0]
}
}
impl<F: PoseidonField> TranscriptHash for PoseidonState<F> {
type Input = Vec<F>;
type Output = F;
fn init() -> Self {
PoseidonChip::init(None)
}
fn absorb(&mut self, input: &Self::Input) {
PoseidonChip::absorb(self, input)
}
fn squeeze(&mut self) -> Self::Output {
PoseidonChip::squeeze(self)
}
}
impl Hashable<PoseidonState<midnight_curves::Fq>> for midnight_curves::G1Projective {
fn to_input(&self) -> Vec<midnight_curves::Fq> {
AssignedForeignPoint::<midnight_curves::Fq, midnight_curves::G1Projective, MEP>::as_public_input(self)
}
fn to_bytes(&self) -> Vec<u8> {
<midnight_curves::G1Affine as GroupEncoding>::to_bytes(&self.into())
.as_ref()
.to_vec()
}
fn read(buffer: &mut impl Read) -> io::Result<Self> {
let mut bytes = <midnight_curves::G1Affine as GroupEncoding>::Repr::default();
buffer.read_exact(bytes.as_mut())?;
Option::from(midnight_curves::G1Affine::from_bytes(&bytes))
.ok_or_else(|| io::Error::other("Invalid BLS12-381 point encoding in proof"))
.map(|p: midnight_curves::G1Affine| p.into())
}
}
impl Hashable<PoseidonState<midnight_curves::Fq>> for midnight_curves::Fq {
fn to_input(&self) -> Vec<midnight_curves::Fq> {
vec![*self]
}
fn to_bytes(&self) -> Vec<u8> {
self.to_bytes_le().as_ref().to_vec()
}
fn read(buffer: &mut impl Read) -> io::Result<Self> {
use midnight_curves::Fq;
let mut bytes = [0u8; <Fq as CircuitField>::NUM_BYTES];
buffer.read_exact(bytes.as_mut())?;
<Fq as CircuitField>::from_bytes_le(&bytes)
.ok_or_else(|| io::Error::other("Invalid BLS12-381 scalar encoding in proof"))
}
}
impl Sampleable<PoseidonState<midnight_curves::Fq>> for midnight_curves::Fq {
fn sample(out: midnight_curves::Fq) -> Self {
out
}
}
#[cfg(feature = "dev-curves")]
impl Hashable<PoseidonState<bn256::Fr>> for bn256::G1 {
fn to_input(&self) -> Vec<bn256::Fr> {
AssignedForeignPoint::<bn256::Fr, bn256::G1, MEP>::as_public_input(self)
}
fn to_bytes(&self) -> Vec<u8> {
<bn256::G1Affine as GroupEncoding>::to_bytes(&self.into()).as_ref().to_vec()
}
fn read(buffer: &mut impl Read) -> io::Result<Self> {
let mut bytes = <bn256::G1Affine as GroupEncoding>::Repr::default();
buffer.read_exact(bytes.as_mut())?;
Option::from(bn256::G1Affine::from_bytes(&bytes))
.ok_or_else(|| io::Error::other("Invalid BN256 point encoding in proof"))
.map(|p: bn256::G1Affine| p.into())
}
}
#[cfg(feature = "dev-curves")]
impl Hashable<PoseidonState<bn256::Fr>> for bn256::Fr {
fn to_input(&self) -> Vec<bn256::Fr> {
vec![*self]
}
fn to_bytes(&self) -> Vec<u8> {
self.to_bytes().to_vec()
}
fn read(buffer: &mut impl Read) -> io::Result<Self> {
let mut bytes = <Self as PrimeField>::Repr::default();
buffer.read_exact(bytes.as_mut())?;
Option::from(Self::from_repr(bytes))
.ok_or_else(|| io::Error::other("Invalid BN256 scalar encoding in proof"))
}
}
#[cfg(feature = "dev-curves")]
impl Sampleable<PoseidonState<bn256::Fr>> for bn256::Fr {
fn sample(out: bn256::Fr) -> Self {
out
}
}
#[cfg(test)]
mod tests {
use rand::SeedableRng;
use rand_chacha::ChaCha12Rng;
use super::*;
use crate::hash::poseidon::permutation_cpu;
fn permutation_cpu_raw<F: PoseidonField>(state: &mut [F]) {
for (x, k0) in state.iter_mut().zip(F::ROUND_CONSTANTS[0]) {
*x += k0;
}
for round_index in 0..NB_FULL_ROUNDS / 2 {
full_round_cpu(round_index, state);
}
for round_index in (NB_FULL_ROUNDS / 2..).take(NB_PARTIAL_ROUNDS) {
partial_round_cpu_raw(round_index, state);
}
for round_index in (NB_FULL_ROUNDS / 2 + NB_PARTIAL_ROUNDS..).take(NB_FULL_ROUNDS / 2) {
full_round_cpu(round_index, state);
}
}
fn consistency_cpu<F: PoseidonField + ff::FromUniformBytes<64>>(nb_samples: usize) {
println!(
">> Testing the consistency between the two cpu implementations of the permutation ({NB_SKIPS_CPU} round skips VS no round skips)."
);
let pre_computed = PreComputedRoundCPU::init();
let mut rng = ChaCha12Rng::seed_from_u64(0xf007ba11);
(0..nb_samples)
.for_each(|_| {
let input: [F; WIDTH] =
core::array::from_fn(|_| F::random(&mut rng));
let mut res1 = input;
let mut res2 = input;
permutation_cpu_raw(&mut res1);
permutation_cpu(&pre_computed, &mut res2);
if res1 != res2 {
panic!("=> Inconsistencies between the cpu implementations of the permutations.\n\nOn input x = {:?},\n\npermutation_cpu_no_skip(x) = {:?}\n\npermutation_cpu_with_skips(x) = {:?}\n", input, res1, res2)
}
});
println!("=> No internal inconsistency found.")
}
#[test]
fn cpu_test() {
consistency_cpu::<midnight_curves::Fq>(1);
}
}