1use std::{
2 borrow::Borrow,
3 iter::{repeat, zip},
4};
5
6use itertools::Itertools;
7use slop_algebra::{AbstractField, Field};
8use slop_bn254::{outer_perm, Bn254Fr, OUTER_CHALLENGER_STATE_WIDTH};
9use slop_challenger::IopCtx;
10use slop_symmetric::{CryptographicHasher, Permutation};
11use sp1_hypercube::{inner_perm, SP1InnerPcs};
12use sp1_primitives::{SP1Field, SP1GlobalContext, SP1OuterGlobalContext};
13use sp1_recursion_compiler::ir::{Builder, DslIr, Felt, Var};
14use sp1_recursion_executor::{DIGEST_SIZE, HASH_RATE, PERMUTATION_WIDTH};
15
16use crate::{
17 challenger::{reduce_31, POSEIDON_2_BB_RATE},
18 CircuitConfig,
19};
20
21pub trait FieldHasher: IopCtx {
22 fn constant_compress(input: [Self::Digest; 2]) -> Self::Digest;
23
24 fn hash_slice(input: &[Self::F]) -> Self::Digest;
25}
26
27pub trait Poseidon2SP1FieldHasherVariable<C: CircuitConfig> {
28 fn poseidon2_permute(
29 builder: &mut Builder<C>,
30 state: [Felt<SP1Field>; PERMUTATION_WIDTH],
31 ) -> [Felt<SP1Field>; PERMUTATION_WIDTH];
32
33 fn poseidon2_hash(
37 builder: &mut Builder<C>,
38 input: &[Felt<SP1Field>],
39 ) -> [Felt<SP1Field>; DIGEST_SIZE] {
40 let mut state = core::array::from_fn(|_| builder.eval(SP1Field::zero()));
42 for input_chunk in input.chunks(HASH_RATE) {
43 state[..input_chunk.len()].copy_from_slice(input_chunk);
44 state = Self::poseidon2_permute(builder, state);
45 }
46 let digest: [Felt<SP1Field>; DIGEST_SIZE] = state[..DIGEST_SIZE].try_into().unwrap();
47 digest
48 }
49}
50
51pub trait FieldHasherVariable<C: CircuitConfig>: FieldHasher {
52 type DigestVariable: Clone + Copy;
53
54 fn hash(builder: &mut Builder<C>, input: &[Felt<SP1Field>]) -> Self::DigestVariable;
55
56 fn compress(builder: &mut Builder<C>, input: [Self::DigestVariable; 2])
57 -> Self::DigestVariable;
58
59 fn assert_digest_eq(builder: &mut Builder<C>, a: Self::DigestVariable, b: Self::DigestVariable);
60
61 fn select_chain_digest(
63 builder: &mut Builder<C>,
64 should_swap: C::Bit,
65 input: [Self::DigestVariable; 2],
66 ) -> [Self::DigestVariable; 2];
67
68 fn print_digest(builder: &mut Builder<C>, digest: Self::DigestVariable);
69}
70
71impl FieldHasher for SP1GlobalContext {
72 fn constant_compress(
73 input: [<SP1GlobalContext as IopCtx>::Digest; 2],
74 ) -> <SP1GlobalContext as IopCtx>::Digest {
75 let mut pre_iter = input.into_iter().flatten().chain(repeat(SP1Field::zero()));
76 let mut pre = core::array::from_fn(move |_| pre_iter.next().unwrap());
77 inner_perm().permute_mut(&mut pre);
78 pre[..DIGEST_SIZE].try_into().unwrap()
79 }
80
81 fn hash_slice(input: &[SP1Field]) -> <SP1GlobalContext as IopCtx>::Digest {
82 let mut state = [SP1Field::zero(); PERMUTATION_WIDTH];
83 for input_chunk in input.chunks(HASH_RATE) {
84 state[..input_chunk.len()].copy_from_slice(input_chunk);
85 inner_perm().permute_mut(&mut state);
86 }
87 let digest: [SP1Field; DIGEST_SIZE] = state[..DIGEST_SIZE].try_into().unwrap();
88 digest
89 }
90}
91
92impl<C: CircuitConfig> Poseidon2SP1FieldHasherVariable<C> for SP1InnerPcs {
93 fn poseidon2_permute(
94 builder: &mut Builder<C>,
95 input: [Felt<SP1Field>; PERMUTATION_WIDTH],
96 ) -> [Felt<SP1Field>; PERMUTATION_WIDTH] {
97 C::poseidon2_permute_v2(builder, input)
98 }
99}
100
101impl<C: CircuitConfig> Poseidon2SP1FieldHasherVariable<C> for SP1GlobalContext {
102 fn poseidon2_permute(
103 builder: &mut Builder<C>,
104 input: [Felt<SP1Field>; PERMUTATION_WIDTH],
105 ) -> [Felt<SP1Field>; PERMUTATION_WIDTH] {
106 C::poseidon2_permute_v2(builder, input)
107 }
108}
109
110impl<C: CircuitConfig<Bit = Felt<SP1Field>>> FieldHasherVariable<C> for SP1GlobalContext {
111 type DigestVariable = [Felt<SP1Field>; DIGEST_SIZE];
112
113 fn hash(builder: &mut Builder<C>, input: &[Felt<SP1Field>]) -> Self::DigestVariable {
114 <Self as Poseidon2SP1FieldHasherVariable<C>>::poseidon2_hash(builder, input)
115 }
116
117 fn compress(
118 builder: &mut Builder<C>,
119 input: [Self::DigestVariable; 2],
120 ) -> Self::DigestVariable {
121 C::poseidon2_compress_v2(builder, input.into_iter().flatten())
122 }
123
124 fn assert_digest_eq(
125 builder: &mut Builder<C>,
126 a: Self::DigestVariable,
127 b: Self::DigestVariable,
128 ) {
129 zip(a, b).for_each(|(e1, e2)| builder.push_op(DslIr::AssertEqF(e1, e2)));
132 }
133
134 fn select_chain_digest(
135 builder: &mut Builder<C>,
136 should_swap: <C as CircuitConfig>::Bit,
137 input: [Self::DigestVariable; 2],
138 ) -> [Self::DigestVariable; 2] {
139 let result0: [Felt<SP1Field>; DIGEST_SIZE] = core::array::from_fn(|_| builder.uninit());
140 let result1: [Felt<SP1Field>; DIGEST_SIZE] = core::array::from_fn(|_| builder.uninit());
141
142 (0..DIGEST_SIZE).for_each(|i| {
143 builder.push_op(DslIr::Select(
144 should_swap,
145 result0[i],
146 result1[i],
147 input[0][i],
148 input[1][i],
149 ));
150 });
151
152 [result0, result1]
153 }
154
155 fn print_digest(builder: &mut Builder<C>, digest: Self::DigestVariable) {
156 for d in digest.iter() {
157 builder.print_f(*d);
158 }
159 }
160}
161
162impl<C: CircuitConfig> Poseidon2SP1FieldHasherVariable<C> for SP1OuterGlobalContext {
163 fn poseidon2_permute(
164 builder: &mut Builder<C>,
165 state: [Felt<SP1Field>; PERMUTATION_WIDTH],
166 ) -> [Felt<SP1Field>; PERMUTATION_WIDTH] {
167 let state: [Felt<_>; PERMUTATION_WIDTH] = state.map(|x| builder.eval(x));
168 builder.push_op(DslIr::CircuitPoseidon2PermuteKoalaBear(Box::new(state)));
169 state
170 }
171}
172
173pub const BN254_DIGEST_SIZE: usize = 1;
174
175impl FieldHasher for SP1OuterGlobalContext {
176 fn constant_compress(input: [Self::Digest; 2]) -> Self::Digest {
177 let mut state = [
178 Borrow::<[Bn254Fr; 1]>::borrow(&input[0])[0],
179 Borrow::<[Bn254Fr; 1]>::borrow(&input[1])[0],
180 Bn254Fr::zero(),
181 ];
182 outer_perm().permute_mut(&mut state);
183 [state[0]; BN254_DIGEST_SIZE].into()
184 }
185
186 fn hash_slice(input: &[SP1Field]) -> Self::Digest {
187 SP1OuterGlobalContext::default_hasher_and_compressor().0.hash_slice(input)
188 }
189}
190
191impl<C: CircuitConfig<N = Bn254Fr, Bit = Var<Bn254Fr>>> FieldHasherVariable<C>
192 for SP1OuterGlobalContext
193{
194 type DigestVariable = [Var<Bn254Fr>; BN254_DIGEST_SIZE];
195
196 fn hash(builder: &mut Builder<C>, input: &[Felt<SP1Field>]) -> Self::DigestVariable {
197 assert!(C::N::bits() == slop_bn254::Bn254Fr::bits());
198 assert!(SP1Field::bits() == sp1_primitives::SP1Field::bits());
199 let num_f_elms = C::N::bits() / SP1Field::bits();
200 let mut state: [Var<C::N>; OUTER_CHALLENGER_STATE_WIDTH] =
201 [builder.eval(C::N::zero()), builder.eval(C::N::zero()), builder.eval(C::N::zero())];
202 for block_chunk in &input.iter().chunks(POSEIDON_2_BB_RATE) {
203 for (chunk_id, chunk) in (&block_chunk.chunks(num_f_elms)).into_iter().enumerate() {
204 let chunk = chunk.copied().collect::<Vec<_>>();
205 state[chunk_id] = reduce_31(builder, chunk.as_slice());
206 }
207 builder.push_op(DslIr::CircuitPoseidon2Permute(state))
208 }
209
210 [state[0]; BN254_DIGEST_SIZE]
211 }
212
213 fn compress(
214 builder: &mut Builder<C>,
215 input: [Self::DigestVariable; 2],
216 ) -> Self::DigestVariable {
217 let state: [Var<C::N>; OUTER_CHALLENGER_STATE_WIDTH] =
218 [builder.eval(input[0][0]), builder.eval(input[1][0]), builder.eval(C::N::zero())];
219 builder.push_op(DslIr::CircuitPoseidon2Permute(state));
220 [state[0]; BN254_DIGEST_SIZE]
221 }
222
223 fn assert_digest_eq(
224 builder: &mut Builder<C>,
225 a: Self::DigestVariable,
226 b: Self::DigestVariable,
227 ) {
228 zip(a, b).for_each(|(e1, e2)| builder.assert_var_eq(e1, e2));
229 }
230
231 fn select_chain_digest(
232 builder: &mut Builder<C>,
233 should_swap: <C as CircuitConfig>::Bit,
234 input: [Self::DigestVariable; 2],
235 ) -> [Self::DigestVariable; 2] {
236 let result0: [Var<_>; BN254_DIGEST_SIZE] = core::array::from_fn(|j| {
237 let result = builder.uninit();
238 builder.push_op(DslIr::CircuitSelectV(should_swap, input[1][j], input[0][j], result));
239 result
240 });
241 let result1: [Var<_>; BN254_DIGEST_SIZE] = core::array::from_fn(|j| {
242 let result = builder.uninit();
243 builder.push_op(DslIr::CircuitSelectV(should_swap, input[0][j], input[1][j], result));
244 result
245 });
246
247 [result0, result1]
248 }
249
250 fn print_digest(builder: &mut Builder<C>, digest: Self::DigestVariable) {
251 for d in digest.iter() {
252 builder.print_v(*d);
253 }
254 }
255}