use ff::Field;
use midnight_proofs::plonk::Expression;
use super::{sbox, PoseidonField, NB_SKIPS_CPU};
use crate::hash::poseidon::{
constants::{NB_FULL_ROUNDS, NB_PARTIAL_ROUNDS, WIDTH},
NB_SKIPS_CIRCUIT,
};
pub(crate) const NB_SKIPS_MAX: usize = if NB_SKIPS_CIRCUIT < NB_SKIPS_CPU {
NB_SKIPS_CPU
} else {
NB_SKIPS_CIRCUIT
};
type RoundContantsCPU<F> = [[F; WIDTH + NB_SKIPS_CPU]; NB_PARTIAL_ROUNDS / (1 + NB_SKIPS_CPU)];
type RoundContantsCircuit<F> =
[[F; WIDTH + NB_SKIPS_CIRCUIT]; NB_PARTIAL_ROUNDS / (1 + NB_SKIPS_CIRCUIT)];
#[derive(Clone, Copy, Debug)]
pub(crate) struct RoundVarId<F> {
var_coeffs: [F; WIDTH + NB_SKIPS_MAX], const_coeffs: [F; WIDTH * (1 + NB_SKIPS_MAX)], }
#[derive(Clone, Copy, Debug)]
pub(crate) struct RoundId<F> {
pub nb_skips: usize,
ids: [RoundVarId<F>; WIDTH + 1 + NB_SKIPS_MAX], }
#[derive(Clone, Copy, Debug)]
pub struct PreComputedRoundCPU<F: PoseidonField> {
pub(crate) round_constants: RoundContantsCPU<F>,
pub(crate) partial_round_id: RoundId<F>,
}
#[derive(Clone, Copy, Debug)]
pub struct PreComputedRoundCircuit<F: PoseidonField> {
pub(crate) round_constants: RoundContantsCircuit<F>,
pub(crate) partial_round_id: RoundId<F>,
}
impl<F: Field> RoundVarId<F> {
fn add_and_mul(&mut self, rhs: &Self, c: &F) {
self.var_coeffs
.iter_mut()
.chain(self.const_coeffs.iter_mut())
.zip(rhs.var_coeffs.iter().chain(rhs.const_coeffs.iter()))
.for_each(|(a, b)| *a += *b * *c)
}
fn init() -> Self {
RoundVarId {
var_coeffs: [F::ZERO; WIDTH + NB_SKIPS_MAX],
const_coeffs: [F::ZERO; WIDTH * (1 + NB_SKIPS_MAX)],
}
}
fn from_constant_index(round_offset: &usize, column: &usize) -> Self {
let mut id = Self::init();
id.const_coeffs[*round_offset * WIDTH + *column] = F::ONE;
id
}
fn eval_constants(&self, instances: &[[F; WIDTH]]) -> F {
self.const_coeffs
.iter()
.zip(instances.iter().flatten())
.fold(F::ZERO, |accu, (x1, x2)| accu + *x1 * x2)
}
fn eval_vars(
&self,
instances: &[F], constant: F,
) -> F {
self.var_coeffs
.iter()
.zip(instances.iter())
.fold(constant, |accu, (v1, v2)| accu + *v1 * v2)
}
fn to_expression(self, vars: &[Expression<F>]) -> Expression<F> {
let (lin_coeffs, pow_coeffs) = self.var_coeffs.split_at(WIDTH - 1);
let (lin_vars, pow_vars) = vars.split_at(WIDTH - 1);
let expr = lin_coeffs.iter().zip(lin_vars).fold(
Expression::Constant(F::ZERO),
|accu, (coeff, x)| {
if coeff.is_zero_vartime() {
accu
} else {
accu + Expression::Constant(*coeff) * x.clone()
}
},
);
pow_coeffs.iter().zip(pow_vars).fold(expr, |accu, (coeff, x)| {
if coeff.is_zero_vartime() {
accu
} else {
accu + Expression::Constant(*coeff) * sbox(x.clone())
}
})
}
}
impl<F: PoseidonField> RoundId<F> {
fn init(nb_skips: usize) -> Self {
RoundId {
nb_skips,
ids: core::array::from_fn(|i| {
if i < WIDTH {
let mut id = RoundVarId::init();
id.var_coeffs[i] = F::ONE;
id
} else {
RoundVarId::init()
}
}),
}
}
fn row_id(&self, row: &usize) -> [RoundVarId<F>; WIDTH] {
let mut last = RoundVarId::init();
last.var_coeffs[WIDTH - 1 + *row] = F::ONE;
core::array::from_fn(|i| if i == WIDTH - 1 { last } else { self.ids[i] })
}
fn update_row(self: &mut RoundId<F>, round_offset: &usize) {
let current_row = self.row_id(round_offset);
#[allow(clippy::needless_range_loop)]
for i in 0..WIDTH - 1 {
self.ids[i] = RoundVarId::from_constant_index(round_offset, &i);
for j in 0..WIDTH {
self.ids[i].add_and_mul(¤t_row[j], &F::MDS[i][j]);
}
}
self.ids[WIDTH + *round_offset] =
RoundVarId::from_constant_index(round_offset, &(WIDTH - 1));
#[allow(clippy::needless_range_loop)]
for j in 0..WIDTH {
self.ids[WIDTH + *round_offset].add_and_mul(¤t_row[j], &F::MDS[WIDTH - 1][j]);
}
}
fn generate(nb_skips: usize) -> Self {
let mut ids = RoundId::<F>::init(nb_skips);
for row in 0..1 + nb_skips {
ids.update_row(&row);
}
ids
}
fn eval_constants(&self, round: usize, arg: &mut [F]) {
let instances = &F::ROUND_CONSTANTS[round + 1..round + 2 + self.nb_skips];
self.ids[..WIDTH - 1]
.iter()
.chain(self.ids[WIDTH..].iter())
.map(|id| id.eval_constants(instances))
.zip(arg.iter_mut())
.for_each(|(c, x)| *x = c)
}
pub(crate) fn eval<const NB_SKIPS: usize>(
&self,
round_constants: &[F], instances: &mut [F], ) -> [F; NB_SKIPS] {
let mut pow_instances = [F::ZERO; NB_SKIPS];
instances[WIDTH - 1] *= instances[WIDTH - 1].square().square();
let mut pow_instances_exp = instances
.iter()
.chain(std::iter::repeat_n(&F::ZERO, NB_SKIPS))
.copied()
.collect::<Vec<_>>();
#[allow(clippy::reversed_empty_ranges)]
for i in 0..self.nb_skips {
let next =
self.ids[WIDTH + i].eval_vars(&pow_instances_exp, round_constants[WIDTH - 1 + i]);
pow_instances[i] = next;
pow_instances_exp[WIDTH + i] = next.square().square() * next;
}
let mut output = [F::ZERO; WIDTH];
for i in 0..WIDTH - 1 {
output[i] = self.ids[i].eval_vars(&pow_instances_exp, round_constants[i]);
}
output[WIDTH - 1] = self.ids[WIDTH + NB_SKIPS]
.eval_vars(&pow_instances_exp, round_constants[WIDTH + NB_SKIPS - 1]);
instances.copy_from_slice(&output);
pow_instances
}
fn round_constants_cpu(&self) -> RoundContantsCPU<F> {
let mut v = [[F::ZERO; WIDTH + NB_SKIPS_CPU]; NB_PARTIAL_ROUNDS / (1 + NB_SKIPS_CPU)];
for (round, main_round) in (NB_FULL_ROUNDS / 2..)
.take(NB_PARTIAL_ROUNDS - NB_PARTIAL_ROUNDS % (1 + NB_SKIPS_CPU))
.step_by(1 + NB_SKIPS_CPU)
.zip(0..)
{
self.eval_constants(round, &mut v[main_round])
}
v
}
fn round_constants_circuit(&self) -> RoundContantsCircuit<F> {
assert_eq!(
NB_PARTIAL_ROUNDS % (1 + NB_SKIPS_CIRCUIT),
0,
"The Poseidon chip assumes that the number of partial round (NB_PARTIAL_ROUNDS = {}) is dividable by the number of round skips (1 + NB_SKIPS = {}).",
NB_PARTIAL_ROUNDS,
1 + NB_SKIPS_CIRCUIT
);
assert_eq!(
NB_FULL_ROUNDS % 2,
0,
"The Poseidon chip assumes the number of full round (NB_FULL_ROUNDS = {}) is even.",
NB_FULL_ROUNDS
);
let mut v =
[[F::ZERO; WIDTH + NB_SKIPS_CIRCUIT]; NB_PARTIAL_ROUNDS / (1 + NB_SKIPS_CIRCUIT)];
for (round, main_round) in (NB_FULL_ROUNDS / 2..)
.take(NB_PARTIAL_ROUNDS - NB_PARTIAL_ROUNDS % (1 + NB_SKIPS_CIRCUIT))
.step_by(1 + NB_SKIPS_CIRCUIT)
.zip(0..)
{
self.eval_constants(round, &mut v[main_round])
}
v
}
pub(crate) fn to_expression(self, vars: &[Expression<F>]) -> Vec<Expression<F>> {
self.ids[..WIDTH - 1]
.iter()
.chain(self.ids[WIDTH..].iter())
.map(|id| (*id).to_expression(vars))
.collect::<Vec<_>>()
}
}
impl<F: PoseidonField> PreComputedRoundCPU<F> {
pub fn init() -> Self {
let partial_round_id = RoundId::<F>::generate(NB_SKIPS_CPU);
let round_constants = partial_round_id.round_constants_cpu();
PreComputedRoundCPU {
partial_round_id,
round_constants,
}
}
}
impl<F: PoseidonField> PreComputedRoundCircuit<F> {
pub(crate) fn init() -> Self {
let partial_round_id = RoundId::<F>::generate(NB_SKIPS_CIRCUIT);
let round_constants = partial_round_id.round_constants_circuit();
PreComputedRoundCircuit {
partial_round_id,
round_constants,
}
}
}