openzeppelin_crypto/poseidon2/
mod.rs

1//! This module contains the Poseidon hash ([whitepaper]) function implemented
2//! as a [Sponge Function].
3//!
4//! Poseidon permutation here follows referenced in [whitepaper] original [rust
5//! implementation] with slight improvements.
6//!
7//! ## Important Usage Notes
8//!
9//! This interface provides low-level primitives and does not implement padding
10//! or domain separation. Users are responsible for:
11//! - Padding inputs appropriately
12//! - Prepending domain separation tags when needed
13//! - Managing absorb/squeeze transitions correctly
14//! - Ensuring proper security practices for their specific use case
15//!
16//! [Sponge function]: https://en.wikipedia.org/wiki/Sponge_function
17//! [whitepaper]: https://eprint.iacr.org/2023/323.pdf
18//! [rust implementation]: https://github.com/HorizenLabs/poseidon2
19
20pub mod instance;
21pub mod params;
22
23use alloc::{boxed::Box, vec, vec::Vec};
24
25use crate::{field::prime::PrimeField, poseidon2::params::PoseidonParams};
26
27/// Determines whether poseidon sponge in absorbing or squeezing state.
28/// In squeezing state, sponge can only squeeze elements.
29#[derive(Clone, Copy, Debug, PartialEq)]
30pub enum Mode {
31    /// Sponge is in absorbing state.
32    Absorbing,
33    /// Sponge is in squeezing state.
34    Squeezing,
35}
36
37/// Poseidon2 sponge that can absorb any number of `F` field elements and be
38/// squeezed to a finite number of `F` field elements.
39///
40/// ## Security Notice
41///
42/// This is a low-level primitive that does not implement padding or domain
43/// separation. Users must ensure proper input formatting and security practices
44/// for their specific cryptographic protocols.
45#[derive(Clone, Debug)]
46pub struct Poseidon2<P: PoseidonParams<F>, F: PrimeField> {
47    phantom: core::marker::PhantomData<P>,
48    state: Box<[F]>,
49    mode: Mode,
50    index: usize,
51}
52
53impl<P: PoseidonParams<F>, F: PrimeField> Default for Poseidon2<P, F> {
54    fn default() -> Self {
55        Self::new()
56    }
57}
58
59impl<P: PoseidonParams<F>, F: PrimeField> Poseidon2<P, F> {
60    /// Create a new Poseidon sponge.
61    #[must_use]
62    #[inline]
63    pub fn new() -> Self {
64        Self {
65            phantom: core::marker::PhantomData,
66            state: vec![F::zero(); P::T].into_boxed_slice(),
67            mode: Mode::Absorbing,
68            // Begin index from `CAPACITY`. Skip capacity elements.
69            index: P::CAPACITY,
70        }
71    }
72
73    /// Size of poseidon sponge's state.
74    #[must_use]
75    pub const fn state_size() -> usize {
76        P::T
77    }
78
79    /// Start index of partial rounds.
80    ///
81    /// This represents the point where the algorithm transitions from full
82    /// rounds to partial rounds in the Poseidon permutation.
83    #[must_use]
84    const fn partial_round_start() -> usize {
85        P::ROUNDS_F / 2
86    }
87
88    /// End index of partial rounds (noninclusive).
89    ///
90    /// This represents the point where the algorithm transitions from partial
91    /// rounds back to full rounds in the Poseidon permutation.
92    #[must_use]
93    const fn partial_round_end() -> usize {
94        Self::partial_round_start() + P::ROUNDS_P
95    }
96
97    /// Total number of rounds.
98    ///
99    /// This is the sum of full rounds and partial rounds in the Poseidon
100    /// permutation.
101    #[must_use]
102    const fn rounds() -> usize {
103        P::ROUNDS_F + P::ROUNDS_P
104    }
105
106    /// Absorb a single element into the sponge.
107    ///
108    /// Transitions from [`Mode::Absorbing`] to [`Mode::Squeezing`] mode are
109    /// unidirectional.
110    ///
111    /// # Panics
112    ///
113    /// May panic if absorbing while squeezing.
114    #[inline]
115    pub fn absorb(&mut self, elem: &F) {
116        if let Mode::Squeezing = self.mode {
117            panic!("cannot absorb while squeezing");
118        }
119
120        if self.index == Self::state_size() {
121            self.permute();
122            self.index = P::CAPACITY;
123        }
124
125        self.state[self.index] += elem;
126        self.index += 1;
127    }
128
129    /// Absorb batch of elements into the sponge.
130    #[inline]
131    pub fn absorb_batch(&mut self, elems: &[F]) {
132        for elem in elems {
133            self.absorb(elem);
134        }
135    }
136
137    /// Permute elements in the sponge.
138    #[inline]
139    pub fn permute(&mut self) {
140        // Linear layer at the beginning.
141        self.matmul_external();
142
143        // Run the first half of the full round.
144        for round in 0..Self::partial_round_start() {
145            self.external_round(round);
146        }
147
148        // Run the partial round.
149        for round in Self::partial_round_start()..Self::partial_round_end() {
150            self.internal_round(round);
151        }
152
153        // Run the second half of the full round.
154        for round in Self::partial_round_end()..Self::rounds() {
155            self.external_round(round);
156        }
157    }
158
159    /// Apply external round to the state.
160    ///
161    /// External rounds apply S-box to all elements of the state vector,
162    /// followed by the MDS matrix multiplication.
163    #[inline]
164    fn external_round(&mut self, round: usize) {
165        self.add_rc_external(round);
166        self.apply_sbox_external();
167        self.matmul_external();
168    }
169
170    /// Apply internal round to the state.
171    ///
172    /// Internal rounds apply S-box only to the first element of the state
173    /// vector, followed by the MDS matrix multiplication, which is more
174    /// efficient.
175    #[inline]
176    fn internal_round(&mut self, round: usize) {
177        self.add_rc_internal(round);
178        self.apply_sbox_internal();
179        self.matmul_internal();
180    }
181
182    /// Squeeze a single element from the sponge.
183    ///
184    /// When invoked from [`Mode::Absorbing`] mode, this function triggers a
185    /// permutation and transitions to [`Mode::Squeezing`] mode.
186    #[inline]
187    pub fn squeeze(&mut self) -> F {
188        if self.mode == Mode::Absorbing || self.index == Self::state_size() {
189            self.permute();
190            self.mode = Mode::Squeezing;
191            self.index = P::CAPACITY;
192        }
193
194        let elem = self.state[self.index];
195        self.index += 1;
196        elem
197    }
198
199    /// Squeeze a batch of elements from the sponge.
200    #[inline]
201    pub fn squeeze_batch(&mut self, n: usize) -> Vec<F> {
202        (0..n).map(|_| self.squeeze()).collect()
203    }
204
205    /// Apply sbox to the entire state in the external round.
206    ///
207    /// This raises each element in the state to the power of D, which is
208    /// the S-box degree defined in the Poseidon parameters.
209    #[inline]
210    fn apply_sbox_external(&mut self) {
211        for elem in &mut self.state {
212            *elem = elem.pow(P::D);
213        }
214    }
215
216    /// Apply sbox to the first element in the internal round.
217    ///
218    /// This applies the S-box (raising to power D) only to the first element of
219    /// the state, which is more efficient than applying it to all elements.
220    #[inline]
221    fn apply_sbox_internal(&mut self) {
222        self.state[0] = self.state[0].pow(P::D);
223    }
224
225    /// Apply the external MDS matrix `M_E` to the state.
226    ///
227    /// This function applies the Maximum Distance Separable (MDS) matrix
228    /// multiplication to the entire state vector for external rounds of the
229    /// Poseidon permutation. The implementation is optimized for different
230    /// state sizes.
231    #[allow(clippy::needless_range_loop)]
232    #[inline(always)]
233    fn matmul_external(&mut self) {
234        let t = Self::state_size();
235        match t {
236            2 => {
237                // Matrix circ(2, 1)
238                let sum = self.state[0] + self.state[1];
239                self.state[0] += sum;
240                self.state[1] += sum;
241            }
242            3 => {
243                // Matrix circ(2, 1, 1).
244                let sum = self.state[0] + self.state[1] + self.state[2];
245                self.state[0] += sum;
246                self.state[1] += sum;
247                self.state[2] += sum;
248            }
249            4 => {
250                self.matmul_m4();
251            }
252            8 | 12 | 16 | 20 | 24 => {
253                self.matmul_m4();
254
255                // Applying second cheap matrix for t > 4.
256                let t4 = t / 4;
257                let mut stored = [F::zero(); 4];
258                for l in 0..4 {
259                    stored[l] = self.state[l];
260                    for j in 1..t4 {
261                        stored[l] += &self.state[4 * j + l];
262                    }
263                }
264                for i in 0..self.state.len() {
265                    self.state[i] += &stored[i % 4];
266                }
267            }
268            _ => {
269                panic!("not supported state size")
270            }
271        }
272    }
273
274    /// Apply the cheap 4x4 MDS matrix to each 4-element part of the state.
275    ///
276    /// Optimized matrix multiplication for state sizes that are multiples of 4.
277    /// Uses efficient in-place operations instead of constructing the full
278    /// matrix.
279    #[inline(always)]
280    fn matmul_m4(&mut self) {
281        let state = &mut self.state;
282        let t = Self::state_size();
283        let t4 = t / 4;
284        for i in 0..t4 {
285            let start_index = i * 4;
286            let mut t_0 = state[start_index];
287            t_0 += &state[start_index + 1];
288            let mut t_1 = state[start_index + 2];
289            t_1 += &state[start_index + 3];
290            let mut t_2 = state[start_index + 1];
291            t_2.double_in_place();
292            t_2 += &t_1;
293            let mut t_3 = state[start_index + 3];
294            t_3.double_in_place();
295            t_3 += &t_0;
296            let mut t_4 = t_1;
297            t_4.double_in_place();
298            t_4.double_in_place();
299            t_4 += &t_3;
300            let mut t_5 = t_0;
301            t_5.double_in_place();
302            t_5.double_in_place();
303            t_5 += &t_2;
304            let mut t_6 = t_3;
305            t_6 += &t_5;
306            let mut t_7 = t_2;
307            t_7 += &t_4;
308            state[start_index] = t_6;
309            state[start_index + 1] = t_5;
310            state[start_index + 2] = t_7;
311            state[start_index + 3] = t_4;
312        }
313    }
314
315    /// Apply the internal MDS matrix to the state.
316    ///
317    /// Optimized matrix multiplication for internal rounds.
318    #[inline(always)]
319    fn matmul_internal(&mut self) {
320        let t = Self::state_size();
321
322        match t {
323            2 => {
324                // [2, 1]
325                // [1, 3]
326                let sum = self.state[0] + self.state[1];
327                self.state[0] += &sum;
328                self.state[1].double_in_place();
329                self.state[1] += &sum;
330            }
331            3 => {
332                // [2, 1, 1]
333                // [1, 2, 1]
334                // [1, 1, 3]
335                let sum = self.state[0] + self.state[1] + self.state[2];
336                self.state[0] += &sum;
337                self.state[1] += &sum;
338                self.state[2].double_in_place();
339                self.state[2] += &sum;
340            }
341            4 | 8 | 12 | 16 | 20 | 24 => {
342                let sum = self.state.iter().sum();
343
344                // Add sum + diag entry * element to each element.
345                for i in 0..self.state.len() {
346                    self.state[i] *= &P::MAT_INTERNAL_DIAG_M_1[i];
347                    self.state[i] += &sum;
348                }
349            }
350            _ => {
351                panic!("not supported state size")
352            }
353        }
354    }
355
356    /// Add a round constant to the entire state in external round.
357    #[inline]
358    fn add_rc_external(&mut self, round: usize) {
359        for (a, b) in
360            self.state.iter_mut().zip(P::ROUND_CONSTANTS[round].iter())
361        {
362            *a += b;
363        }
364    }
365
366    // Add a round constant to the first state element in internal round.
367    #[inline]
368    fn add_rc_internal(&mut self, round: usize) {
369        self.state[0] += P::ROUND_CONSTANTS[round][0];
370    }
371}