Skip to main content

sp1_recursion_circuit/
hash.rs

1use std::{
2    borrow::Borrow,
3    iter::{repeat, zip},
4};
5
6use itertools::Itertools;
7use slop_algebra::{AbstractField, Field};
8use slop_bn254::{outer_perm, Bn254Fr, OUTER_CHALLENGER_STATE_WIDTH};
9use slop_challenger::IopCtx;
10use slop_symmetric::{CryptographicHasher, Permutation};
11use sp1_hypercube::{inner_perm, SP1InnerPcs};
12use sp1_primitives::{SP1Field, SP1GlobalContext, SP1OuterGlobalContext};
13use sp1_recursion_compiler::ir::{Builder, DslIr, Felt, Var};
14use sp1_recursion_executor::{DIGEST_SIZE, HASH_RATE, PERMUTATION_WIDTH};
15
16use crate::{
17    challenger::{reduce_31, POSEIDON_2_BB_RATE},
18    CircuitConfig,
19};
20
21pub trait FieldHasher: IopCtx {
22    fn constant_compress(input: [Self::Digest; 2]) -> Self::Digest;
23
24    fn hash_slice(input: &[Self::F]) -> Self::Digest;
25}
26
27pub trait Poseidon2SP1FieldHasherVariable<C: CircuitConfig> {
28    fn poseidon2_permute(
29        builder: &mut Builder<C>,
30        state: [Felt<SP1Field>; PERMUTATION_WIDTH],
31    ) -> [Felt<SP1Field>; PERMUTATION_WIDTH];
32
33    /// Applies the Poseidon2 hash function to the given array.
34    ///
35    /// Reference: [p3_symmetric::PaddingFreeSponge]
36    fn poseidon2_hash(
37        builder: &mut Builder<C>,
38        input: &[Felt<SP1Field>],
39    ) -> [Felt<SP1Field>; DIGEST_SIZE] {
40        // static_assert(RATE < WIDTH)
41        let mut state = core::array::from_fn(|_| builder.eval(SP1Field::zero()));
42        for input_chunk in input.chunks(HASH_RATE) {
43            state[..input_chunk.len()].copy_from_slice(input_chunk);
44            state = Self::poseidon2_permute(builder, state);
45        }
46        let digest: [Felt<SP1Field>; DIGEST_SIZE] = state[..DIGEST_SIZE].try_into().unwrap();
47        digest
48    }
49}
50
51pub trait FieldHasherVariable<C: CircuitConfig>: FieldHasher {
52    type DigestVariable: Clone + Copy;
53
54    fn hash(builder: &mut Builder<C>, input: &[Felt<SP1Field>]) -> Self::DigestVariable;
55
56    fn compress(builder: &mut Builder<C>, input: [Self::DigestVariable; 2])
57        -> Self::DigestVariable;
58
59    fn assert_digest_eq(builder: &mut Builder<C>, a: Self::DigestVariable, b: Self::DigestVariable);
60
61    // Encountered many issues trying to make the following two parametrically polymorphic.
62    fn select_chain_digest(
63        builder: &mut Builder<C>,
64        should_swap: C::Bit,
65        input: [Self::DigestVariable; 2],
66    ) -> [Self::DigestVariable; 2];
67
68    fn print_digest(builder: &mut Builder<C>, digest: Self::DigestVariable);
69}
70
71impl FieldHasher for SP1GlobalContext {
72    fn constant_compress(
73        input: [<SP1GlobalContext as IopCtx>::Digest; 2],
74    ) -> <SP1GlobalContext as IopCtx>::Digest {
75        let mut pre_iter = input.into_iter().flatten().chain(repeat(SP1Field::zero()));
76        let mut pre = core::array::from_fn(move |_| pre_iter.next().unwrap());
77        inner_perm().permute_mut(&mut pre);
78        pre[..DIGEST_SIZE].try_into().unwrap()
79    }
80
81    fn hash_slice(input: &[SP1Field]) -> <SP1GlobalContext as IopCtx>::Digest {
82        let mut state = [SP1Field::zero(); PERMUTATION_WIDTH];
83        for input_chunk in input.chunks(HASH_RATE) {
84            state[..input_chunk.len()].copy_from_slice(input_chunk);
85            inner_perm().permute_mut(&mut state);
86        }
87        let digest: [SP1Field; DIGEST_SIZE] = state[..DIGEST_SIZE].try_into().unwrap();
88        digest
89    }
90}
91
92impl<C: CircuitConfig> Poseidon2SP1FieldHasherVariable<C> for SP1InnerPcs {
93    fn poseidon2_permute(
94        builder: &mut Builder<C>,
95        input: [Felt<SP1Field>; PERMUTATION_WIDTH],
96    ) -> [Felt<SP1Field>; PERMUTATION_WIDTH] {
97        C::poseidon2_permute_v2(builder, input)
98    }
99}
100
101impl<C: CircuitConfig> Poseidon2SP1FieldHasherVariable<C> for SP1GlobalContext {
102    fn poseidon2_permute(
103        builder: &mut Builder<C>,
104        input: [Felt<SP1Field>; PERMUTATION_WIDTH],
105    ) -> [Felt<SP1Field>; PERMUTATION_WIDTH] {
106        C::poseidon2_permute_v2(builder, input)
107    }
108}
109
110impl<C: CircuitConfig<Bit = Felt<SP1Field>>> FieldHasherVariable<C> for SP1GlobalContext {
111    type DigestVariable = [Felt<SP1Field>; DIGEST_SIZE];
112
113    fn hash(builder: &mut Builder<C>, input: &[Felt<SP1Field>]) -> Self::DigestVariable {
114        <Self as Poseidon2SP1FieldHasherVariable<C>>::poseidon2_hash(builder, input)
115    }
116
117    fn compress(
118        builder: &mut Builder<C>,
119        input: [Self::DigestVariable; 2],
120    ) -> Self::DigestVariable {
121        C::poseidon2_compress_v2(builder, input.into_iter().flatten())
122    }
123
124    fn assert_digest_eq(
125        builder: &mut Builder<C>,
126        a: Self::DigestVariable,
127        b: Self::DigestVariable,
128    ) {
129        // Push the instruction directly instead of passing through `assert_felt_eq` in order to
130        //avoid symbolic expression overhead.
131        zip(a, b).for_each(|(e1, e2)| builder.push_op(DslIr::AssertEqF(e1, e2)));
132    }
133
134    fn select_chain_digest(
135        builder: &mut Builder<C>,
136        should_swap: <C as CircuitConfig>::Bit,
137        input: [Self::DigestVariable; 2],
138    ) -> [Self::DigestVariable; 2] {
139        let result0: [Felt<SP1Field>; DIGEST_SIZE] = core::array::from_fn(|_| builder.uninit());
140        let result1: [Felt<SP1Field>; DIGEST_SIZE] = core::array::from_fn(|_| builder.uninit());
141
142        (0..DIGEST_SIZE).for_each(|i| {
143            builder.push_op(DslIr::Select(
144                should_swap,
145                result0[i],
146                result1[i],
147                input[0][i],
148                input[1][i],
149            ));
150        });
151
152        [result0, result1]
153    }
154
155    fn print_digest(builder: &mut Builder<C>, digest: Self::DigestVariable) {
156        for d in digest.iter() {
157            builder.print_f(*d);
158        }
159    }
160}
161
162impl<C: CircuitConfig> Poseidon2SP1FieldHasherVariable<C> for SP1OuterGlobalContext {
163    fn poseidon2_permute(
164        builder: &mut Builder<C>,
165        state: [Felt<SP1Field>; PERMUTATION_WIDTH],
166    ) -> [Felt<SP1Field>; PERMUTATION_WIDTH] {
167        let state: [Felt<_>; PERMUTATION_WIDTH] = state.map(|x| builder.eval(x));
168        builder.push_op(DslIr::CircuitPoseidon2PermuteKoalaBear(Box::new(state)));
169        state
170    }
171}
172
173pub const BN254_DIGEST_SIZE: usize = 1;
174
175impl FieldHasher for SP1OuterGlobalContext {
176    fn constant_compress(input: [Self::Digest; 2]) -> Self::Digest {
177        let mut state = [
178            Borrow::<[Bn254Fr; 1]>::borrow(&input[0])[0],
179            Borrow::<[Bn254Fr; 1]>::borrow(&input[1])[0],
180            Bn254Fr::zero(),
181        ];
182        outer_perm().permute_mut(&mut state);
183        [state[0]; BN254_DIGEST_SIZE].into()
184    }
185
186    fn hash_slice(input: &[SP1Field]) -> Self::Digest {
187        SP1OuterGlobalContext::default_hasher_and_compressor().0.hash_slice(input)
188    }
189}
190
191impl<C: CircuitConfig<N = Bn254Fr, Bit = Var<Bn254Fr>>> FieldHasherVariable<C>
192    for SP1OuterGlobalContext
193{
194    type DigestVariable = [Var<Bn254Fr>; BN254_DIGEST_SIZE];
195
196    fn hash(builder: &mut Builder<C>, input: &[Felt<SP1Field>]) -> Self::DigestVariable {
197        assert!(C::N::bits() == slop_bn254::Bn254Fr::bits());
198        assert!(SP1Field::bits() == sp1_primitives::SP1Field::bits());
199        let num_f_elms = C::N::bits() / SP1Field::bits();
200        let mut state: [Var<C::N>; OUTER_CHALLENGER_STATE_WIDTH] =
201            [builder.eval(C::N::zero()), builder.eval(C::N::zero()), builder.eval(C::N::zero())];
202        for block_chunk in &input.iter().chunks(POSEIDON_2_BB_RATE) {
203            for (chunk_id, chunk) in (&block_chunk.chunks(num_f_elms)).into_iter().enumerate() {
204                let chunk = chunk.copied().collect::<Vec<_>>();
205                state[chunk_id] = reduce_31(builder, chunk.as_slice());
206            }
207            builder.push_op(DslIr::CircuitPoseidon2Permute(state))
208        }
209
210        [state[0]; BN254_DIGEST_SIZE]
211    }
212
213    fn compress(
214        builder: &mut Builder<C>,
215        input: [Self::DigestVariable; 2],
216    ) -> Self::DigestVariable {
217        let state: [Var<C::N>; OUTER_CHALLENGER_STATE_WIDTH] =
218            [builder.eval(input[0][0]), builder.eval(input[1][0]), builder.eval(C::N::zero())];
219        builder.push_op(DslIr::CircuitPoseidon2Permute(state));
220        [state[0]; BN254_DIGEST_SIZE]
221    }
222
223    fn assert_digest_eq(
224        builder: &mut Builder<C>,
225        a: Self::DigestVariable,
226        b: Self::DigestVariable,
227    ) {
228        zip(a, b).for_each(|(e1, e2)| builder.assert_var_eq(e1, e2));
229    }
230
231    fn select_chain_digest(
232        builder: &mut Builder<C>,
233        should_swap: <C as CircuitConfig>::Bit,
234        input: [Self::DigestVariable; 2],
235    ) -> [Self::DigestVariable; 2] {
236        let result0: [Var<_>; BN254_DIGEST_SIZE] = core::array::from_fn(|j| {
237            let result = builder.uninit();
238            builder.push_op(DslIr::CircuitSelectV(should_swap, input[1][j], input[0][j], result));
239            result
240        });
241        let result1: [Var<_>; BN254_DIGEST_SIZE] = core::array::from_fn(|j| {
242            let result = builder.uninit();
243            builder.push_op(DslIr::CircuitSelectV(should_swap, input[0][j], input[1][j], result));
244            result
245        });
246
247        [result0, result1]
248    }
249
250    fn print_digest(builder: &mut Builder<C>, digest: Self::DigestVariable) {
251        for d in digest.iter() {
252            builder.print_v(*d);
253        }
254    }
255}