openzeppelin_crypto/poseidon2/mod.rs
1//! This module contains the Poseidon hash ([whitepaper]) function implemented
2//! as a [Sponge Function].
3//!
4//! Poseidon permutation here follows referenced in [whitepaper] original [rust
5//! implementation] with slight improvements.
6//!
7//! ## Important Usage Notes
8//!
9//! This interface provides low-level primitives and does not implement padding
10//! or domain separation. Users are responsible for:
11//! - Padding inputs appropriately
12//! - Prepending domain separation tags when needed
13//! - Managing absorb/squeeze transitions correctly
14//! - Ensuring proper security practices for their specific use case
15//!
16//! [Sponge function]: https://en.wikipedia.org/wiki/Sponge_function
17//! [whitepaper]: https://eprint.iacr.org/2023/323.pdf
18//! [rust implementation]: https://github.com/HorizenLabs/poseidon2
19
20pub mod instance;
21pub mod params;
22
23use alloc::{boxed::Box, vec, vec::Vec};
24
25use crate::{field::prime::PrimeField, poseidon2::params::PoseidonParams};
26
27/// Determines whether poseidon sponge in absorbing or squeezing state.
28/// In squeezing state, sponge can only squeeze elements.
29#[derive(Clone, Copy, Debug, PartialEq)]
30pub enum Mode {
31 /// Sponge is in absorbing state.
32 Absorbing,
33 /// Sponge is in squeezing state.
34 Squeezing,
35}
36
37/// Poseidon2 sponge that can absorb any number of `F` field elements and be
38/// squeezed to a finite number of `F` field elements.
39///
40/// ## Security Notice
41///
42/// This is a low-level primitive that does not implement padding or domain
43/// separation. Users must ensure proper input formatting and security practices
44/// for their specific cryptographic protocols.
45#[derive(Clone, Debug)]
46pub struct Poseidon2<P: PoseidonParams<F>, F: PrimeField> {
47 phantom: core::marker::PhantomData<P>,
48 state: Box<[F]>,
49 mode: Mode,
50 index: usize,
51}
52
53impl<P: PoseidonParams<F>, F: PrimeField> Default for Poseidon2<P, F> {
54 fn default() -> Self {
55 Self::new()
56 }
57}
58
59impl<P: PoseidonParams<F>, F: PrimeField> Poseidon2<P, F> {
60 /// Create a new Poseidon sponge.
61 #[must_use]
62 #[inline]
63 pub fn new() -> Self {
64 Self {
65 phantom: core::marker::PhantomData,
66 state: vec![F::zero(); P::T].into_boxed_slice(),
67 mode: Mode::Absorbing,
68 // Begin index from `CAPACITY`. Skip capacity elements.
69 index: P::CAPACITY,
70 }
71 }
72
73 /// Size of poseidon sponge's state.
74 #[must_use]
75 pub const fn state_size() -> usize {
76 P::T
77 }
78
79 /// Start index of partial rounds.
80 ///
81 /// This represents the point where the algorithm transitions from full
82 /// rounds to partial rounds in the Poseidon permutation.
83 #[must_use]
84 const fn partial_round_start() -> usize {
85 P::ROUNDS_F / 2
86 }
87
88 /// End index of partial rounds (noninclusive).
89 ///
90 /// This represents the point where the algorithm transitions from partial
91 /// rounds back to full rounds in the Poseidon permutation.
92 #[must_use]
93 const fn partial_round_end() -> usize {
94 Self::partial_round_start() + P::ROUNDS_P
95 }
96
97 /// Total number of rounds.
98 ///
99 /// This is the sum of full rounds and partial rounds in the Poseidon
100 /// permutation.
101 #[must_use]
102 const fn rounds() -> usize {
103 P::ROUNDS_F + P::ROUNDS_P
104 }
105
106 /// Absorb a single element into the sponge.
107 ///
108 /// Transitions from [`Mode::Absorbing`] to [`Mode::Squeezing`] mode are
109 /// unidirectional.
110 ///
111 /// # Panics
112 ///
113 /// May panic if absorbing while squeezing.
114 #[inline]
115 pub fn absorb(&mut self, elem: &F) {
116 if let Mode::Squeezing = self.mode {
117 panic!("cannot absorb while squeezing");
118 }
119
120 if self.index == Self::state_size() {
121 self.permute();
122 self.index = P::CAPACITY;
123 }
124
125 self.state[self.index] += elem;
126 self.index += 1;
127 }
128
129 /// Absorb batch of elements into the sponge.
130 #[inline]
131 pub fn absorb_batch(&mut self, elems: &[F]) {
132 for elem in elems {
133 self.absorb(elem);
134 }
135 }
136
137 /// Permute elements in the sponge.
138 #[inline]
139 pub fn permute(&mut self) {
140 // Linear layer at the beginning.
141 self.matmul_external();
142
143 // Run the first half of the full round.
144 for round in 0..Self::partial_round_start() {
145 self.external_round(round);
146 }
147
148 // Run the partial round.
149 for round in Self::partial_round_start()..Self::partial_round_end() {
150 self.internal_round(round);
151 }
152
153 // Run the second half of the full round.
154 for round in Self::partial_round_end()..Self::rounds() {
155 self.external_round(round);
156 }
157 }
158
159 /// Apply external round to the state.
160 ///
161 /// External rounds apply S-box to all elements of the state vector,
162 /// followed by the MDS matrix multiplication.
163 #[inline]
164 fn external_round(&mut self, round: usize) {
165 self.add_rc_external(round);
166 self.apply_sbox_external();
167 self.matmul_external();
168 }
169
170 /// Apply internal round to the state.
171 ///
172 /// Internal rounds apply S-box only to the first element of the state
173 /// vector, followed by the MDS matrix multiplication, which is more
174 /// efficient.
175 #[inline]
176 fn internal_round(&mut self, round: usize) {
177 self.add_rc_internal(round);
178 self.apply_sbox_internal();
179 self.matmul_internal();
180 }
181
182 /// Squeeze a single element from the sponge.
183 ///
184 /// When invoked from [`Mode::Absorbing`] mode, this function triggers a
185 /// permutation and transitions to [`Mode::Squeezing`] mode.
186 #[inline]
187 pub fn squeeze(&mut self) -> F {
188 if self.mode == Mode::Absorbing || self.index == Self::state_size() {
189 self.permute();
190 self.mode = Mode::Squeezing;
191 self.index = P::CAPACITY;
192 }
193
194 let elem = self.state[self.index];
195 self.index += 1;
196 elem
197 }
198
199 /// Squeeze a batch of elements from the sponge.
200 #[inline]
201 pub fn squeeze_batch(&mut self, n: usize) -> Vec<F> {
202 (0..n).map(|_| self.squeeze()).collect()
203 }
204
205 /// Apply sbox to the entire state in the external round.
206 ///
207 /// This raises each element in the state to the power of D, which is
208 /// the S-box degree defined in the Poseidon parameters.
209 #[inline]
210 fn apply_sbox_external(&mut self) {
211 for elem in &mut self.state {
212 *elem = elem.pow(P::D);
213 }
214 }
215
216 /// Apply sbox to the first element in the internal round.
217 ///
218 /// This applies the S-box (raising to power D) only to the first element of
219 /// the state, which is more efficient than applying it to all elements.
220 #[inline]
221 fn apply_sbox_internal(&mut self) {
222 self.state[0] = self.state[0].pow(P::D);
223 }
224
225 /// Apply the external MDS matrix `M_E` to the state.
226 ///
227 /// This function applies the Maximum Distance Separable (MDS) matrix
228 /// multiplication to the entire state vector for external rounds of the
229 /// Poseidon permutation. The implementation is optimized for different
230 /// state sizes.
231 #[allow(clippy::needless_range_loop)]
232 #[inline(always)]
233 fn matmul_external(&mut self) {
234 let t = Self::state_size();
235 match t {
236 2 => {
237 // Matrix circ(2, 1)
238 let sum = self.state[0] + self.state[1];
239 self.state[0] += sum;
240 self.state[1] += sum;
241 }
242 3 => {
243 // Matrix circ(2, 1, 1).
244 let sum = self.state[0] + self.state[1] + self.state[2];
245 self.state[0] += sum;
246 self.state[1] += sum;
247 self.state[2] += sum;
248 }
249 4 => {
250 self.matmul_m4();
251 }
252 8 | 12 | 16 | 20 | 24 => {
253 self.matmul_m4();
254
255 // Applying second cheap matrix for t > 4.
256 let t4 = t / 4;
257 let mut stored = [F::zero(); 4];
258 for l in 0..4 {
259 stored[l] = self.state[l];
260 for j in 1..t4 {
261 stored[l] += &self.state[4 * j + l];
262 }
263 }
264 for i in 0..self.state.len() {
265 self.state[i] += &stored[i % 4];
266 }
267 }
268 _ => {
269 panic!("not supported state size")
270 }
271 }
272 }
273
274 /// Apply the cheap 4x4 MDS matrix to each 4-element part of the state.
275 ///
276 /// Optimized matrix multiplication for state sizes that are multiples of 4.
277 /// Uses efficient in-place operations instead of constructing the full
278 /// matrix.
279 #[inline(always)]
280 fn matmul_m4(&mut self) {
281 let state = &mut self.state;
282 let t = Self::state_size();
283 let t4 = t / 4;
284 for i in 0..t4 {
285 let start_index = i * 4;
286 let mut t_0 = state[start_index];
287 t_0 += &state[start_index + 1];
288 let mut t_1 = state[start_index + 2];
289 t_1 += &state[start_index + 3];
290 let mut t_2 = state[start_index + 1];
291 t_2.double_in_place();
292 t_2 += &t_1;
293 let mut t_3 = state[start_index + 3];
294 t_3.double_in_place();
295 t_3 += &t_0;
296 let mut t_4 = t_1;
297 t_4.double_in_place();
298 t_4.double_in_place();
299 t_4 += &t_3;
300 let mut t_5 = t_0;
301 t_5.double_in_place();
302 t_5.double_in_place();
303 t_5 += &t_2;
304 let mut t_6 = t_3;
305 t_6 += &t_5;
306 let mut t_7 = t_2;
307 t_7 += &t_4;
308 state[start_index] = t_6;
309 state[start_index + 1] = t_5;
310 state[start_index + 2] = t_7;
311 state[start_index + 3] = t_4;
312 }
313 }
314
315 /// Apply the internal MDS matrix to the state.
316 ///
317 /// Optimized matrix multiplication for internal rounds.
318 #[inline(always)]
319 fn matmul_internal(&mut self) {
320 let t = Self::state_size();
321
322 match t {
323 2 => {
324 // [2, 1]
325 // [1, 3]
326 let sum = self.state[0] + self.state[1];
327 self.state[0] += ∑
328 self.state[1].double_in_place();
329 self.state[1] += ∑
330 }
331 3 => {
332 // [2, 1, 1]
333 // [1, 2, 1]
334 // [1, 1, 3]
335 let sum = self.state[0] + self.state[1] + self.state[2];
336 self.state[0] += ∑
337 self.state[1] += ∑
338 self.state[2].double_in_place();
339 self.state[2] += ∑
340 }
341 4 | 8 | 12 | 16 | 20 | 24 => {
342 let sum = self.state.iter().sum();
343
344 // Add sum + diag entry * element to each element.
345 for i in 0..self.state.len() {
346 self.state[i] *= &P::MAT_INTERNAL_DIAG_M_1[i];
347 self.state[i] += ∑
348 }
349 }
350 _ => {
351 panic!("not supported state size")
352 }
353 }
354 }
355
356 /// Add a round constant to the entire state in external round.
357 #[inline]
358 fn add_rc_external(&mut self, round: usize) {
359 for (a, b) in
360 self.state.iter_mut().zip(P::ROUND_CONSTANTS[round].iter())
361 {
362 *a += b;
363 }
364 }
365
366 // Add a round constant to the first state element in internal round.
367 #[inline]
368 fn add_rc_internal(&mut self, round: usize) {
369 self.state[0] += P::ROUND_CONSTANTS[round][0];
370 }
371}