use std::{
borrow::Borrow,
iter::{repeat, zip},
};
use itertools::Itertools;
use slop_algebra::{AbstractField, Field};
use slop_bn254::{outer_perm, Bn254Fr, OUTER_CHALLENGER_STATE_WIDTH};
use slop_challenger::IopCtx;
use slop_symmetric::{CryptographicHasher, Permutation};
use sp1_hypercube::{inner_perm, SP1InnerPcs};
use sp1_primitives::{SP1Field, SP1GlobalContext, SP1OuterGlobalContext};
use sp1_recursion_compiler::ir::{Builder, DslIr, Felt, Var};
use sp1_recursion_executor::{DIGEST_SIZE, HASH_RATE, PERMUTATION_WIDTH};
use crate::{
challenger::{reduce_31, POSEIDON_2_BB_RATE},
CircuitConfig,
};
pub trait FieldHasher: IopCtx {
fn constant_compress(input: [Self::Digest; 2]) -> Self::Digest;
fn hash_slice(input: &[Self::F]) -> Self::Digest;
}
pub trait Poseidon2SP1FieldHasherVariable<C: CircuitConfig> {
fn poseidon2_permute(
builder: &mut Builder<C>,
state: [Felt<SP1Field>; PERMUTATION_WIDTH],
) -> [Felt<SP1Field>; PERMUTATION_WIDTH];
fn poseidon2_hash(
builder: &mut Builder<C>,
input: &[Felt<SP1Field>],
) -> [Felt<SP1Field>; DIGEST_SIZE] {
let mut state = core::array::from_fn(|_| builder.eval(SP1Field::zero()));
for input_chunk in input.chunks(HASH_RATE) {
state[..input_chunk.len()].copy_from_slice(input_chunk);
state = Self::poseidon2_permute(builder, state);
}
let digest: [Felt<SP1Field>; DIGEST_SIZE] = state[..DIGEST_SIZE].try_into().unwrap();
digest
}
}
pub trait FieldHasherVariable<C: CircuitConfig>: FieldHasher {
type DigestVariable: Clone + Copy;
fn hash(builder: &mut Builder<C>, input: &[Felt<SP1Field>]) -> Self::DigestVariable;
fn compress(builder: &mut Builder<C>, input: [Self::DigestVariable; 2])
-> Self::DigestVariable;
fn assert_digest_eq(builder: &mut Builder<C>, a: Self::DigestVariable, b: Self::DigestVariable);
fn select_chain_digest(
builder: &mut Builder<C>,
should_swap: C::Bit,
input: [Self::DigestVariable; 2],
) -> [Self::DigestVariable; 2];
fn print_digest(builder: &mut Builder<C>, digest: Self::DigestVariable);
}
impl FieldHasher for SP1GlobalContext {
fn constant_compress(
input: [<SP1GlobalContext as IopCtx>::Digest; 2],
) -> <SP1GlobalContext as IopCtx>::Digest {
let mut pre_iter = input.into_iter().flatten().chain(repeat(SP1Field::zero()));
let mut pre = core::array::from_fn(move |_| pre_iter.next().unwrap());
inner_perm().permute_mut(&mut pre);
pre[..DIGEST_SIZE].try_into().unwrap()
}
fn hash_slice(input: &[SP1Field]) -> <SP1GlobalContext as IopCtx>::Digest {
let mut state = [SP1Field::zero(); PERMUTATION_WIDTH];
for input_chunk in input.chunks(HASH_RATE) {
state[..input_chunk.len()].copy_from_slice(input_chunk);
inner_perm().permute_mut(&mut state);
}
let digest: [SP1Field; DIGEST_SIZE] = state[..DIGEST_SIZE].try_into().unwrap();
digest
}
}
impl<C: CircuitConfig> Poseidon2SP1FieldHasherVariable<C> for SP1InnerPcs {
fn poseidon2_permute(
builder: &mut Builder<C>,
input: [Felt<SP1Field>; PERMUTATION_WIDTH],
) -> [Felt<SP1Field>; PERMUTATION_WIDTH] {
C::poseidon2_permute_v2(builder, input)
}
}
impl<C: CircuitConfig> Poseidon2SP1FieldHasherVariable<C> for SP1GlobalContext {
fn poseidon2_permute(
builder: &mut Builder<C>,
input: [Felt<SP1Field>; PERMUTATION_WIDTH],
) -> [Felt<SP1Field>; PERMUTATION_WIDTH] {
C::poseidon2_permute_v2(builder, input)
}
}
impl<C: CircuitConfig<Bit = Felt<SP1Field>>> FieldHasherVariable<C> for SP1GlobalContext {
type DigestVariable = [Felt<SP1Field>; DIGEST_SIZE];
fn hash(builder: &mut Builder<C>, input: &[Felt<SP1Field>]) -> Self::DigestVariable {
<Self as Poseidon2SP1FieldHasherVariable<C>>::poseidon2_hash(builder, input)
}
fn compress(
builder: &mut Builder<C>,
input: [Self::DigestVariable; 2],
) -> Self::DigestVariable {
C::poseidon2_compress_v2(builder, input.into_iter().flatten())
}
fn assert_digest_eq(
builder: &mut Builder<C>,
a: Self::DigestVariable,
b: Self::DigestVariable,
) {
zip(a, b).for_each(|(e1, e2)| builder.push_op(DslIr::AssertEqF(e1, e2)));
}
fn select_chain_digest(
builder: &mut Builder<C>,
should_swap: <C as CircuitConfig>::Bit,
input: [Self::DigestVariable; 2],
) -> [Self::DigestVariable; 2] {
let result0: [Felt<SP1Field>; DIGEST_SIZE] = core::array::from_fn(|_| builder.uninit());
let result1: [Felt<SP1Field>; DIGEST_SIZE] = core::array::from_fn(|_| builder.uninit());
(0..DIGEST_SIZE).for_each(|i| {
builder.push_op(DslIr::Select(
should_swap,
result0[i],
result1[i],
input[0][i],
input[1][i],
));
});
[result0, result1]
}
fn print_digest(builder: &mut Builder<C>, digest: Self::DigestVariable) {
for d in digest.iter() {
builder.print_f(*d);
}
}
}
impl<C: CircuitConfig> Poseidon2SP1FieldHasherVariable<C> for SP1OuterGlobalContext {
fn poseidon2_permute(
builder: &mut Builder<C>,
state: [Felt<SP1Field>; PERMUTATION_WIDTH],
) -> [Felt<SP1Field>; PERMUTATION_WIDTH] {
let state: [Felt<_>; PERMUTATION_WIDTH] = state.map(|x| builder.eval(x));
builder.push_op(DslIr::CircuitPoseidon2PermuteKoalaBear(Box::new(state)));
state
}
}
pub const BN254_DIGEST_SIZE: usize = 1;
impl FieldHasher for SP1OuterGlobalContext {
fn constant_compress(input: [Self::Digest; 2]) -> Self::Digest {
let mut state = [
Borrow::<[Bn254Fr; 1]>::borrow(&input[0])[0],
Borrow::<[Bn254Fr; 1]>::borrow(&input[1])[0],
Bn254Fr::zero(),
];
outer_perm().permute_mut(&mut state);
[state[0]; BN254_DIGEST_SIZE].into()
}
fn hash_slice(input: &[SP1Field]) -> Self::Digest {
SP1OuterGlobalContext::default_hasher_and_compressor().0.hash_slice(input)
}
}
impl<C: CircuitConfig<N = Bn254Fr, Bit = Var<Bn254Fr>>> FieldHasherVariable<C>
for SP1OuterGlobalContext
{
type DigestVariable = [Var<Bn254Fr>; BN254_DIGEST_SIZE];
fn hash(builder: &mut Builder<C>, input: &[Felt<SP1Field>]) -> Self::DigestVariable {
assert!(C::N::bits() == slop_bn254::Bn254Fr::bits());
assert!(SP1Field::bits() == sp1_primitives::SP1Field::bits());
let num_f_elms = C::N::bits() / SP1Field::bits();
let mut state: [Var<C::N>; OUTER_CHALLENGER_STATE_WIDTH] =
[builder.eval(C::N::zero()), builder.eval(C::N::zero()), builder.eval(C::N::zero())];
for block_chunk in &input.iter().chunks(POSEIDON_2_BB_RATE) {
for (chunk_id, chunk) in (&block_chunk.chunks(num_f_elms)).into_iter().enumerate() {
let chunk = chunk.copied().collect::<Vec<_>>();
state[chunk_id] = reduce_31(builder, chunk.as_slice());
}
builder.push_op(DslIr::CircuitPoseidon2Permute(state))
}
[state[0]; BN254_DIGEST_SIZE]
}
fn compress(
builder: &mut Builder<C>,
input: [Self::DigestVariable; 2],
) -> Self::DigestVariable {
let state: [Var<C::N>; OUTER_CHALLENGER_STATE_WIDTH] =
[builder.eval(input[0][0]), builder.eval(input[1][0]), builder.eval(C::N::zero())];
builder.push_op(DslIr::CircuitPoseidon2Permute(state));
[state[0]; BN254_DIGEST_SIZE]
}
fn assert_digest_eq(
builder: &mut Builder<C>,
a: Self::DigestVariable,
b: Self::DigestVariable,
) {
zip(a, b).for_each(|(e1, e2)| builder.assert_var_eq(e1, e2));
}
fn select_chain_digest(
builder: &mut Builder<C>,
should_swap: <C as CircuitConfig>::Bit,
input: [Self::DigestVariable; 2],
) -> [Self::DigestVariable; 2] {
let result0: [Var<_>; BN254_DIGEST_SIZE] = core::array::from_fn(|j| {
let result = builder.uninit();
builder.push_op(DslIr::CircuitSelectV(should_swap, input[1][j], input[0][j], result));
result
});
let result1: [Var<_>; BN254_DIGEST_SIZE] = core::array::from_fn(|j| {
let result = builder.uninit();
builder.push_op(DslIr::CircuitSelectV(should_swap, input[0][j], input[1][j], result));
result
});
[result0, result1]
}
fn print_digest(builder: &mut Builder<C>, digest: Self::DigestVariable) {
for d in digest.iter() {
builder.print_v(*d);
}
}
}