1use crate::{
2 batch_field_cast, squeeze_field_elements_with_sizes_default_impl, Absorb, CryptographicSponge,
3 FieldBasedCryptographicSponge, FieldElementSize, SpongeExt,
4};
5use ark_ff::{BigInteger, FpParameters, PrimeField};
6use ark_std::any::TypeId;
7use ark_std::rand::Rng;
8use ark_std::vec;
9use ark_std::vec::Vec;
10
11#[cfg(feature = "r1cs")]
13pub mod constraints;
14#[cfg(test)]
15mod tests;
16
17#[derive(Clone)]
18enum PoseidonSpongeMode {
19 Absorbing { next_absorb_index: usize },
20 Squeezing { next_squeeze_index: usize },
21}
22
23#[derive(Clone)]
24pub struct PoseidonSponge<F: PrimeField> {
31 full_rounds: u32,
34 partial_rounds: u32,
36 alpha: u64,
38 ark: Vec<Vec<F>>,
41 mds: Vec<Vec<F>>,
43 rate: usize,
45 capacity: usize,
47
48 state: Vec<F>,
51 mode: PoseidonSpongeMode,
53}
54
55impl<F: PrimeField> PoseidonSponge<F> {
56 fn apply_s_box(&self, state: &mut [F], is_full_round: bool) {
57 if is_full_round {
59 for elem in state {
60 *elem = elem.pow(&[self.alpha]);
61 }
62 }
63 else {
65 state[state.len() - 1] = state[state.len() - 1].pow(&[self.alpha]);
66 }
67 }
68
69 fn apply_ark(&self, state: &mut [F], round_number: usize) {
70 for (i, state_elem) in state.iter_mut().enumerate() {
71 state_elem.add_assign(&self.ark[round_number][i]);
72 }
73 }
74
75 fn apply_mds(&self, state: &mut [F]) {
76 let mut new_state = Vec::new();
77 for i in 0..state.len() {
78 let mut cur = F::zero();
79 for (j, state_elem) in state.iter().enumerate() {
80 let term = state_elem.mul(&self.mds[i][j]);
81 cur.add_assign(&term);
82 }
83 new_state.push(cur);
84 }
85 state.clone_from_slice(&new_state[..state.len()])
86 }
87
88 fn permute(&mut self) {
89 let full_rounds_over_2 = self.full_rounds / 2;
90 let mut state = self.state.clone();
91 for i in 0..full_rounds_over_2 {
92 self.apply_ark(&mut state, i as usize);
93 self.apply_s_box(&mut state, true);
94 self.apply_mds(&mut state);
95 }
96
97 for i in full_rounds_over_2..(full_rounds_over_2 + self.partial_rounds) {
98 self.apply_ark(&mut state, i as usize);
99 self.apply_s_box(&mut state, false);
100 self.apply_mds(&mut state);
101 }
102
103 for i in
104 (full_rounds_over_2 + self.partial_rounds)..(self.partial_rounds + self.full_rounds)
105 {
106 self.apply_ark(&mut state, i as usize);
107 self.apply_s_box(&mut state, true);
108 self.apply_mds(&mut state);
109 }
110 self.state = state;
111 }
112
113 fn absorb_internal(&mut self, mut rate_start_index: usize, elements: &[F]) {
115 let mut remaining_elements = elements;
116
117 loop {
118 if rate_start_index + remaining_elements.len() <= self.rate {
120 for (i, element) in remaining_elements.iter().enumerate() {
121 self.state[i + rate_start_index] += element;
122 }
123 self.mode = PoseidonSpongeMode::Absorbing {
124 next_absorb_index: rate_start_index + remaining_elements.len(),
125 };
126
127 return;
128 }
129 let num_elements_absorbed = self.rate - rate_start_index;
131 for (i, element) in remaining_elements
132 .iter()
133 .enumerate()
134 .take(num_elements_absorbed)
135 {
136 self.state[i + rate_start_index] += element;
137 }
138 self.permute();
139 remaining_elements = &remaining_elements[num_elements_absorbed..];
141 rate_start_index = 0;
142 }
143 }
144
145 fn squeeze_internal(&mut self, mut rate_start_index: usize, output: &mut [F]) {
147 let mut output_remaining = output;
148 loop {
149 if rate_start_index + output_remaining.len() <= self.rate {
151 output_remaining.clone_from_slice(
152 &self.state[rate_start_index..(output_remaining.len() + rate_start_index)],
153 );
154 self.mode = PoseidonSpongeMode::Squeezing {
155 next_squeeze_index: rate_start_index + output_remaining.len(),
156 };
157 return;
158 }
159 let num_elements_squeezed = self.rate - rate_start_index;
161 output_remaining[..num_elements_squeezed].clone_from_slice(
162 &self.state[rate_start_index..(num_elements_squeezed + rate_start_index)],
163 );
164
165 if output_remaining.len() != self.rate {
167 self.permute();
168 }
169 output_remaining = &mut output_remaining[num_elements_squeezed..];
171 rate_start_index = 0;
172 }
173 }
174}
175
176#[derive(Clone, Debug)]
178pub struct PoseidonParameters<F: PrimeField> {
179 full_rounds: u32,
180 partial_rounds: u32,
181 alpha: u64,
182 mds: Vec<Vec<F>>,
183 ark: Vec<Vec<F>>,
184}
185
186impl<F: PrimeField> PoseidonParameters<F> {
187 pub fn new(
189 full_rounds: u32,
190 partial_rounds: u32,
191 alpha: u64,
192 mds: Vec<Vec<F>>,
193 ark: Vec<Vec<F>>,
194 ) -> Self {
195 assert_eq!(ark.len() as u32, full_rounds + partial_rounds);
197 for item in &ark {
198 assert_eq!(item.len(), 3);
199 }
200 Self {
201 full_rounds,
202 partial_rounds,
203 alpha,
204 mds,
205 ark,
206 }
207 }
208
209 pub fn random_ark<R: Rng>(full_rounds: u32, rng: &mut R) -> Vec<Vec<F>> {
211 let mut ark = Vec::new();
212
213 for _ in 0..full_rounds {
214 let mut res = Vec::new();
215
216 for _ in 0..3 {
217 res.push(F::rand(rng));
218 }
219 ark.push(res);
220 }
221
222 ark
223 }
224}
225
226impl<F: PrimeField> CryptographicSponge for PoseidonSponge<F> {
227 type Parameters = PoseidonParameters<F>;
228
229 fn new(params: &Self::Parameters) -> Self {
230 let full_rounds = params.full_rounds;
232 let partial_rounds = params.partial_rounds;
233 let alpha = params.alpha;
234
235 let mds = params.mds.clone();
236
237 let ark = params.ark.to_vec();
238
239 let rate = 2;
240 let capacity = 1;
241 let state = vec![F::zero(); rate + capacity];
242 let mode = PoseidonSpongeMode::Absorbing {
243 next_absorb_index: 0,
244 };
245
246 Self {
247 full_rounds,
248 partial_rounds,
249 alpha,
250 ark,
251 mds,
252
253 state,
254 rate,
255 capacity,
256 mode,
257 }
258 }
259
260 fn absorb(&mut self, input: &impl Absorb) {
261 let elems = input.to_sponge_field_elements_as_vec::<F>();
262 if elems.is_empty() {
263 return;
264 }
265
266 match self.mode {
267 PoseidonSpongeMode::Absorbing { next_absorb_index } => {
268 let mut absorb_index = next_absorb_index;
269 if absorb_index == self.rate {
270 self.permute();
271 absorb_index = 0;
272 }
273 self.absorb_internal(absorb_index, elems.as_slice());
274 }
275 PoseidonSpongeMode::Squeezing {
276 next_squeeze_index: _,
277 } => {
278 self.permute();
279 self.absorb_internal(0, elems.as_slice());
280 }
281 };
282 }
283
284 fn squeeze_bytes(&mut self, num_bytes: usize) -> Vec<u8> {
285 let usable_bytes = (F::Params::CAPACITY / 8) as usize;
286
287 let num_elements = (num_bytes + usable_bytes - 1) / usable_bytes;
288 let src_elements = self.squeeze_native_field_elements(num_elements);
289
290 let mut bytes: Vec<u8> = Vec::with_capacity(usable_bytes * num_elements);
291 for elem in &src_elements {
292 let elem_bytes = elem.into_repr().to_bytes_le();
293 bytes.extend_from_slice(&elem_bytes[..usable_bytes]);
294 }
295
296 bytes.truncate(num_bytes);
297 bytes
298 }
299
300 fn squeeze_bits(&mut self, num_bits: usize) -> Vec<bool> {
301 let usable_bits = F::Params::CAPACITY as usize;
302
303 let num_elements = (num_bits + usable_bits - 1) / usable_bits;
304 let src_elements = self.squeeze_native_field_elements(num_elements);
305
306 let mut bits: Vec<bool> = Vec::with_capacity(usable_bits * num_elements);
307 for elem in &src_elements {
308 let elem_bits = elem.into_repr().to_bits_le();
309 bits.extend_from_slice(&elem_bits[..usable_bits]);
310 }
311
312 bits.truncate(num_bits);
313 bits
314 }
315
316 fn squeeze_field_elements_with_sizes<F2: PrimeField>(
317 &mut self,
318 sizes: &[FieldElementSize],
319 ) -> Vec<F2> {
320 if F::characteristic() == F2::characteristic() {
321 let mut buf = Vec::with_capacity(sizes.len());
323 batch_field_cast(
324 &self.squeeze_native_field_elements_with_sizes(sizes),
325 &mut buf,
326 )
327 .unwrap();
328 buf
329 } else {
330 squeeze_field_elements_with_sizes_default_impl(self, sizes)
331 }
332 }
333
334 fn squeeze_field_elements<F2: PrimeField>(&mut self, num_elements: usize) -> Vec<F2> {
335 if TypeId::of::<F>() == TypeId::of::<F2>() {
336 let result = self.squeeze_native_field_elements(num_elements);
337 let mut cast = Vec::with_capacity(result.len());
338 batch_field_cast(&result, &mut cast).unwrap();
339 cast
340 } else {
341 self.squeeze_field_elements_with_sizes::<F2>(
342 vec![FieldElementSize::Full; num_elements].as_slice(),
343 )
344 }
345 }
346}
347
348impl<F: PrimeField> FieldBasedCryptographicSponge<F> for PoseidonSponge<F> {
349 fn squeeze_native_field_elements(&mut self, num_elements: usize) -> Vec<F> {
350 let mut squeezed_elems = vec![F::zero(); num_elements];
351 match self.mode {
352 PoseidonSpongeMode::Absorbing {
353 next_absorb_index: _,
354 } => {
355 self.permute();
356 self.squeeze_internal(0, &mut squeezed_elems);
357 }
358 PoseidonSpongeMode::Squeezing { next_squeeze_index } => {
359 let mut squeeze_index = next_squeeze_index;
360 if squeeze_index == self.rate {
361 self.permute();
362 squeeze_index = 0;
363 }
364 self.squeeze_internal(squeeze_index, &mut squeezed_elems);
365 }
366 };
367
368 squeezed_elems
369 }
370}
371
372#[derive(Clone)]
373pub struct PoseidonSpongeState<F: PrimeField> {
375 state: Vec<F>,
376 mode: PoseidonSpongeMode,
377}
378
379impl<CF: PrimeField> SpongeExt for PoseidonSponge<CF> {
380 type State = PoseidonSpongeState<CF>;
381
382 fn from_state(state: Self::State, params: &Self::Parameters) -> Self {
383 let mut sponge = Self::new(params);
384 sponge.mode = state.mode;
385 sponge.state = state.state;
386 sponge
387 }
388
389 fn into_state(self) -> Self::State {
390 Self::State {
391 state: self.state,
392 mode: self.mode,
393 }
394 }
395}