ark_sponge/poseidon/
mod.rs

1use crate::{
2    batch_field_cast, squeeze_field_elements_with_sizes_default_impl, Absorb, CryptographicSponge,
3    FieldBasedCryptographicSponge, FieldElementSize, SpongeExt,
4};
5use ark_ff::{BigInteger, FpParameters, PrimeField};
6use ark_std::any::TypeId;
7use ark_std::rand::Rng;
8use ark_std::vec;
9use ark_std::vec::Vec;
10
11/// constraints for Poseidon
12#[cfg(feature = "r1cs")]
13pub mod constraints;
14#[cfg(test)]
15mod tests;
16
17#[derive(Clone)]
18enum PoseidonSpongeMode {
19    Absorbing { next_absorb_index: usize },
20    Squeezing { next_squeeze_index: usize },
21}
22
23#[derive(Clone)]
24/// A duplex sponge based using the Poseidon permutation.
25///
26/// This implementation of Poseidon is entirely from Fractal's implementation in [COS20][cos]
27/// with small syntax changes.
28///
29/// [cos]: https://eprint.iacr.org/2019/1076
30pub struct PoseidonSponge<F: PrimeField> {
31    // Sponge Parameters
32    /// number of rounds in a full-round operation
33    full_rounds: u32,
34    /// number of rounds in a partial-round operation
35    partial_rounds: u32,
36    /// Exponent used in S-boxes
37    alpha: u64,
38    /// Additive Round keys. These are added before each MDS matrix application to make it an affine shift.
39    /// They are indexed by `ark[round_num][state_element_index]`
40    ark: Vec<Vec<F>>,
41    /// Maximally Distance Separating Matrix.
42    mds: Vec<Vec<F>>,
43    /// the rate (in terms of number of field elements)
44    rate: usize,
45    /// the capacity (in terms of number of field elements)
46    capacity: usize,
47
48    // Sponge State
49    /// current sponge's state (current elements in the permutation block)
50    state: Vec<F>,
51    /// current mode (whether its absorbing or squeezing)
52    mode: PoseidonSpongeMode,
53}
54
55impl<F: PrimeField> PoseidonSponge<F> {
56    fn apply_s_box(&self, state: &mut [F], is_full_round: bool) {
57        // Full rounds apply the S Box (x^alpha) to every element of state
58        if is_full_round {
59            for elem in state {
60                *elem = elem.pow(&[self.alpha]);
61            }
62        }
63        // Partial rounds apply the S Box (x^alpha) to just the final element of state
64        else {
65            state[state.len() - 1] = state[state.len() - 1].pow(&[self.alpha]);
66        }
67    }
68
69    fn apply_ark(&self, state: &mut [F], round_number: usize) {
70        for (i, state_elem) in state.iter_mut().enumerate() {
71            state_elem.add_assign(&self.ark[round_number][i]);
72        }
73    }
74
75    fn apply_mds(&self, state: &mut [F]) {
76        let mut new_state = Vec::new();
77        for i in 0..state.len() {
78            let mut cur = F::zero();
79            for (j, state_elem) in state.iter().enumerate() {
80                let term = state_elem.mul(&self.mds[i][j]);
81                cur.add_assign(&term);
82            }
83            new_state.push(cur);
84        }
85        state.clone_from_slice(&new_state[..state.len()])
86    }
87
88    fn permute(&mut self) {
89        let full_rounds_over_2 = self.full_rounds / 2;
90        let mut state = self.state.clone();
91        for i in 0..full_rounds_over_2 {
92            self.apply_ark(&mut state, i as usize);
93            self.apply_s_box(&mut state, true);
94            self.apply_mds(&mut state);
95        }
96
97        for i in full_rounds_over_2..(full_rounds_over_2 + self.partial_rounds) {
98            self.apply_ark(&mut state, i as usize);
99            self.apply_s_box(&mut state, false);
100            self.apply_mds(&mut state);
101        }
102
103        for i in
104            (full_rounds_over_2 + self.partial_rounds)..(self.partial_rounds + self.full_rounds)
105        {
106            self.apply_ark(&mut state, i as usize);
107            self.apply_s_box(&mut state, true);
108            self.apply_mds(&mut state);
109        }
110        self.state = state;
111    }
112
113    // Absorbs everything in elements, this does not end in an absorbtion.
114    fn absorb_internal(&mut self, mut rate_start_index: usize, elements: &[F]) {
115        let mut remaining_elements = elements;
116
117        loop {
118            // if we can finish in this call
119            if rate_start_index + remaining_elements.len() <= self.rate {
120                for (i, element) in remaining_elements.iter().enumerate() {
121                    self.state[i + rate_start_index] += element;
122                }
123                self.mode = PoseidonSpongeMode::Absorbing {
124                    next_absorb_index: rate_start_index + remaining_elements.len(),
125                };
126
127                return;
128            }
129            // otherwise absorb (rate - rate_start_index) elements
130            let num_elements_absorbed = self.rate - rate_start_index;
131            for (i, element) in remaining_elements
132                .iter()
133                .enumerate()
134                .take(num_elements_absorbed)
135            {
136                self.state[i + rate_start_index] += element;
137            }
138            self.permute();
139            // the input elements got truncated by num elements absorbed
140            remaining_elements = &remaining_elements[num_elements_absorbed..];
141            rate_start_index = 0;
142        }
143    }
144
145    // Squeeze |output| many elements. This does not end in a squeeze
146    fn squeeze_internal(&mut self, mut rate_start_index: usize, output: &mut [F]) {
147        let mut output_remaining = output;
148        loop {
149            // if we can finish in this call
150            if rate_start_index + output_remaining.len() <= self.rate {
151                output_remaining.clone_from_slice(
152                    &self.state[rate_start_index..(output_remaining.len() + rate_start_index)],
153                );
154                self.mode = PoseidonSpongeMode::Squeezing {
155                    next_squeeze_index: rate_start_index + output_remaining.len(),
156                };
157                return;
158            }
159            // otherwise squeeze (rate - rate_start_index) elements
160            let num_elements_squeezed = self.rate - rate_start_index;
161            output_remaining[..num_elements_squeezed].clone_from_slice(
162                &self.state[rate_start_index..(num_elements_squeezed + rate_start_index)],
163            );
164
165            // Unless we are done with squeezing in this call, permute.
166            if output_remaining.len() != self.rate {
167                self.permute();
168            }
169            // Repeat with updated output slices
170            output_remaining = &mut output_remaining[num_elements_squeezed..];
171            rate_start_index = 0;
172        }
173    }
174}
175
176/// Parameters and RNG used
177#[derive(Clone, Debug)]
178pub struct PoseidonParameters<F: PrimeField> {
179    full_rounds: u32,
180    partial_rounds: u32,
181    alpha: u64,
182    mds: Vec<Vec<F>>,
183    ark: Vec<Vec<F>>,
184}
185
186impl<F: PrimeField> PoseidonParameters<F> {
187    /// Initialize the parameter for Poseidon Sponge.
188    pub fn new(
189        full_rounds: u32,
190        partial_rounds: u32,
191        alpha: u64,
192        mds: Vec<Vec<F>>,
193        ark: Vec<Vec<F>>,
194    ) -> Self {
195        // shape check
196        assert_eq!(ark.len() as u32, full_rounds + partial_rounds);
197        for item in &ark {
198            assert_eq!(item.len(), 3);
199        }
200        Self {
201            full_rounds,
202            partial_rounds,
203            alpha,
204            mds,
205            ark,
206        }
207    }
208
209    /// Return a random round constant.
210    pub fn random_ark<R: Rng>(full_rounds: u32, rng: &mut R) -> Vec<Vec<F>> {
211        let mut ark = Vec::new();
212
213        for _ in 0..full_rounds {
214            let mut res = Vec::new();
215
216            for _ in 0..3 {
217                res.push(F::rand(rng));
218            }
219            ark.push(res);
220        }
221
222        ark
223    }
224}
225
226impl<F: PrimeField> CryptographicSponge for PoseidonSponge<F> {
227    type Parameters = PoseidonParameters<F>;
228
229    fn new(params: &Self::Parameters) -> Self {
230        // Requires F to be Alt_Bn128Fr
231        let full_rounds = params.full_rounds;
232        let partial_rounds = params.partial_rounds;
233        let alpha = params.alpha;
234
235        let mds = params.mds.clone();
236
237        let ark = params.ark.to_vec();
238
239        let rate = 2;
240        let capacity = 1;
241        let state = vec![F::zero(); rate + capacity];
242        let mode = PoseidonSpongeMode::Absorbing {
243            next_absorb_index: 0,
244        };
245
246        Self {
247            full_rounds,
248            partial_rounds,
249            alpha,
250            ark,
251            mds,
252
253            state,
254            rate,
255            capacity,
256            mode,
257        }
258    }
259
260    fn absorb(&mut self, input: &impl Absorb) {
261        let elems = input.to_sponge_field_elements_as_vec::<F>();
262        if elems.is_empty() {
263            return;
264        }
265
266        match self.mode {
267            PoseidonSpongeMode::Absorbing { next_absorb_index } => {
268                let mut absorb_index = next_absorb_index;
269                if absorb_index == self.rate {
270                    self.permute();
271                    absorb_index = 0;
272                }
273                self.absorb_internal(absorb_index, elems.as_slice());
274            }
275            PoseidonSpongeMode::Squeezing {
276                next_squeeze_index: _,
277            } => {
278                self.permute();
279                self.absorb_internal(0, elems.as_slice());
280            }
281        };
282    }
283
284    fn squeeze_bytes(&mut self, num_bytes: usize) -> Vec<u8> {
285        let usable_bytes = (F::Params::CAPACITY / 8) as usize;
286
287        let num_elements = (num_bytes + usable_bytes - 1) / usable_bytes;
288        let src_elements = self.squeeze_native_field_elements(num_elements);
289
290        let mut bytes: Vec<u8> = Vec::with_capacity(usable_bytes * num_elements);
291        for elem in &src_elements {
292            let elem_bytes = elem.into_repr().to_bytes_le();
293            bytes.extend_from_slice(&elem_bytes[..usable_bytes]);
294        }
295
296        bytes.truncate(num_bytes);
297        bytes
298    }
299
300    fn squeeze_bits(&mut self, num_bits: usize) -> Vec<bool> {
301        let usable_bits = F::Params::CAPACITY as usize;
302
303        let num_elements = (num_bits + usable_bits - 1) / usable_bits;
304        let src_elements = self.squeeze_native_field_elements(num_elements);
305
306        let mut bits: Vec<bool> = Vec::with_capacity(usable_bits * num_elements);
307        for elem in &src_elements {
308            let elem_bits = elem.into_repr().to_bits_le();
309            bits.extend_from_slice(&elem_bits[..usable_bits]);
310        }
311
312        bits.truncate(num_bits);
313        bits
314    }
315
316    fn squeeze_field_elements_with_sizes<F2: PrimeField>(
317        &mut self,
318        sizes: &[FieldElementSize],
319    ) -> Vec<F2> {
320        if F::characteristic() == F2::characteristic() {
321            // native case
322            let mut buf = Vec::with_capacity(sizes.len());
323            batch_field_cast(
324                &self.squeeze_native_field_elements_with_sizes(sizes),
325                &mut buf,
326            )
327            .unwrap();
328            buf
329        } else {
330            squeeze_field_elements_with_sizes_default_impl(self, sizes)
331        }
332    }
333
334    fn squeeze_field_elements<F2: PrimeField>(&mut self, num_elements: usize) -> Vec<F2> {
335        if TypeId::of::<F>() == TypeId::of::<F2>() {
336            let result = self.squeeze_native_field_elements(num_elements);
337            let mut cast = Vec::with_capacity(result.len());
338            batch_field_cast(&result, &mut cast).unwrap();
339            cast
340        } else {
341            self.squeeze_field_elements_with_sizes::<F2>(
342                vec![FieldElementSize::Full; num_elements].as_slice(),
343            )
344        }
345    }
346}
347
348impl<F: PrimeField> FieldBasedCryptographicSponge<F> for PoseidonSponge<F> {
349    fn squeeze_native_field_elements(&mut self, num_elements: usize) -> Vec<F> {
350        let mut squeezed_elems = vec![F::zero(); num_elements];
351        match self.mode {
352            PoseidonSpongeMode::Absorbing {
353                next_absorb_index: _,
354            } => {
355                self.permute();
356                self.squeeze_internal(0, &mut squeezed_elems);
357            }
358            PoseidonSpongeMode::Squeezing { next_squeeze_index } => {
359                let mut squeeze_index = next_squeeze_index;
360                if squeeze_index == self.rate {
361                    self.permute();
362                    squeeze_index = 0;
363                }
364                self.squeeze_internal(squeeze_index, &mut squeezed_elems);
365            }
366        };
367
368        squeezed_elems
369    }
370}
371
372#[derive(Clone)]
373/// Stores the state of a Poseidon Sponge. Does not store any parameter.
374pub struct PoseidonSpongeState<F: PrimeField> {
375    state: Vec<F>,
376    mode: PoseidonSpongeMode,
377}
378
379impl<CF: PrimeField> SpongeExt for PoseidonSponge<CF> {
380    type State = PoseidonSpongeState<CF>;
381
382    fn from_state(state: Self::State, params: &Self::Parameters) -> Self {
383        let mut sponge = Self::new(params);
384        sponge.mode = state.mode;
385        sponge.state = state.state;
386        sponge
387    }
388
389    fn into_state(self) -> Self::State {
390        Self::State {
391            state: self.state,
392            mode: self.mode,
393        }
394    }
395}