1use ark_crypto_primitives::Error;
59use ark_ff::{BigInteger, PrimeField};
60use ark_std::{error::Error as ArkError, io::Read, rand::Rng, string::ToString, vec::Vec};
61use sbox::PoseidonSbox;
62
63use super::{from_field_elements, to_field_elements};
64
65pub mod sbox;
66
67#[derive(Debug)]
68
69pub enum PoseidonError {
73 InvalidSboxSize(i8),
75
76 ApplySboxFailed,
79
80 InvalidInputs,
83}
84
85impl core::fmt::Display for PoseidonError {
87 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
88 use PoseidonError::*;
89 let msg = match self {
90 InvalidSboxSize(s) => format!("sbox is not supported: {}", s),
91 ApplySboxFailed => "failed to apply sbox".to_string(),
92 InvalidInputs => "invalid inputs".to_string(),
93 };
94 write!(f, "{}", msg)
95 }
96}
97
98impl ArkError for PoseidonError {}
99
100#[derive(Default, Clone, Debug)]
102pub struct PoseidonParameters<F: PrimeField> {
103 pub round_keys: Vec<F>,
105
106 pub mds_matrix: Vec<Vec<F>>,
108
109 pub full_rounds: u8,
111
112 pub partial_rounds: u8,
114
115 pub width: u8,
117
118 pub sbox: PoseidonSbox,
120}
121
122impl<F: PrimeField> PoseidonParameters<F> {
123 pub fn new(
124 round_keys: Vec<F>,
125 mds_matrix: Vec<Vec<F>>,
126 full_rounds: u8,
127 partial_rounds: u8,
128 width: u8,
129 sbox: PoseidonSbox,
130 ) -> Self {
131 Self {
132 round_keys,
133 mds_matrix,
134 width,
135 full_rounds,
136 partial_rounds,
137 sbox,
138 }
139 }
140
141 pub fn generate<R: Rng>(_rng: &mut R) -> Self {
142 unimplemented!();
143 }
144
145 pub fn create_mds<R: Rng>(_rng: &mut R) -> Vec<Vec<F>> {
151 unimplemented!();
152 }
153
154 pub fn create_round_keys<R: Rng>(_rng: &mut R) -> Vec<F> {
160 unimplemented!();
161 }
162
163 pub fn to_bytes(&self) -> Vec<u8> {
169 let max_elt_size = F::BigInt::NUM_LIMBS * 8;
170 let mut buf: Vec<u8> = vec![];
171
172 buf.extend(&self.width.to_be_bytes());
173 buf.extend(&self.full_rounds.to_be_bytes());
174 buf.extend(&self.partial_rounds.to_be_bytes());
175 buf.extend(&self.sbox.0.to_be_bytes());
176
177 let round_key_len = self.round_keys.len() * max_elt_size;
180 buf.extend_from_slice(&(round_key_len as u32).to_be_bytes());
181
182 buf.extend_from_slice(&from_field_elements(&self.round_keys).unwrap());
184
185 let mut stored = false;
192 for i in 0..self.mds_matrix.len() {
193 if !stored {
194 let inner_vec_len = self.mds_matrix[i].len() * max_elt_size;
196 buf.extend_from_slice(&(inner_vec_len as u32).to_be_bytes());
197 stored = true;
198 }
199
200 buf.extend_from_slice(&from_field_elements(&self.mds_matrix[i]).unwrap());
201 }
202 buf
203 }
204
205 pub fn from_bytes(mut bytes: &[u8]) -> Result<Self, Error> {
209 let mut width_u8 = [0u8; 1];
210 bytes.read_exact(&mut width_u8)?;
211 let width = u8::from_be_bytes(width_u8);
212
213 let mut full_rounds_len = [0u8; 1];
214 bytes.read_exact(&mut full_rounds_len)?;
215 let full_rounds = u8::from_be_bytes(full_rounds_len);
216
217 let mut partial_rounds_u8 = [0u8; 1];
218 bytes.read_exact(&mut partial_rounds_u8)?;
219 let partial_rounds = u8::from_be_bytes(partial_rounds_u8);
220
221 let mut exponentiation_u8 = [0u8; 1];
222 bytes.read_exact(&mut exponentiation_u8)?;
223 let exp = i8::from_be_bytes(exponentiation_u8);
224
225 let mut round_key_len = [0u8; 4];
226 bytes.read_exact(&mut round_key_len)?;
227
228 let round_key_len_usize: usize = u32::from_be_bytes(round_key_len) as usize;
229 let mut round_keys_buf = vec![0u8; round_key_len_usize];
230 bytes.read_exact(&mut round_keys_buf)?;
231
232 let round_keys = to_field_elements::<F>(&round_keys_buf)?;
233 let mut mds_matrix_inner_vec_len = [0u8; 4];
234 bytes.read_exact(&mut mds_matrix_inner_vec_len)?;
235
236 let inner_vec_len_usize = u32::from_be_bytes(mds_matrix_inner_vec_len) as usize;
237 let mut mds_matrix: Vec<Vec<F>> = vec![];
238 while !bytes.is_empty() {
239 let mut inner_vec_buf = vec![0u8; inner_vec_len_usize];
240 bytes.read_exact(&mut inner_vec_buf)?;
241
242 let inner_vec = to_field_elements::<F>(&inner_vec_buf)?;
243 mds_matrix.push(inner_vec);
244 }
245
246 Ok(Self {
247 round_keys,
248 mds_matrix,
249 width,
250 full_rounds,
251 partial_rounds,
252 sbox: PoseidonSbox(exp),
253 })
254 }
255}
256
257#[derive(Default, Clone, Debug)]
258
259pub struct Poseidon<F: PrimeField> {
264 pub params: PoseidonParameters<F>,
265}
266
267impl<F: PrimeField> Poseidon<F> {
268 pub fn new(params: PoseidonParameters<F>) -> Self {
269 Poseidon { params }
270 }
271}
272
273pub trait FieldHasher<F: PrimeField> {
277 fn hash(&self, inputs: &[F]) -> Result<F, PoseidonError>;
278
279 fn hash_two(&self, left: &F, right: &F) -> Result<F, PoseidonError>;
283}
284
285impl<F: PrimeField> FieldHasher<F> for Poseidon<F> {
287 fn hash(&self, inputs: &[F]) -> Result<F, PoseidonError> {
288 let width = self.params.width as usize;
290 let partial_rounds = self.params.partial_rounds as usize;
291 let full_rounds = self.params.full_rounds as usize;
292
293 if inputs.len() > width - 1 {
295 return Err(PoseidonError::InvalidInputs);
296 }
297 let mut state = vec![F::zero()];
298 for f in inputs {
299 state.push(*f);
300 }
301 while state.len() < width {
302 state.push(F::zero());
303 }
304
305 let nr = full_rounds + partial_rounds;
306 for r in 0..nr {
307 state.iter_mut().enumerate().for_each(|(i, a)| {
309 let c = self.params.round_keys[(r * width + i)];
310 a.add_assign(c);
311 });
312
313 let half_rounds = full_rounds / 2;
314
315 if r < half_rounds || r >= half_rounds + partial_rounds {
316 state
319 .iter_mut()
320 .try_for_each(|a| self.params.sbox.apply_sbox(*a).map(|f| *a = f))?;
321 } else {
322 state[0] = self.params.sbox.apply_sbox(state[0])?;
325 }
326
327 state = state
329 .iter()
330 .enumerate()
331 .map(|(i, _)| {
332 state.iter().enumerate().fold(F::zero(), |acc, (j, a)| {
333 let m = self.params.mds_matrix[i][j];
334 acc.add(m.mul(*a))
335 })
336 })
337 .collect();
338 }
339
340 Ok(state[0])
341 }
342
343 fn hash_two(&self, left: &F, right: &F) -> Result<F, PoseidonError> {
344 self.hash(&[*left, *right])
345 }
346}
347
348#[cfg(test)]
349pub mod test {
350 use crate::poseidon::{FieldHasher, Poseidon, PoseidonParameters, PoseidonSbox};
351 use ark_ed_on_bn254::Fq;
352 use ark_ff::{fields::Field, PrimeField};
353 use ark_std::{vec::Vec, One};
354
355 use arkworks_utils::{
356 bytes_matrix_to_f, bytes_vec_to_f, parse_vec, poseidon_params::setup_poseidon_params, Curve,
357 };
358
359 pub fn setup_params<F: PrimeField>(curve: Curve, exp: i8, width: u8) -> PoseidonParameters<F> {
360 let pos_data = setup_poseidon_params(curve, exp, width).unwrap();
361
362 let mds_f = bytes_matrix_to_f(&pos_data.mds);
363 let rounds_f = bytes_vec_to_f(&pos_data.rounds);
364
365 let pos = PoseidonParameters {
366 mds_matrix: mds_f,
367 round_keys: rounds_f,
368 full_rounds: pos_data.full_rounds,
369 partial_rounds: pos_data.partial_rounds,
370 sbox: PoseidonSbox(pos_data.exp),
371 width: pos_data.width,
372 };
373
374 pos
375 }
376
377 type PoseidonHasher = Poseidon<Fq>;
378 #[test]
379 fn test_width_3_circom_bn_254() {
380 let curve = Curve::Bn254;
381
382 let params = setup_params(curve, 5, 3);
383 let poseidon = PoseidonHasher::new(params);
384
385 let res: Vec<Fq> = bytes_vec_to_f(
391 &parse_vec(vec![
392 "0x115cc0f5e7d690413df64c6b9662e9cf2a3617f2743245519e19607a4417189a",
393 ])
394 .unwrap(),
395 );
396 let left_input = Fq::one();
397 let right_input = Fq::one().double();
398 let poseidon_res = poseidon.hash_two(&left_input, &right_input).unwrap();
399
400 assert_eq!(res[0], poseidon_res, "{} != {}", res[0], poseidon_res);
401
402 let aaa: &[u8] = &[
428 0x06, 0x9c, 0x63, 0x81, 0xac, 0x0b, 0x96, 0x8e, 0x88, 0x1c, 0x91, 0x3c, 0x17, 0xd8,
429 0x36, 0x06, 0x7f, 0xd1, 0x5f, 0x2c, 0xc7, 0x9f, 0x90, 0x2c, 0x80, 0x70, 0xb3, 0x6d,
430 0x28, 0x66, 0x17, 0xdd,
431 ];
432 let left_input = Fq::from_be_bytes_mod_order(aaa);
433 let right_input = Fq::from_be_bytes_mod_order(&[
434 0xc3, 0x3b, 0x60, 0x04, 0x2f, 0x76, 0xc7, 0xfb, 0xd0, 0x5d, 0xb7, 0x76, 0x23, 0xcb,
435 0x17, 0xb8, 0x1d, 0x49, 0x41, 0x4b, 0x82, 0xe5, 0x6a, 0x2e, 0xc0, 0x18, 0xf7, 0xa5,
436 0x5c, 0x3f, 0x30, 0x0b,
437 ]);
438 let res: Vec<Fq> = bytes_vec_to_f(
439 &parse_vec(vec![
440 "0x0a13ad844d3487ad3dbaf3876760eb971283d48333fa5a9e97e6ee422af9554b",
441 ])
442 .unwrap(),
443 );
444 let poseidon_res = poseidon.hash_two(&left_input, &right_input).unwrap();
445 assert_eq!(res[0], poseidon_res, "{} != {}", res[0], poseidon_res);
446 }
447
448 #[test]
449 fn test_compare_hashes_with_circom_bn_254() {
450 let curve = Curve::Bn254;
451
452 let parameters2 = setup_params(curve, 5, 2);
453 let parameters4 = setup_params(curve, 5, 4);
454 let parameters5 = setup_params(curve, 5, 5);
455
456 let poseidon2 = Poseidon::new(parameters2);
457 let poseidon4 = Poseidon::new(parameters4);
458 let poseidon5 = Poseidon::new(parameters5);
459
460 let expected_public_key: Vec<Fq> = bytes_vec_to_f(
461 &parse_vec(vec![
462 "0x07a1f74bf9feda741e1e9099012079df28b504fc7a19a02288435b8e02ae21fa",
463 ])
464 .unwrap(),
465 );
466
467 let private_key: Vec<Fq> = bytes_vec_to_f(
468 &parse_vec(vec![
469 "0xb2ac10dccfb5a5712d632464a359668bb513e80e9d145ab5a88381de83af1046",
470 ])
471 .unwrap(),
472 );
473 let computed_public_key = poseidon2.hash(&private_key).unwrap();
476 println!("poseidon_res = {:?}", computed_public_key);
477 assert_eq!(
479 expected_public_key[0], computed_public_key,
480 "{} != {}",
481 expected_public_key[0], computed_public_key
482 );
483
484 let chain_id: Vec<Fq> = bytes_vec_to_f(
485 &parse_vec(vec![
486 "0x0000000000000000000000000000000000000000000000000000000000007a69",
487 ])
488 .unwrap(),
489 );
490 let amount: Vec<Fq> = bytes_vec_to_f(
491 &parse_vec(vec![
492 "0x0000000000000000000000000000000000000000000000000000000000989680",
493 ])
494 .unwrap(),
495 );
496 let blinding: Vec<Fq> = bytes_vec_to_f(
497 &parse_vec(vec![
498 "0x00a668ba0dcb34960aca597f433d0d3289c753046afa26d97e1613148c05f2c0",
499 ])
500 .unwrap(),
501 );
502
503 let expected_leaf: Vec<Fq> = bytes_vec_to_f(
504 &parse_vec(vec![
505 "0x15206d966a7fb3e3fbbb7f4d7b623ca1c7c9b5c6e6d0a3348df428189441a1e4",
506 ])
507 .unwrap(),
508 );
509 let mut input = vec![chain_id[0]];
510 input.push(amount[0]);
511 input.push(expected_public_key[0]);
512 input.push(blinding[0]);
513 let computed_leaf = poseidon5.hash(&input).unwrap();
514
515 assert_eq!(
516 expected_leaf[0], computed_leaf,
517 "{} != {}",
518 expected_leaf[0], computed_leaf
519 );
520
521 let path_index: Vec<Fq> = bytes_vec_to_f(
522 &parse_vec(vec![
523 "0x0000000000000000000000000000000000000000000000000000000000000000",
524 ])
525 .unwrap(),
526 );
527 let expected_nullifier: Vec<Fq> = bytes_vec_to_f(
528 &parse_vec(vec![
529 "0x21423c7374ce5b3574f04f92243449359ae3865bb8e34cb2b7b5e4187ba01fca",
530 ])
531 .unwrap(),
532 );
533 let mut input = vec![expected_leaf[0]];
534 input.push(path_index[0]);
535 input.push(private_key[0]);
536
537 let computed_nullifier = poseidon4.hash(&input).unwrap();
538
539 assert_eq!(
540 expected_nullifier[0], computed_nullifier,
541 "{} != {}",
542 expected_nullifier[0], computed_nullifier
543 );
544 }
545
546 #[test]
547 fn test_parameter_to_and_from_bytes() {
548 let curve = Curve::Bn254;
549 let params = setup_params::<Fq>(curve, 5, 3);
550
551 let bytes = params.to_bytes();
552 let new_params: PoseidonParameters<Fq> = PoseidonParameters::from_bytes(&bytes).unwrap();
553 assert_eq!(bytes, new_params.to_bytes());
554 }
555}